MVTCAE : Multiview Correlation Autoencoder

MVTCAE model from Multi-View Representation Learning via Total Correlation Objective.

MVTCAE uses a Product-of-Experts in a similar fashion as the MVAE but without the prior:

\[q_{\phi}(z|X) \sim \prod_j q_{\phi_j}(z|x_j)\]

The MVTCAE loss is derived from a Total Correlation Analysis and writes as follows:

\[\begin{split}\begin{split} \mathcal L(X) &= \frac{M - \alpha}{M}\mathbb{E}_{q_{\phi}(z|X)}\left [\log p_{\theta}(X|z) \right] \\&- \beta \left[(1- \alpha) KL(q_{\phi}(z|X)|| p_{\theta}(z)) + \frac{\alpha}{M} \sum_{j=1}^{M} KL(q_{\phi}(z|X) || q_{\phi_j}(z|x_j) \right] \end{split}\end{split}\]

Although this loss derives from a different analysis, it uses same terms that in the JMVAE model. A \(\beta\) factor weighs the regularization, while the \(\alpha\) parameters is used to ponder the different divergence terms.

Note

For the partially observed setting, we follow the authors’ indications setting the variance for the missing modalities’ decoders to \(\infty\) which amounts to setting the reconstruction loss to 0 for those modalities. The KL terms for missing modalities are also set to 0.

class multivae.models.MVTCAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, alpha=0.1, beta=2.5)[source]

This is the base config class for the MVTCAE model from ‘Multi-View Representation Learning via Total Correlation Objective’ Neurips 2021. The code is based on the original implementation that can be found here : https://github.com/gr8joo/MVTCAE/blob/master/run_epochs.py.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • alpha (float) – The parameter that ponderates the total correlation ratio in the loss. Default to 0.1

  • beta (float) – The parameter that weights the sum of all KLs. Default to 2.5.

class multivae.models.MVTCAE(model_config, encoders=None, decoders=None)[source]

MVTCAE model.

Parameters:
  • model_config (MVTCAEConfig) – An instance of MVTCAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

  • batch_size_K (int) – Default to 100.

Returns:

The negative log-likelihood summed over the batch.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) = True

Return type:

ModelOutput instance with fields

forward(inputs, **kwargs)[source]

Forward pass of the model that returns the loss.