Visualization

This module can be use to generate, visualize and save samples with any MultiVae model.

Basic usage example:

from multivae.metrics import Visualization, VisualizationConfig

eval_config = VisualizationConfig(
                            wandb_path='your_wandb_path',
                            n_data_cond=10, # take ten datapoints for conditional generation
                            n_samples=5, # generate 5 samples per datapoint
                            )

eval_module = Visualization(
    model = your_model,
    test_dataset=test_set,
    output='./metrics',# where to save images
    eval_config=eval_config,
    sampler=None # you can use a trained MultiVae sampler for joint generation
)

# Generate unconditional samples
eval_module.eval()

# Generate conditional samples from a subset of modalities
eval_module.conditional_samples_subset(subset=['modality_1', 'modality_2'], gen_mod='all')

eval_module.finish() # finishes wandb run
class multivae.metrics.VisualizationConfig(batch_size=20, wandb_path=None, n_samples=5, n_data_cond=5)[source]

Config class for the visualization module.

Parameters:
  • batch_size (int) – The batch size to use in the evaluation. Default to 20

  • wandb_path (str) – The user can provide the path of the wandb run with a format β€˜entity/projet_name/run_id’ where the metrics should be logged. See Where to find the WandB path for a trained model ? for more information. If None is provided, the metrics are not logged on wandb. Default to None.

  • n_samples (int) – The number of samples to generate per modality and per data_point for conditional generation. Default to 5.

  • n_data_cond (int) – The number of datapoints to use for conditional generation. Default to 5

class multivae.metrics.Visualization(model, test_dataset, output=None, eval_config=VisualizationConfig(name='VisualizationConfig', batch_size=20, wandb_path=None, n_samples=5, n_data_cond=5), sampler=None)[source]

Visualization Module for visualizing unconditional, conditional samples from models.

Parameters:
  • model (BaseMultiVAE) – the model to evaluate.

  • test_dataset (MultimodalBaseDataset) – the dataset to use for conditional image generation.

  • output (str) – the path where to save images and metrics. Default to None.

  • eval_config (VisualizationConfig) – The configuration file for this evaluation module. Optional.

  • sampler (BaseSampler) – The sampler to use for joint generation. Optional. If None is provided, the sampler is used.

>>> from multivae.metrics.visualization import Visualization, VisualizationConfig


>>> vis_config = VisualizationConfig(
...                    wandb_path='your_wandb_path', # optional, if you have initialized a wandb run
...                     n_samples=5, # number of generated samples
...                     n_data_cond=8, # For conditional generation, the number of datapoints to use.
...                     )

>>> vis_module = Visualization(
...                    model,
...                    test_dataset=test_set,
...                    output='./metrics',
...                    eval_config=vis_config)

# Compute conditional generations
>>> generations = vis_module.conditional_samples_subset(['name_of_conditioning_modality1'])

# Compute unconditional generations
>>> generations = vis_module.unconditional_samples()
conditional_samples_subset(subset, gen_mod='all')[source]

Generate samples conditioning on the modalities in a subset.

Parameters:
  • subset (list) – The subset of modalities to condition on.

  • gen_mod (Union[list, str], optional) – The modalities to generate. Defaults to β€œall”.

Returns:

a PIL image containing a grid of the generated samples.

Return type:

PIL.Image

unconditional_samples(**kwargs)[source]

Generate an image of unconditional samples.

Returns:

An image containing a grid of the generated samples.

Return type:

PIL.Image