CRMVAE

The CRMVAE model proposed in https://openreview.net/forum?id=Rn8u4MYgeNJ.

It builds upon the MVTCAE model by adding unimodal reconstruction terms.

The joint posterior is a Product-Of-Experts \(q_{\phi}(z|X) \propto \prod_m q_{\phi_m}(z|x_m)\)

The loss of the model then writes:

\[\begin{split}\mathcal{L}(X) = \sum_{m=1}^{M} \pi_m \mathbb{E}_{q_\phi(z|X)} \left( \log p_{\theta}(x_m|z) \right) + \pi_m \mathbb{E}_{q_\phi(z|x_m)} \left( \log p_{\theta}(x_m|z) \right) \\ - \sum_{m=1}^{M} \pi_m KL(q_{\phi}(z|X)||q_{\phi}(z|x_m)) - \pi_{M+1}KL(q_{\phi}(z|X)||p(z))\end{split}\]

In practive \(\pi_m = \frac{1}{M+1}\).

Note

This model can be used on incomplete datasets. In that case, the product of experts and the reconstructions are computed only available modalities for each sample.

class multivae.models.CRMVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, beta=2.5)[source]

This is the base config class for the CRMVAE model.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • beta (float) – The parameter that weights the sum of all KLs. Default to 2.5.

class multivae.models.CRMVAE(model_config, encoders=None, decoders=None)[source]

Main class for the CRMVAE model proposed in https://openreview.net/forum?id=Rn8u4MYgeNJ.

Parameters:
  • model_config (CRMVAEConfig) – An instance of CRMVAEConfig containing all the parameters for the model.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

  • batch_size_K (int) – Default to 100.

Returns:

The negative log-likelihood summed over the batch.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) = True

Return type:

ModelOutput instance with fields

forward(inputs, **kwargs)[source]

Forward pass of the model. Returns the loss and additional metrics in a ModelOutput Instance.