Coherencesο
To compute generative coherence of your model.
A simple usage example:
from multivae.metrics import CoherenceEvaluator, CoherenceEvaluatorConfig
eval_config = CoherenceEvaluatorConfig(batch_size=128,
wandb_path='your_wandb_path', #optional / to log to wandb
num_classes=10, # number of classes in your multimodal dataset
nb_samples_for_cross=10,
nb_samples_for_joint=100
)
eval_module = CoherenceEvaluator(
model=your_model,
test_dataset=test_data,
output='your_ouput_path', # where to save metrics
sampler= None,# you can pass a trained MultiVae sampler
classifiers=your_dict_of_classifiers
)
# Compute joint coherence and all cross-coherences
eval_module.eval()
# If you only wish to compute joint coherence:
eval_module.joint_coherence()
# If you want one specific cross-modal coherence
eval_module.coherence_from_subset(['mod1', 'mod2'])
eval_module.finish() # to finish wandb run
- class multivae.metrics.CoherenceEvaluatorConfig(batch_size=512, wandb_path=None, num_classes=10, include_recon=False, nb_samples_for_joint=10000, nb_samples_for_cross=1, give_details_per_class=False)[source]ο
Config class for the evaluation of the coherences module.
- Parameters:
batch_size (int) β The batch size to use in the evaluation.
wandb_path (str) β The user can provide the path of the wandb run with a format βentity/projet_name/run_idβ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.
num_classes (int) β Number of Classes. Default to 10.
include_recon (bool) β If True, we include the reconstructions in the mean conditional generations coherences. Default to False.
nb_samples_for_joint (int) β How many samples to use to compute joint coherence. Default to 10000.
nb_samples_for_cross (int) β How many generations per sample to use when computing cross coherences. Default to 1.
give_details_per_class (bool) β Provide accuracy details per class. Default to False.
- class multivae.metrics.CoherenceEvaluator(model, classifiers, test_dataset, output=None, eval_config=CoherenceEvaluatorConfig(name='CoherenceEvaluatorConfig', batch_size=512, wandb_path=None, num_classes=10, include_recon=False, nb_samples_for_joint=10000, nb_samples_for_cross=1, give_details_per_class=False), sampler=None)[source]ο
Class for computing coherences metrics.
- Parameters:
model (BaseMultiVAE) β The model to evaluate.
classifiers (dict) β A dictionary containing the pretrained classifiers to use for the coherence evaluation.
test_dataset (MultimodalBaseDataset) β The dataset to use for computing the metrics.
output (str) β The folder path to save metrics. The metrics will be saved in a metrics.txt file.
eval_config (CoherencesEvaluatorConfig) β The configuration class to specify parameters for the evaluation.
sampler (BaseSampler) β A custom sampler for computing the joint coherence. If None is provided, samples are generated from the prior.
- coherence_from_subset(subset)[source]ο
Compute all the coherences generating from the modalities in subset to a modality that is not in subset.