MMVAE+

Implementation of “MMVAE+: Enhancing the Generative Quality of Multimodal VAEs without Compromises” (https://openreview.net/forum?id=sdQGxouELX).

The MMVAE+ model is an aggregated model that uses multiple latent spaces : \(z\) is the latent code shared accross modalities and \(w_j\) is private latent code of modality \(j \in [|1 , M|]\).

It also uses auxiliary prior distributions for each private latent spaces \(r_j(w_j)\) with a scale parameter that is learned.

As for the MMVAE, the joint posterior for the shared latent code is a Mixture-Of-Experts of the unimodal posteriors:

\[q_{\phi_z}(z|X) = \frac{1}{M} \sum_{m =1}^{M} q_{\phi_{z_m}}(z|x_m)\]

The loss of the MMVAE+ model then writes as follows:

\[\begin{split}\frac{1}{M}\sum_{m=1}^{M} \mathbb E_{ \substack{ z_m^{1::K} \sim q_{\phi_{z_m}}(z|x_m)\\w_m^{1::K} \sim q_{\phi_{w_m}}(w_m|x_m) \\ \tilde{w}_{n\neq m}^{1::K} \sim r_n(w_n) } } \log \frac{1}{K} \sum_{k=1}^{K} D^{\beta}_{\Phi,\Theta}(X,z^k, \tilde{w}_1^k, \tilde{w}_2^k,.., w_m^k, .., \tilde{w}_M^k)\end{split}\]

with

\[D^{\beta}_{\Phi,\Theta}(X,z^k, \tilde{w}_1^k, \tilde{w}_2^k,.., w_m^k, .., \tilde{w}_M^k) = \frac{p_{\theta_m}(x_m|z^k, w_m^k)(p(z^k)p(w_m^k))^{\beta}}{(q_{\phi_z}(z^k|X)q_{\phi_{w_m}}(w_m^k|x_m))^{\beta}}\prod_{n \neq m}p_{\theta_n}(x_n|z^k,\tilde{w}_n^k)\]

It uses a K-importance sampled estimator of the likelihood and a \(\beta\) factor that can be tuned to promote disentanglement in the latent space. In this objective function, the modality private information \(w_m\) is only used for self reconstruction and not for cross-modal generation. For crossmodal generation, the shared semantic content flows through the shared latent variable \(z\).

Note

For the partially observed case, that loss can be computed using only available sample instead of all modalities.

If we only observe modalities in \(S_{obs}(X)\) the loss for sample \(X\) becomes:

\[\begin{split}\frac{1}{|S_{obs}(X)|}\sum_{m \in S_{obs}(X)} \mathbb E_{ \substack{ z_m^{1::K} \sim q_{\phi_{z_m}}(z|x_m)\\w_m^{1::K} \sim q_{\phi_{w_m}}(z|w_m) \\ \tilde{w}_{n\neq m}^{1::K} \sim r_n(w_n) } } \log \frac{1}{K} \sum_{k=1}^{K} D^{\beta}_{\Phi,\Theta}(X,z^k, \tilde{w}_1^k, \tilde{w}_2^k,.., w_2^m, .., \tilde{w}_M^k)\end{split}\]

where:

\[\]

D^{beta}_{Phi,Theta}(X,z^k, tilde{w}_1^k, tilde{w}_2^k,.., w_m^k, .., tilde{w}_M^k) = frac{p_{theta_m}(x_m|z^k, w_m^k)(p(z^k)p(w_m^k))^{beta}}{(q_{phi_z}(z^k|X)q_{phi_{w_m}}(w_m^k|x_m))^{beta}}prod_{n in S_{obs}(X), nneq m}p_{theta_n}(x_n|z^k,tilde{w}_n^k)

In simpler terms; we reconstruct only available modalities and compute the joint posterior with available modalities only.

\[q_{\phi_z}(z|X) = \frac{1}{|S_{obs}(X)|} \sum_{m \in S_{obs}(X)} q_{\phi_{z_m}}(z|x_m)\]
class multivae.models.MMVAEPlusConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, K=10, prior_and_posterior_dist='laplace_with_softmax', learn_shared_prior=False, learn_modality_prior=True, beta=1.0, modalities_specific_dim=None, reconstruction_option='joint_prior', loss='dreg_looser')[source]

This class is the configuration class for the MMVAE+ model.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • K (int) – the number of samples to use in the IWAE loss. Default to 1.

  • prior_and_posterior_dist (str) – The type of distribution to use for posterior and prior. Possible values [‘laplace_with_softmax’,’normal_with_softplus’,’normal’]. Default to ‘laplace_with_softmax’ the posterior distribution that is used in the original paper.

  • learn_shared_prior (bool) – If True, the mean and variance of the shared latent space prior are optimized during the training. Default to False.

  • learn_modality_prior (bool) – If True, the mean and variance of the modality latent space priors are optimized during the training. Default to True. It is key for the method to work.

  • beta (float) – When using K = 1 (ELBO loss), the beta factor regularizes the divergence term. Default to 1.

  • modalities_specific_dim (int) – The dimensionality of the modalitie’s private latent spaces. Must be provided.

  • reconstruction_option (Literal['single_prior','joint_prior']) – Specifies how to sample the modality specific variable when reconstructing/ translating modalities during inference. Default to ‘joint_prior’ used in the article.

  • loss (Literal['dreg_looser','iwae_looser']) – Default to ‘dreg_looser’.

class multivae.models.MMVAEPlus(model_config, encoders=None, decoders=None)[source]

The MMVAE+ model.

Parameters:
  • model_config (MMVAEPlusConfig) – An instance of MMVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseMultilatentEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Multivae’s BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

  • batch_size_K (int) – Default to 100.

Returns:

The negative log-likelihood summed over the batch.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

contains fields

’z’ (torch.Tensor (n_data, N, latent_dim)) ‘one_latent_space’ (bool) = False ‘modalities_z’ (Dict[str,torch.Tensor (n_data, N, latent_dim) ])

Return type:

ModelOutput

forward(inputs, **kwargs)[source]

Compute loss and metrics.

generate_from_prior(n_samples, **kwargs)[source]

Generate latent samples from the prior distribution. This is the base class in which we consider a static standard Normal Prior. This may be overwritten in subclasses.

Parameters:
  • n_samples (int) – number of samples to generate

  • **kwargs – additional arguments

Returns:

A ModelOutput instance containing the generated samples

Return type:

ModelOutput

property pz_params

From the prior mean and log_covariance, return the mean and standard deviation, either applying softmax or not depending on the choice of prior distribution.

Returns:

mean, std

Return type:

tuple