Source code for multivae.metrics.latent_clustering.clustering_class

import numpy as np
import torch
from torch.utils.data import DataLoader

from multivae.data import MultimodalBaseDataset
from multivae.data.utils import set_inputs_to_device
from multivae.models.base import BaseMultiVAE, ModelOutput

from ..base.evaluator_class import Evaluator
from .clustering_config import ClusteringConfig


[docs] class Clustering(Evaluator): """Module to perform clustering in the latent space. As of now, it is only supported for the joint representation of the data. The eval() function fits a k-means model on the training embeddings, then uses this model to classify the test_samples and returns a k-means accuracy for this prediction. Args: model (BaseMultiVAE) : The model to evaluate. test_dataset (MultimodalBaseDataset) : The dataset to use for computing the metrics. train_dataset (MultimodalBaseDataset): The training dataset to fit the k-means. output (str) : The folder path to save metrics. The metrics will be saved in a metrics.txt file. eval_config (EvaluatorConfig) : The configuration class to specify parameters for the evaluation. """ def __init__( self, model: BaseMultiVAE, test_dataset: MultimodalBaseDataset, train_dataset: MultimodalBaseDataset, output: str = None, eval_config=ClusteringConfig(), ) -> None: super().__init__(model, test_dataset, output, eval_config) self.num_samples_for_fit = eval_config.num_samples_for_fit self.n_fits = eval_config.number_of_runs self.use_mean = eval_config.use_mean if eval_config.clustering_method == "kmeans": try: from sklearn.cluster import KMeans except: raise ModuleNotFoundError( "scikit-learn must be installed to perform clustering. Run `pip install scikit-learn` to install it " ) self.clustering = KMeans(n_clusters=eval_config.n_clusters, max_iter=300) self.train_dataset = train_dataset def fit_clustering(self, mods="all"): # compute all training embeddings dl = DataLoader(self.train_dataset, self.batch_size, shuffle=True) list_z = [] n_samples = 0 labels = [] for inputs in dl: if ( self.num_samples_for_fit is not None and n_samples > self.num_samples_for_fit ): break inputs = set_inputs_to_device(inputs, self.device) with torch.no_grad(): list_z.append( self.model.encode(inputs, mods, return_mean=self.use_mean).z ) if inputs.labels is not None: labels.append(inputs.labels) if len(labels) > 0: labels = torch.cat(labels).cpu().numpy() labels.dtype = np.int64 all_z = torch.cat(list_z).cpu().numpy() clusters_labels = self.clustering.fit_predict(all_z) # Get the majority label for each cluster self.labels_dict = {str(m): m for m in np.unique(clusters_labels)} if len(labels) == len(clusters_labels): for c in np.unique(clusters_labels): maj_value = np.bincount(labels[clusters_labels == c]).argmax() self.labels_dict[str(c)] = maj_value def cluster_accuracy(self, mods="all"): mean_acc = [] for i in range(self.n_fits): self.fit_clustering(mods) # Cluster the test dataset acc = 0 n_samples = 0 for inputs in self.test_loader: inputs = set_inputs_to_device(inputs, self.device) with torch.no_grad(): z = self.model.encode(inputs, mods, return_mean=self.use_mean).z clabels = self.clustering.predict(z.cpu().numpy()) labels = np.array([self.labels_dict[str(c)] for c in clabels]) true_labels = inputs.labels.cpu().numpy() acc += np.sum(labels == true_labels) n_samples += len(z) accuracy = acc / n_samples mean_acc.append(accuracy) accuracy = np.mean(mean_acc) self.metrics["cluster_accuracy"] = accuracy self.logger.info(f"Cluster accuracy is {accuracy}") return ModelOutput(cluster_accuracy=accuracy) def eval(self): output = self.cluster_accuracy("all") self.log_to_wandb() return output