Reconstruction

This module can be used to compute reconstruction metrics for any MultiVae model. It uses torchmetrics for computing the SSIM.

Basic code example:

from multivae.metrics import Reconstruction, ReconstructionConfig

eval_config = ReconstructionConfig(
                            batch_size=128,
                            wandb_path='your_wandb_path',
                            metric='SSIM' # take ten datapoints for conditional generation
                            )

eval_module = Reconstruction(
    model = your_model,
    test_dataset=test_set,
    output='./metrics',# where to save images
    eval_config=eval_config,
)

# Compute metrics
eval_module.eval()

eval_module.finish() # finishes wandb run
class multivae.metrics.ReconstructionConfig(batch_size=512, wandb_path=None, metric='SSIM')[source]

Config class for a quantitative evaluation of the reconstruction quality.

Parameters:
  • batch_size (int) – The batch size to use in the evaluation.

  • wandb_path (str) – The user can provide the path of the wandb run with a format β€˜entity/projet_name/run_id’ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.

  • metric (Literal['SSIM', 'MSE']) – The metric to use to assess reconstruction quality. Default to β€˜SSIM’.

class multivae.metrics.Reconstruction(model, test_dataset, output=None, eval_config=ReconstructionConfig(name='ReconstructionConfig', batch_size=512, wandb_path=None, metric='SSIM'))[source]

Class for computing reconstruction metrics.

Available metrics are:

Parameters:
  • model (BaseMultiVAE) – The model to evaluate.

  • test_dataset (MultimodalBaseDataset) – The dataset to use for computing the metrics.

  • output (str) – The folder path to save metrics. The metrics will be saved in a metrics.txt file.

  • eval_config (CoherencesEvaluatorConfig) – The configuration class to specify parameters for the evaluation.

eval()[source]

Compute metrics for joint reconstruction and unimodal reconstruction.

reconstruction_from_subset(subset)[source]

Take a subset of modalities as input and compute reconstructions for those modalities.