Gaussian Mixture Model
Implements a Gaussian Mixture Sampler in the latent space of MultiVae models for improved unconditional generation. A Gaussian Mixture is fitted on the training embeddings.
- class multivae.samplers.GaussianMixtureSampler(model, sampler_config=None)[source]
Fits a Gaussian Mixture in the Multimodal Autoencoder’s latent space. If the model has several latent spaces, it fits a gaussian mixture per latent space.
- Parameters:
model (BaseMultiVAE) – The model to sample from.
sampler_config (BaseSamplerConfig) – An instance of BaseSamplerConfig in which any sampler’s parameters is made available. If None a default configuration is used. Default: None.
Note
The method
fitmust be called to fit the sampler before sampling.- fit(train_data, **kwargs)[source]
Method to fit the sampler from the training data.
- Parameters:
train_data (MultimodalBaseDataset) – The train data needed to retreive the training embeddings and fit the mixture in the latent space. Must be an instance of MultimodalBaseDataset.