Source code for multivae.metrics.latent_clustering.clustering_config

from typing import Literal

from pydantic.dataclasses import dataclass

from ..base.evaluator_config import EvaluatorConfig


[docs] @dataclass class ClusteringConfig(EvaluatorConfig): """Config class for the clustering module. Args: batch_size (int) : The batch size to use in the evaluation. Default to 512 wandb_path (str) : The user can provide the path of the wandb run with a format 'entity/projet_name/run_id' where the metrics should be logged. See :doc:`info_wandb` for more information. If None is provided, the metrics are not logged on wandb. Default to None. clustering_method (Literal['kmeans']) : The method to use to cluster. Default to 'kmeans' n_clusters (int) :the number of clusters. Default to 10. number_of_runs (int) : When computing accuracies, how many runs of clustering to perform to to compute the average accuracies. Default to 20. num-samples_for_fit (int) : Number of training samples to use to fit the clusters. If None, uses all the samples. Default to None. use_mean (bool) : Whether to use a sample or the mean of the encoding distribution as the representative embedding. Default to True. """ clustering_method: Literal["kmeans"] = "kmeans" n_clusters: int = 10 number_of_runs: int = 20 num_samples_for_fit: int = None use_mean: bool = True