JMVAE

Implementation of the Joint Multimodal VAE model from the paper “Joint Multimodal Learning with Deep Generative Models” (http://arxiv.org/abs/1611.01891).

The JMVAE model is one of the first multimodal variational autoencoders models. It has a dedicated joint encoder network \(q_{\phi}(z|X)\) and surrogate unimodal encoders \(q_{\phi_j}(z|x_j)\). The JMVAE loss has additional terms to the ELBO to fit the unimodal encoders:

\[\mathcal{L}_{JMVAE}(X) = \mathbb E_{q_{\phi}(z|X)}\left[ p_{\theta}(z|X) \right] - KL\left(q_{\phi}(z|X)||p_{\theta}(z)\right) - \alpha \sum_{j=1}^{M} KL \left (q_{\phi}(z|X) || q_{\phi_j}(z|x_j) \right)\]

where \(M\) is the number of modalities. This loss can be linked to the Variation of Information (VI) between modalities [1]. \(\alpha\) is the parameter that controls a trade-off between the quality of reconstruction and the quality of cross-modal generation [1]. This model has been proposed for only two-modalities, but an extension has been proposed in [2] for additional modalities.

During inference, when \(M \leq 2\), the subset posteriors \(p_{\theta}(z|(x_j)_{j \in S})\) can be approximated by the product of experts (PoE) of the already trained unimodal encoders \(q_{\phi}(z|x_j)_{1 \leq j \leq M}\). Since the unimodal posteriors are normal distributions, the PoE has a closed-form and can easily be computed.

The JMVAE model uses annealing during training: which means that a weighting factor that ponders the regularizations terms is linearly augmented from 0 to 1 during the first epochs.

Note

As it uses a joint encoder network, this model can not be trained with partially observed samples.

[1] Suzuki et al, 2016. Joint Multimodal Learning with Deep Generative Models.

[2] Senellart et al, 2023. Improving Multimodal Variational Autoencoders with Normalizing Flows and Deep Canonical Correlation Analysis.

class multivae.models.JMVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, alpha=0.1, warmup=10, beta=1.0)[source]

This is the base config for the JMVAE model.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • alpha (float) – the parameter that controls the tradeoff between the ELBO and the regularization term. Default to 0.1.

  • warmup (int) – The number of warmup epochs during training. The JMVAE model uses annealing. The KL terms in the objective are weighted by a factor beta which is linearly brought to 1 during the first warmup epochs. Default to 10.

  • beta (float) – Weighing term for the regularization of the joint posterior to the prior. This parameter doesn’t exist in the original method, it is a simple add-on. Default to 1.

class multivae.models.JMVAE(model_config, encoders=None, decoders=None, joint_encoder=None, **kwargs)[source]

The Joint Multimodal Variational Autoencoder model.

Parameters:
  • model_config (JMVAEConfig) – An instance of JMVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoder (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

  • joint_encoder (BaseJointEncoder) – Takes all the modalities as input. If none is provided, one is created from the unimodal encoders. Default : None.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The data to encode.

  • cond_mod (Union[list, str], optional) – The modalities to use to compute the posterior

  • 'all'. (distribution. Defaults to)

  • N (int, optional) – The number of samples to generate from the posterior distribution

  • 1. (for each datapoint. Defaults to)

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Raises:
  • AttributeError – _description_

  • AttributeError – _description_ Generate encodings conditioning on all modalities or a subset of modalities.

Returns:

z (torch.Tensor (N, n_data, latent_dim)) one_latent_space (bool) = True

Return type:

ModelOutput instance with fields

forward(inputs, **kwargs)[source]

Performs a forward pass of the JMVAE model on inputs.

Parameters:
  • inputs (MultimodalBaseDataset)

  • warmup (int) – number of warmup epochs to do. The weigth of the regularization augments linearly to reach 1 at the end of the warmup. The enforces the optimization of the reconstruction term only at first.

  • epoch (int) – the epoch number during which forward is called.

Returns:

ModelOutput