DMVAE

Implementation of the DMVAE model from “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations” (Lee & Pavlovic 2021)(https://par.nsf.gov/servlets/purl/10297662).

This model is an aggregated model with a shared latent variable \(z_s\) and modality-specific latent variables \(z_{p_i}\). The joint posterior is a Product-Of-Experts:

\[q(z_s|X) \propto p(z_s)\prod_{i=1}^{M} q(z_s|x_i)\]

The joint ELBO writes:

\[\begin{split}&\sum_i \lambda_i \mathbb{E}_{\substack{q_{\phi}(z_{p_i}|x_i)} \\ q_{\phi}(z_s|X)}\left[ \log p_{\theta}(x_i|z_{p_i},z_s)\right] \\ & -KL(q_{\phi}(z_{p_i}|x_i)||p(z_{p_i})) - KL(q_{\phi}(z_s|X)||p(z_s))\\ & + \sum_j \lambda_i \mathbb{E}_{\substack{q_{\phi}(z_{p_i}|x_i)} \\ q_{\phi}(z_s|x_j)}\left[ \log p_{\theta}(x_i|z_{p_i},z_s)\right] \\ & -KL(q_{\phi}(z_{p_i}|x_i)||p(z_{p_i})) - KL(q_{\phi}(z_s|x_j)||p(z_s))\end{split}\]

This loss incorporates differents ELBOS, using either the joint posterior or each of the unimodal posteriors.

Note

This model can be used in the partially observed setting. In that case, for each sample \(X\), we take the loss and the product-of-experts on available modalities only.

class multivae.models.DMVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, modalities_specific_dim=None, modalities_specific_betas=None, beta=1.0)[source]

Config class for the DMVAE model from “Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations”.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • modalities_specific_dims (dict) – The latent dimensions for the private spaces.

  • beta (float) – The scaling factor for the joint divergence term. Default to 1.

  • modality_specific_betas (dict) – the betas for the private KL divergence terms. Default to None.

class multivae.models.DMVAE(model_config, encoders=None, decoders=None)[source]

The DMVAE model from the paper ‘Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations’.

Mihee Lee, Vladimir Pavlovic

Parameters:
  • model_config (DMVAEConfig) – An instance of DMVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseMultilatentEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Multivae’s BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

  • batch_size_K (int) – Default to 100.

Returns:

The negative log-likelihood summed over the batch.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

Contains fields

’z’ (torch.Tensor (N, n_data, latent_dim)) ‘one_latent_space’ (bool) = False ‘modalities_z’ (dict[str,torch.Tensor (N, n_data,mod_latent_dim)])

Return type:

ModelOutput

forward(inputs, **kwargs)[source]

The main function of the model that computes the loss and some monitoring metrics. One of the advantages of DMVAE is that we can train with incomplete data.

Parameters:

inputs (MultimodalBaseDataset) – The data. It can be an instance of IncompleteDataset which contains a field masks for weakly supervised learning. masks is a dictionary indicating which datasamples are missing in each of the modalities. For each modality, a boolean tensor indicates which samples are available. (The non available samples are assumed to be replaced with zero values in the multimodal dataset entry.)

generate_from_prior(n_samples, **kwargs)[source]

Generates latent variables from the prior for the shared latent spaces and for each modality specific latent space.

Parameters:

n_samples