Clusteringο
This module performs latent clustering with k-means in the latent space and returns the clustering accuracy.
Basic code example:
from multivae.metrics import Clustering, ClusteringConfig
eval_config = ClusteringConfig(batch_size=128,
wandb_path='your_wandb_path',
n_clusters=10,
number_of_runs=10)
eval_module = Clustering(
model = your_model,
test_dataset=test_set,
train_dataset=train_data,
output='./metrics',# where to save metrics
eval_config=eval_config
)
# Compute clustering accuracy
eval_module.eval()
eval_module.finish() # finishes wandb run
- class multivae.metrics.ClusteringConfig(batch_size=512, wandb_path=None, clustering_method='kmeans', n_clusters=10, number_of_runs=20, num_samples_for_fit=None, use_mean=True)[source]ο
Config class for the clustering module.
- Parameters:
batch_size (int) β The batch size to use in the evaluation. Default to 512
wandb_path (str) β The user can provide the path of the wandb run with a format βentity/projet_name/run_idβ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.
clustering_method (Literal['kmeans']) β The method to use to cluster. Default to βkmeansβ
n_clusters (int) β the number of clusters. Default to 10.
number_of_runs (int) β When computing accuracies, how many runs of clustering to perform to
20. (to compute the average accuracies. Default to)
num-samples_for_fit (int) β Number of training samples to use to fit the clusters. If None, uses all the samples. Default to None.
use_mean (bool) β Whether to use a sample or the mean of the encoding distribution as the representative embedding. Default to True.
- class multivae.metrics.Clustering(model, test_dataset, train_dataset, output=None, eval_config=ClusteringConfig(name='ClusteringConfig', batch_size=512, wandb_path=None, clustering_method='kmeans', n_clusters=10, number_of_runs=20, num_samples_for_fit=None, use_mean=True))[source]ο
Module to perform clustering in the latent space. As of now, it is only supported for the joint representation of the data. The eval() function fits a k-means model on the training embeddings, then uses this model to classify the test_samples and returns a k-means accuracy for this prediction.
- Parameters:
model (BaseMultiVAE) β The model to evaluate.
test_dataset (MultimodalBaseDataset) β The dataset to use for computing the metrics.
train_dataset (MultimodalBaseDataset) β The training dataset to fit the k-means.
output (str) β The folder path to save metrics. The metrics will be saved in a metrics.txt file.
eval_config (EvaluatorConfig) β The configuration class to specify parameters for the evaluation.