Reconstructionο
This module can be used to compute reconstruction metrics for any MultiVae model. It uses torchmetrics for computing the SSIM.
Basic code example:
from multivae.metrics import Reconstruction, ReconstructionConfig
eval_config = ReconstructionConfig(
batch_size=128,
wandb_path='your_wandb_path',
metric='SSIM' # take ten datapoints for conditional generation
)
eval_module = Reconstruction(
model = your_model,
test_dataset=test_set,
output='./metrics',# where to save images
eval_config=eval_config,
)
# Compute metrics
eval_module.eval()
eval_module.finish() # finishes wandb run
- class multivae.metrics.ReconstructionConfig(batch_size=512, wandb_path=None, metric='SSIM')[source]ο
Config class for a quantitative evaluation of the reconstruction quality.
- Parameters:
batch_size (int) β The batch size to use in the evaluation.
wandb_path (str) β The user can provide the path of the wandb run with a format βentity/projet_name/run_idβ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.
metric (Literal['SSIM', 'MSE']) β The metric to use to assess reconstruction quality. Default to βSSIMβ.
- class multivae.metrics.Reconstruction(model, test_dataset, output=None, eval_config=ReconstructionConfig(name='ReconstructionConfig', batch_size=512, wandb_path=None, metric='SSIM'))[source]ο
Class for computing reconstruction metrics.
Available metrics are:
MSE (Mean Squared Error): https://en.wikipedia.org/wiki/Mean_squared_error
SSIM (Structural Similarity Index Measure ): https://en.wikipedia.org/wiki/Structural_similarity, only for images.
- Parameters:
model (BaseMultiVAE) β The model to evaluate.
test_dataset (MultimodalBaseDataset) β The dataset to use for computing the metrics.
output (str) β The folder path to save metrics. The metrics will be saved in a metrics.txt file.
eval_config (CoherencesEvaluatorConfig) β The configuration class to specify parameters for the evaluation.