Source code for multivae.samplers.maf_sampler.maf_sampler

import os
import shutil

import torch
from pythae.data.datasets import BaseDataset
from pythae.models.normalizing_flows import MAF, MAFConfig, NFModel
from pythae.trainers import BaseTrainer, BaseTrainerConfig
from torch.distributions import MultivariateNormal
from torch.utils.data import DataLoader

from multivae.data.utils import set_inputs_to_device
from multivae.models.base import ModelOutput

from ...models import BaseMultiVAE
from ..base.base_sampler import BaseSampler
from .maf_sampler_config import MAFSamplerConfig


[docs] class MAFSampler(BaseSampler): """Fits an Inverse Autoregressive Flow in the multimodal autoencoder's latent space. If the model has multiple latent spaces, we fit one flow per latent space. Args: model (BaseMultiVAE): The model to sample from sampler_config (MAFSamplerConfig): A IAFSamplerConfig instance containing the main parameters of the sampler. If None, a pre-defined configuration is used. Default: None .. note:: The method :class:`~multivae.samplers.IAFSampler.fit` must be called to fit the sampler before sampling. """ def __init__(self, model: BaseMultiVAE, sampler_config: MAFSamplerConfig = None): self.is_fitted = False if sampler_config is None: sampler_config = MAFSamplerConfig() BaseSampler.__init__(self, model=model, sampler_config=sampler_config) self.flows_dims = dict(shared=model.model_config.latent_dim) if self.model.multiple_latent_spaces: self.flows_dims.update(self.model.style_dims) self.priors = dict() self.flows_models = dict() for key in self.flows_dims: self.priors[key] = MultivariateNormal( torch.zeros(self.flows_dims[key]).to(self.device), torch.eye(self.flows_dims[key]).to(self.device), ) maf_config = MAFConfig( input_dim=(self.flows_dims[key],), n_made_blocks=sampler_config.n_made_blocks, n_hidden_in_made=sampler_config.n_hidden_in_made, hidden_size=sampler_config.hidden_size, include_batch_norm=sampler_config.include_batch_norm, ) maf_model = MAF(model_config=maf_config) self.flows_models[key] = NFModel(self.priors[key], maf_model).to( self.device ) self.name = "MAFsampler"
[docs] def fit( self, train_data, eval_data=None, training_config: BaseTrainerConfig = None ): """Method to fit the sampler from the training data. Args: train_data (MultimodalBaseDataset): The train data needed to retreive the training embeddings and fit the mixture in the latent space. Must be of shape n_imgs x im_channels x ... and in range [0-1] eval_data (MultimodalBaseDataset): The train data needed to retreive the evaluation embeddings and fit the mixture in the latent space. Must be of shape n_imgs x im_channels x ... and in range [0-1] training_config (BaseTrainerConfig): the training config to use to fit the flow. """ train_loader = DataLoader(dataset=train_data, batch_size=100, shuffle=True) zs = {m: [] for m in self.flows_models} # Get all the latent codes and detach them to form the training set of the flow model. for _, inputs in enumerate(train_loader): inputs = set_inputs_to_device(inputs, self.device) encoder_output = self.model.encode(inputs) zs["shared"].append( encoder_output.z.detach() ) # the result of the detach does not require grad. if self.model.multiple_latent_spaces: for m in encoder_output.modalities_z: zs[m].append(encoder_output.modalities_z[m].detach()) train_data = {m: torch.cat(zs[m], dim=0) for m in zs} # Do the same for the eval dataset if eval_data is not None: eval_loader = DataLoader(dataset=eval_data, batch_size=100, shuffle=False) zs = {m: [] for m in self.flows_models} for _, inputs in enumerate(eval_loader): inputs = set_inputs_to_device(inputs, self.device) encoder_output = self.model.encode(inputs) zs["shared"].append(encoder_output.z.detach()) if self.model.multiple_latent_spaces: for m in encoder_output.modalities_z: zs[m].append(encoder_output.modalities_z[m].detach()) eval_data = {m: torch.cat(zs[m]) for m in zs} for m in train_data: # number of latent_spaces train_dataset = BaseDataset( data=train_data[m], labels=torch.zeros((len(train_data[m]),)) ) eval_dataset = ( None if eval_data is None else BaseDataset( data=eval_data[m], labels=torch.zeros((len(eval_data[m]),)) ) ) trainer = BaseTrainer( model=self.flows_models[m], train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, ) trainer.train() # Update the flow_model with the result of training self.flows_models[m] = MAF.load_from_folder( os.path.join(trainer.training_dir, "final_model") ).to(self.device) shutil.rmtree(trainer.training_dir) self.is_fitted = True
[docs] def sample( self, n_samples: int = 1, batch_size: int = 500, **kwargs ) -> torch.Tensor: """Main sampling function of the sampler. Args: num_samples (int): The number of samples to generate batch_size (int): The batch size to use during sampling Returns: ~torch.Tensor: The generated images """ if not self.is_fitted: raise ArithmeticError( "The sampler needs to be fitted by calling sampler.fit() method" "before sampling." ) full_batch_nbr = int(n_samples / batch_size) last_batch_samples_nbr = n_samples % batch_size batches = [batch_size] * full_batch_nbr if last_batch_samples_nbr != 0: batches = batches + [last_batch_samples_nbr] z_gen = {m: [] for m in self.flows_models} for batch in batches: for m in self.flows_models: u = self.priors[m].sample((batch,)) z = self.flows_models[m].inverse(u).out z_gen[m].append(z) # Output with the same format as the output of encode or generate_from_prior functions output = ModelOutput( z=torch.cat(z_gen.pop("shared")), one_latent_space=not self.model.multiple_latent_spaces, ) if self.model.multiple_latent_spaces: output["modalities_z"] = {m: torch.cat(z_gen[m]) for m in z_gen} return output
[docs] def save(self, dir_path): """Save the config and trained models.""" super().save(dir_path=dir_path) if not self.is_fitted: raise ArithmeticError( "The sampler needs to be fitted by calling sampler.fit() method" "before sampling." ) # Save the state_dicts for the flow models for m, model in self.flows_models.items(): path = os.path.join(dir_path, m) os.makedirs(path, exist_ok=True) model.save(path)
[docs] def load_flows_from_folder(self, dir_path): """Instead of calling fit, you can reload weights from a previous training. .. code-block:: python >>> sampler.save(dir_path) >>> new_sampler = MAFSampler(model, sampler_config) # must be the same model and config >>> new_sampler.load_flows_from_folder(dir_path) """ for m in self.flows_models: try: self.flows_models[m] = MAF.load_from_folder( os.path.join(dir_path, m) ).to(self.device) except Exception as exc: raise AttributeError( "Error when trying to load the flows from the folder.", f"Check that you provided the right path. Exception raised: {exc}", ) self.is_fitted = True