Gaussian Mixture Model

Implements a Gaussian Mixture Sampler in the latent space of MultiVae models for improved unconditional generation. A Gaussian Mixture is fitted on the training embeddings.

class multivae.samplers.GaussianMixtureSampler(model, sampler_config=None)[source]

Fits a Gaussian Mixture in the Multimodal Autoencoder’s latent space. If the model has several latent spaces, it fits a gaussian mixture per latent space.

Parameters:
  • model (BaseMultiVAE) – The model to sample from.

  • sampler_config (BaseSamplerConfig) – An instance of BaseSamplerConfig in which any sampler’s parameters is made available. If None a default configuration is used. Default: None.

Note

The method fit must be called to fit the sampler before sampling.

fit(train_data, **kwargs)[source]

Method to fit the sampler from the training data.

Parameters:

train_data (MultimodalBaseDataset) – The train data needed to retreive the training embeddings and fit the mixture in the latent space. Must be an instance of MultimodalBaseDataset.

sample(n_samples=1, batch_size=500, **kwargs)[source]

Main sampling function of the sampler.

Parameters:
  • num_samples (int) – The number of samples to generate

  • batch_size (int) – The batch size to use during sampling

  • save_sampler_config (bool) – Whether to save the sampler config. It is saved in output_dir

Returns:

ModelOutput similar as the one returned by the encode function or generate_from_prior function.

class multivae.samplers.GaussianMixtureSamplerConfig(n_components=10)[source]

Gaussian mixture sampler config class.

Parameters:

n_components (int) – The number of Gaussians in the mixture. Default to 10