Source code for multivae.samplers.gaussian_mixture.gaussian_mixture_sampler

import logging

import torch
from sklearn import mixture
from torch.utils.data import DataLoader

from multivae.data import MultimodalBaseDataset
from multivae.data.utils import set_inputs_to_device
from multivae.models.base import ModelOutput

from ...models import BaseMultiVAE
from ..base import BaseSampler
from .gaussian_mixture_config import GaussianMixtureSamplerConfig

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class GaussianMixtureSampler(BaseSampler): """Fits a Gaussian Mixture in the Multimodal Autoencoder's latent space. If the model has several latent spaces, it fits a gaussian mixture per latent space. Args: model (BaseMultiVAE): The model to sample from. sampler_config (BaseSamplerConfig): An instance of BaseSamplerConfig in which any sampler's parameters is made available. If None a default configuration is used. Default: None. .. note:: The method :class:`~multivae.samplers.GaussianMixtureSampler.fit` must be called to fit the sampler before sampling. """ def __init__( self, model: BaseMultiVAE, sampler_config: GaussianMixtureSamplerConfig = None ): if sampler_config is None: sampler_config = GaussianMixtureSamplerConfig() BaseSampler.__init__(self, model=model, sampler_config=sampler_config) self.n_components = sampler_config.n_components self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.name = "GaussianMixtureSampler"
[docs] def fit(self, train_data: MultimodalBaseDataset, **kwargs): """Method to fit the sampler from the training data. Args: train_data (MultimodalBaseDataset): The train data needed to retreive the training embeddings and fit the mixture in the latent space. Must be an instance of MultimodalBaseDataset. """ train_loader = DataLoader( dataset=train_data, batch_size=100, shuffle=False, ) z = [] if self.model.multiple_latent_spaces: mod_z = {m: [] for m in self.model.encoders} # Compute all embeddings with torch.no_grad(): for _, inputs in enumerate(train_loader): inputs = set_inputs_to_device(inputs, self.device) output = self.model.encode(inputs) z.append(output.z) if self.model.multiple_latent_spaces: for m in mod_z: mod_z[m].append(output.modalities_z[m]) z = torch.cat(z) if self.model.multiple_latent_spaces: mod_z = {m: torch.cat(mod_z[m]) for m in mod_z} if self.n_components > z.shape[0]: self.n_components = z.shape[0] logger.warning( f"Setting the number of component to {z.shape[0]} since" "n_components > n_samples when fitting the gmm" ) gmm = mixture.GaussianMixture( n_components=self.n_components, covariance_type="full", max_iter=2000, verbose=0, tol=1e-3, ) gmm.fit(z.cpu().detach()) self.gmm = gmm if self.model.multiple_latent_spaces: self.mod_gmms = dict() for m in mod_z: gmm = mixture.GaussianMixture( n_components=self.n_components, covariance_type="full", max_iter=2000, verbose=0, tol=1e-3, ) gmm.fit(mod_z[m].cpu().detach()) self.mod_gmms[m] = gmm self.is_fitted = True
[docs] def sample( self, n_samples: int = 1, batch_size: int = 500, **kwargs ) -> torch.Tensor: """Main sampling function of the sampler. Args: num_samples (int): The number of samples to generate batch_size (int): The batch size to use during sampling save_sampler_config (bool): Whether to save the sampler config. It is saved in output_dir Returns: ModelOutput similar as the one returned by the encode function or generate_from_prior function. """ if not self.is_fitted: raise ArithmeticError( "The sampler needs to be fitted by calling sampler.fit() method" "before sampling." ) full_batch_nbr = int(n_samples / batch_size) last_batch_samples_nbr = n_samples % batch_size batches_sizes = [batch_size] * full_batch_nbr if last_batch_samples_nbr != 0: batches_sizes = batches_sizes + [last_batch_samples_nbr] z_list = [] if self.model.multiple_latent_spaces: mod_z = {m: [] for m in self.model.encoders} for batch_size in batches_sizes: z = ( torch.tensor(self.gmm.sample(batch_size)[0]) .to(self.device) .type(torch.float) ) z_list.append(z) if self.model.multiple_latent_spaces: for m in mod_z: z = ( torch.tensor(self.mod_gmms[m].sample(batch_size)[0]) .to(self.device) .type(torch.float) ) mod_z[m].append(z) output = ModelOutput(z=torch.cat(z_list, dim=0)) if self.model.multiple_latent_spaces: output["one_latent_space"] = False output["modalities_z"] = {m: torch.cat(mod_z[m]) for m in mod_z} else: output["one_latent_space"] = True return output