Source code for multivae.metrics.visualization.visualization_class

import os
from typing import Union

import torch
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

from multivae.data import MultimodalBaseDataset
from multivae.data.datasets.utils import adapt_shape
from multivae.data.utils import set_inputs_to_device
from multivae.models.base import BaseMultiVAE, ModelOutput
from multivae.models.cvae import CVAE
from multivae.samplers.base import BaseSampler

from ..base.evaluator_class import Evaluator
from .visualize_config import VisualizationConfig


[docs] class Visualization(Evaluator): """Visualization Module for visualizing unconditional, conditional samples from models. Args: model (BaseMultiVAE) : the model to evaluate. test_dataset (MultimodalBaseDataset) : the dataset to use for conditional image generation. output (str): the path where to save images and metrics. Default to None. eval_config (VisualizationConfig) : The configuration file for this evaluation module. Optional. sampler (BaseSampler) : The sampler to use for joint generation. Optional. If None is provided, the sampler is used. .. code-block:: >>> from multivae.metrics.visualization import Visualization, VisualizationConfig >>> vis_config = VisualizationConfig( ... wandb_path='your_wandb_path', # optional, if you have initialized a wandb run ... n_samples=5, # number of generated samples ... n_data_cond=8, # For conditional generation, the number of datapoints to use. ... ) >>> vis_module = Visualization( ... model, ... test_dataset=test_set, ... output='./metrics', ... eval_config=vis_config) # Compute conditional generations >>> generations = vis_module.conditional_samples_subset(['name_of_conditioning_modality1']) # Compute unconditional generations >>> generations = vis_module.unconditional_samples() """ def __init__( self, model: Union[BaseMultiVAE, CVAE], test_dataset: MultimodalBaseDataset, output: str = None, eval_config=VisualizationConfig(), sampler: BaseSampler = None, ) -> None: super().__init__(model, test_dataset, output, eval_config, sampler) self.n_samples = eval_config.n_samples self.n_data_cond = eval_config.n_data_cond
[docs] def unconditional_samples(self, **kwargs): """Generate an image of unconditional samples. Returns: PIL.Image: An image containing a grid of the generated samples. """ device = kwargs.pop("device", "cuda" if torch.cuda.is_available() else "cpu") if self.sampler is None: samples = self.model.generate_from_prior(self.n_samples) else: samples = self.sampler.sample(self.n_samples) from multivae.data.utils import set_inputs_to_device samples = set_inputs_to_device(samples, device=device) recon = self.model.decode(samples) if hasattr(self.test_dataset, "transform_for_plotting"): recon = { m: self.test_dataset.transform_for_plotting(recon[m], m) for m in recon } recon, shape = adapt_shape(recon) recon_image = torch.cat(list(recon.values())) # Transform to PIL format recon_image = make_grid(recon_image, nrow=self.n_samples) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = ( recon_image.mul(255) .add_(0.5) .clamp_(0, 255) .permute(1, 2, 0) .to("cpu", torch.uint8) .numpy() ) recon_image = Image.fromarray(ndarr) if self.output is not None: recon_image.save(os.path.join(self.output, "unconditional.png")) if self.wandb_run is not None: import wandb self.wandb_run.log({"unconditional_generation": wandb.Image(recon_image)}) return recon_image
[docs] def conditional_samples_subset( self, subset: list, gen_mod: Union[list, str] = "all" ): """Generate samples conditioning on the modalities in a subset. Args: subset (list): The subset of modalities to condition on. gen_mod (Union[list, str], optional): The modalities to generate. Defaults to "all". Returns: PIL.Image : a PIL image containing a grid of the generated samples. """ dataloader = DataLoader( self.test_dataset, batch_size=self.n_data_cond, shuffle=True ) data = next(iter(dataloader)) # set inputs to device data = set_inputs_to_device(data, self.device) recon = self.model.predict( data, cond_mod=subset, gen_mod=gen_mod, N=self.n_samples, flatten=True, ignore_incomplete=True, ) if hasattr(self.test_dataset, "transform_for_plotting"): recon = { m: self.test_dataset.transform_for_plotting(recon[m], m) for m in recon } recon.update( { f"original_{m}": self.test_dataset.transform_for_plotting( data.data[m], m ) for m in subset } ) else: recon.update({f"original_{m}": data.data[m] for m in subset}) recon, shape = adapt_shape(recon) recon_image = [recon[f"original_{m}"] for m in subset] recon_image = recon_image + [recon[m] for m in recon if "original" not in m] recon_image = torch.cat(recon_image) # Transform to PIL format recon_image = make_grid(recon_image, nrow=self.n_data_cond) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer ndarr = ( recon_image.mul(255) .add_(0.5) .clamp_(0, 255) .permute(1, 2, 0) .to("cpu", torch.uint8) .numpy() ) recon_image = Image.fromarray(ndarr) if self.output is not None: recon_image.save( os.path.join(self.output, f"conditional_from_subset_{subset}.png") ) if self.wandb_run is not None: import wandb self.wandb_run.log( {f"conditional_from_subset_{subset}": wandb.Image(recon_image)} ) return recon_image
def reconstruction(self, modality: str, **kwargs): return self.conditional_samples_subset([modality], gen_mod=modality) def eval(self): image = self.unconditional_samples() return ModelOutput(unconditional_generation=image)