MMVAE

Implementation of the Variational Mixture-of-Experts Autoencoder model from the paper “Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models” (https://arxiv.org/abs/1911.03393).

The MMVAE model uses a mixture-of-experts (MoE) aggregation. It also uses a k-samples IWAE lower bound. The MMVAE loss writes as follows:

\[\frac{1}{M}\sum_{j=1}^{M} \mathbb E_{z^{(1)},\dots z^{(k)} \sim q_{\phi_j}(z|x_j)} \left [ \log \frac{1}{K} \sum_k \frac{p_{\theta}(z^{(k)},X)}{q_{\phi}(z|X)} \right]\]

The original MMVAE model uses Laplace posteriors while constraining their scaling in each direction to sum to \(D\), the dimension of the latent space.

A DReG estimator can be used to compute the gradient of the IWAE loss. See (https://yugeten.github.io/posts/2020/06/elbo/) for a nice explanation of the DReG estimator.

Note

In the partially observed setting, we take the mixture of experts \(q_{\phi}(z|X)\) over the available modalities.

For instance, if \(S_{obs}(X)\) is the subset of observed modalities for sample \(X\) the loss becomes:

\[\frac{1}{|S_{obs}(X)|}\sum_{j \in S_{obs}(X)} \mathbb E_{z^{(1)},\dots z^{(k)} \sim q_{\phi_j}(z|x_j)} \left [ \log \frac{1}{K} \sum_k \frac{p_{\theta}(z^{(k)},X)}{q_{\phi}(z|X)} \right]\]

with the joint posterior \(q_{\phi}(z|X)\) computed as the mixture of available experts.

class multivae.models.MMVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, K=10, prior_and_posterior_dist='laplace_with_softmax', learn_prior=True, beta=1.0, loss='dreg_looser')[source]

This class is the configuration class for the MMVAE model, from (Variational Mixture-of-Experts Autoencoders for Multi-Modal Deep Generative Models, Shi et al 2019, https://proceedings.neurips.cc/paper/2019/hash/0ae775a8cb3b499ad1fca944e6f5c836-Abstract.html).

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • K (int) – the number of samples to use for importance sampling. Default to 1.

  • prior_and_posterior_dist (str) – The type of distribution to use for posterior and prior. Possible values [‘laplace_with_softmax’,’normal’]. Default to ‘laplace_with_softmax’ the posterior distribution that is used in the original paper.

  • learn_prior (bool) – If True, the mean and variance of the prior are optimized during the training. Default to True.

  • beta (float) – Regularizes the divergence term. Default to 1.

  • loss (Literal) – Either ‘iwae_looser’ or ‘dreg_looser’.

class multivae.models.MMVAE(model_config, encoders=None, decoders=None)[source]

The Variational Mixture-of-Experts Autoencoder model.

Parameters:
  • model_config (MMVAEConfig) – An instance of MMVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

  • batch_size_K (int) – Default to 100.

Returns:

The negative log-likelihood summed over the batch.

compute_joint_nll_paper(inputs, K=1000, batch_size_K=10)[source]

Computes the joint likelihood like in the original dataset, using all Mixture of experts samples and modality rescaling.

compute_k_lws(qz_xs, embeddings, reconstructions, inputs)[source]

Compute likelihood terms for all modalities and for all k.

returns :

dict containing the likelihoods terms (not aggregated) for all modalities.

dreg_looser(qz_xs, embeddings, reconstructions, inputs)[source]

The DreG estimation for IWAE. losses components in lws needs to have been computed on detached posteriors.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

ModelOutput instance with fields ‘z’ (torch.Tensor (n_data, N, latent_dim)),’one_latent_space’ (bool) = True

forward(inputs, **kwargs)[source]

Forward pass of the model. Outputs the loss and metrics.

generate_from_prior(n_samples, **kwargs)[source]

Generate latent samples from the prior distribution. This is the base class in which we consider a static standard Normal Prior. This may be overwritten in subclasses.

Parameters:
  • n_samples (int) – number of samples to generate

  • **kwargs – additional arguments

Returns:

A ModelOutput instance containing the generated samples

Return type:

ModelOutput

iwae_looser(qz_xs, embeddings, reconstructions, inputs)[source]

Compute the iwae loss without the DReG estimator for the gradient.

log_var_to_std(log_var)[source]

For latent distributions parameters, transform the log covariance to the standard deviation of the distribution either applying softmax or not. This follows the original implementation.

property pz_params

From the prior mean and log_covariance, return the mean and standard deviation, either applying softmax or not depending on the choice of prior distribution.

Returns:

mean, std

Return type:

tuple