Visualizationο
This module can be use to generate, visualize and save samples with any MultiVae model.
Basic usage example:
from multivae.metrics import Visualization, VisualizationConfig
eval_config = VisualizationConfig(
wandb_path='your_wandb_path',
n_data_cond=10, # take ten datapoints for conditional generation
n_samples=5, # generate 5 samples per datapoint
)
eval_module = Visualization(
model = your_model,
test_dataset=test_set,
output='./metrics',# where to save images
eval_config=eval_config,
sampler=None # you can use a trained MultiVae sampler for joint generation
)
# Generate unconditional samples
eval_module.eval()
# Generate conditional samples from a subset of modalities
eval_module.conditional_samples_subset(subset=['modality_1', 'modality_2'], gen_mod='all')
eval_module.finish() # finishes wandb run
- class multivae.metrics.VisualizationConfig(batch_size=20, wandb_path=None, n_samples=5, n_data_cond=5)[source]ο
Config class for the visualization module.
- Parameters:
batch_size (int) β The batch size to use in the evaluation. Default to 20
wandb_path (str) β The user can provide the path of the wandb run with a format βentity/projet_name/run_idβ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.
n_samples (int) β The number of samples to generate per modality and per data_point for conditional generation. Default to 5.
n_data_cond (int) β The number of datapoints to use for conditional generation. Default to 5
- class multivae.metrics.Visualization(model, test_dataset, output=None, eval_config=VisualizationConfig(name='VisualizationConfig', batch_size=20, wandb_path=None, n_samples=5, n_data_cond=5), sampler=None)[source]ο
Visualization Module for visualizing unconditional, conditional samples from models.
- Parameters:
model (BaseMultiVAE) β the model to evaluate.
test_dataset (MultimodalBaseDataset) β the dataset to use for conditional image generation.
output (str) β the path where to save images and metrics. Default to None.
eval_config (VisualizationConfig) β The configuration file for this evaluation module. Optional.
sampler (BaseSampler) β The sampler to use for joint generation. Optional. If None is provided, the sampler is used.
>>> from multivae.metrics.visualization import Visualization, VisualizationConfig >>> vis_config = VisualizationConfig( ... wandb_path='your_wandb_path', # optional, if you have initialized a wandb run ... n_samples=5, # number of generated samples ... n_data_cond=8, # For conditional generation, the number of datapoints to use. ... ) >>> vis_module = Visualization( ... model, ... test_dataset=test_set, ... output='./metrics', ... eval_config=vis_config) # Compute conditional generations >>> generations = vis_module.conditional_samples_subset(['name_of_conditioning_modality1']) # Compute unconditional generations >>> generations = vis_module.unconditional_samples()