Clustering

This module performs latent clustering with k-means in the latent space and returns the clustering accuracy.

Basic code example:

from multivae.metrics import Clustering, ClusteringConfig

eval_config = ClusteringConfig(batch_size=128,
                            wandb_path='your_wandb_path',
                            n_clusters=10,
                            number_of_runs=10)

eval_module = Clustering(
    model = your_model,
    test_dataset=test_set,
    train_dataset=train_data,
    output='./metrics',# where to save metrics
    eval_config=eval_config
)

# Compute clustering accuracy
eval_module.eval()

eval_module.finish() # finishes wandb run
class multivae.metrics.ClusteringConfig(batch_size=512, wandb_path=None, clustering_method='kmeans', n_clusters=10, number_of_runs=20, num_samples_for_fit=None, use_mean=True)[source]

Config class for the clustering module.

Parameters:
  • batch_size (int) – The batch size to use in the evaluation. Default to 512

  • wandb_path (str) – The user can provide the path of the wandb run with a format β€˜entity/projet_name/run_id’ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.

  • clustering_method (Literal['kmeans']) – The method to use to cluster. Default to β€˜kmeans’

  • n_clusters (int) – the number of clusters. Default to 10.

  • number_of_runs (int) – When computing accuracies, how many runs of clustering to perform to

  • 20. (to compute the average accuracies. Default to)

  • num-samples_for_fit (int) – Number of training samples to use to fit the clusters. If None, uses all the samples. Default to None.

  • use_mean (bool) – Whether to use a sample or the mean of the encoding distribution as the representative embedding. Default to True.

class multivae.metrics.Clustering(model, test_dataset, train_dataset, output=None, eval_config=ClusteringConfig(name='ClusteringConfig', batch_size=512, wandb_path=None, clustering_method='kmeans', n_clusters=10, number_of_runs=20, num_samples_for_fit=None, use_mean=True))[source]

Module to perform clustering in the latent space. As of now, it is only supported for the joint representation of the data. The eval() function fits a k-means model on the training embeddings, then uses this model to classify the test_samples and returns a k-means accuracy for this prediction.

Parameters:
  • model (BaseMultiVAE) – The model to evaluate.

  • test_dataset (MultimodalBaseDataset) – The dataset to use for computing the metrics.

  • train_dataset (MultimodalBaseDataset) – The training dataset to fit the k-means.

  • output (str) – The folder path to save metrics. The metrics will be saved in a metrics.txt file.

  • eval_config (EvaluatorConfig) – The configuration class to specify parameters for the evaluation.