from itertools import combinations
from typing import Dict, List, Optional
import numpy as np
import torch
from pythae.models.base.base_utils import ModelOutput
from torchmetrics.classification import MulticlassAccuracy
from multivae.data import MultimodalBaseDataset
from multivae.data.utils import set_inputs_to_device
from multivae.models.base import BaseMultiVAE
from multivae.samplers.base import BaseSampler
from ..base.evaluator_class import Evaluator
from .coherences_config import CoherenceEvaluatorConfig
[docs]
class CoherenceEvaluator(Evaluator):
"""Class for computing coherences metrics.
Args:
model (BaseMultiVAE) : The model to evaluate.
classifiers (dict) : A dictionary containing the pretrained classifiers to use for the coherence evaluation.
test_dataset (MultimodalBaseDataset) : The dataset to use for computing the metrics.
output (str) : The folder path to save metrics. The metrics will be saved in a metrics.txt file.
eval_config (CoherencesEvaluatorConfig) : The configuration class to specify parameters for the evaluation.
sampler (BaseSampler) : A custom sampler for computing the joint coherence. If None is provided, samples
are generated from the prior.
"""
def __init__(
self,
model: BaseMultiVAE,
classifiers: Dict[str, torch.nn.Module],
test_dataset: MultimodalBaseDataset,
output: Optional[str] = None,
eval_config=CoherenceEvaluatorConfig(),
sampler: BaseSampler = None,
) -> None:
super().__init__(model, test_dataset, output, eval_config, sampler)
self.clfs = classifiers
self.include_recon = eval_config.include_recon
self.nb_samples_for_joint = eval_config.nb_samples_for_joint
self.nb_samples_for_cross = eval_config.nb_samples_for_cross
self.num_classes = eval_config.num_classes
self.give_details_per_classes = eval_config.give_details_per_class
assert self.num_classes is not None, "Please provide the number of classes"
for k in self.clfs:
self.clfs[k] = self.clfs[k].to(self.device).eval()
[docs]
def cross_coherences(self):
"""Computes all the coherences from one subset of modalities to another modality.
Returns:
float, float: The cross-coherences metric mean and std
"""
modalities = list(self.model.encoders.keys())
accs = []
accs_per_class = []
for n in range(1, self.model.n_modalities):
subsets_of_size_n = combinations(
modalities,
n,
)
accs.append([])
accs_per_class.append([])
for s in subsets_of_size_n:
s = list(s)
(
subset_dict,
mean_acc,
mean_acc_per_class,
) = self.coherence_from_subset(s)
self.metrics.update(subset_dict)
accs[-1].append(mean_acc)
accs_per_class[-1].append(mean_acc_per_class)
mean_accs = [np.mean(l) for l in accs]
std_accs = [np.std(l) for l in accs]
mean_accs_per_class = [np.mean(np.stack(l), axis=0) for l in accs_per_class]
for i, (m, s) in enumerate(zip(mean_accs, std_accs)):
self.logger.info(
"Conditional accuracies for %s modalities : %s +- %s", i + 1, m, s
)
self.metrics.update(
{
f"mean_coherence_{i + 1}": m,
f"std_coherence_{i + 1}": s,
}
)
if self.give_details_per_classes:
for c in range(self.num_classes):
self.logger.info(
"Conditional accuracies for %s modalities in class %s: %s",
i + 1,
c,
mean_accs_per_class[i][c],
)
self.metrics.update(
{
f"mean_coherence_{i + 1}_class_{c}": mean_accs_per_class[i][
c
],
}
)
return mean_accs, std_accs
[docs]
def coherence_from_subset(self, subset: List[str]):
"""Compute all the coherences generating from the modalities in subset to a modality
that is not in subset.
Args:
subset (List[str]): The subset of modalities to consider.
Returns:
dict, float : The dictionary of all coherences from subset, and the mean coherence
"""
pred_mods = [
m for m in self.model.encoders if (m not in subset) or self.include_recon
]
subset_name = "_".join(subset)
accuracies_per_class = {
m: MulticlassAccuracy(num_classes=self.num_classes, average=None).to(
self.device
)
for m in pred_mods
}
for batch in self.test_loader:
if not hasattr(batch, "labels"):
raise AttributeError(
"Cross-modal coherence can not be computed "
" on a dataset without labels"
)
elif batch.labels is None:
raise AttributeError(
"Cross-modal coherence can not be computed "
" on a dataset without labels, but the provided dataset"
" has None instead of tensor labels"
)
batch = set_inputs_to_device(batch, device=self.device)
output = self.model.predict(
batch,
list(subset),
pred_mods,
N=self.nb_samples_for_cross,
flatten=True,
)
for pred_m in pred_mods:
preds = self.clfs[pred_m](output[pred_m])
if self.nb_samples_for_cross > 1:
labels = torch.stack(
[batch.labels] * self.nb_samples_for_cross, dim=0
).reshape(-1, *batch.labels.shape[1:])
else:
labels = batch.labels
acc = accuracies_per_class[pred_m](preds, labels)
acc_per_class = {
f"{subset_name}_to_{m}": accuracies_per_class[m].compute().cpu()
for m in accuracies_per_class
}
acc = {m: acc_per_class[m].mean() for m in acc_per_class}
self.logger.info("Subset %s accuracies ", subset)
self.logger.info(str(acc))
mean_pair_acc = np.mean(list(acc.values()))
self.logger.info("Mean subset %s accuracies : %s", subset, str(mean_pair_acc))
mean_acc_per_class = np.mean(np.stack(list(acc_per_class.values())), axis=0)
return acc, mean_pair_acc, mean_acc_per_class
[docs]
def joint_coherence(self):
"""Generate in all modalities from the prior and compute the percentage of samples where all modalities have the same
labels.
Returns:
float: The joint coherence metric
"""
all_labels = torch.tensor([]).to(self.device)
samples_to_generate = self.nb_samples_for_joint
# loop over batches
while samples_to_generate > 0:
batch_samples = min(self.batch_size, samples_to_generate)
if self.sampler is None:
output_prior = self.model.generate_from_prior(batch_samples)
else:
output_prior = self.sampler.sample(batch_samples)
# set output to device
output_prior.z = output_prior.z.to(self.device)
if not output_prior.one_latent_space:
for m in output_prior.modalities_z:
output_prior.modalities_z[m] = output_prior.modalities_z[m].to(
self.device
)
# decode
output_decode = self.model.decode(output_prior)
labels = []
for m in output_decode.keys():
preds = self.clfs[m](output_decode[m])
labels_m = torch.argmax(preds, dim=1) # shape (nb_samples_for_joint,1)
labels.append(labels_m)
all_same_labels = torch.all(
torch.stack([l == labels[0] for l in labels]), dim=0
)
all_labels = torch.cat((all_labels, all_same_labels.float()), dim=0)
samples_to_generate -= batch_samples
joint_coherence = all_labels.mean()
sampler_name = "prior" if self.sampler is None else self.sampler.name
self.logger.info(
"Joint coherence with sampler %s: %s", sampler_name, joint_coherence
)
self.metrics.update({f"joint_coherence_{sampler_name}": joint_coherence})
return joint_coherence
[docs]
def eval(self):
"""Compute all cross-modal coherences and the joint coherence."""
self.cross_coherences()
self.joint_coherence()
self.log_to_wandb()
return ModelOutput(**self.metrics)