MVAE
Implementation of the MVAE model from “Multimodal Generative Models for Scalable Weakly-Supervised Learning” (https://arxiv.org/abs/1802.05335).
The MVAE model was the first aggregated model proposed by [1]. The joint posterior is modelled as a Product-of-Experts \(q_{\phi}(z|X) \propto p(z)\prod_j q_{\phi_j}(z|x_j)\). The ELBO is then optimized:
This ELBO can be computed on a subset of modalities \(S\) by taking only the modalities in the subset to compute the PoE:
To ensure all unimodal encoders are correctly trained, the MVAE uses a sub-sampling training paradigm, meaning that at iteration, the ELBO is computed for several subsets: the joint subset \(\{1,..,M\}\), the unimodal subsets and for \(K\) random subsets. For each sample, the objective then becomes:
where \(s_k\) are random subsets.
Note
As an aggregated model, this model can be used in the partially observed setting. In the partially observed setting, we don’t use the sub-sampling paradigm since the dataset is naturally sub-sampled, and for each sample \(X\), we compute the ELBO with only the observed modalities in \(S_{obs}(X)\) using the posterior:
[1] Wu et al (2018), “Multimodal Generative Models for Scalable Weakly-Supervised Learning”, https://arxiv.org/abs/1802.05335
- class multivae.models.MVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, use_subsampling=True, k=0, warmup=10, beta=1)[source]
Config class for the MVAE model from ‘Multimodal Generative Models for Scalable Weakly-Supervised Learning’. https://proceedings.neurips.cc/paper/2018/hash/1102a326d5f7c9e04fc3c89d0ede88c9-Abstract.html.
- Parameters:
n_modalities (int) – The number of modalities. Default: None.
latent_dim (int) – The dimension of the latent space. Default: None.
input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).
uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.
rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.
decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.
decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with
decoder_dist_params = {'mod1' : {"scale" : 0.75}}.use_subsampling (bool) – If True, we use the subsampling paradigm described in the article, not only taking the joint ELBO but also the unimodal ELBOs and k random subset elbos. This is useful when training with a complete dataset but should be set to False when the dataset is already incomplete. Default to True.
k (int) – The number of subsets to use in the objective. The MVAE objective is the sum of the unimodal ELBOs, the joint ELBO and of k random subset ELBOs. Default to 0.
warmup (int) – If warmup > 0, the MVAE model uses annealing during the first warmup epochs. In the objective, the KL terms are weighted by a factor beta that is linearly brought to 1 during the first warmup epochs. Default to 10.
beta (float) – The scaling factor for the divergence term. Default to 1.
- class multivae.models.MVAE(model_config, encoders=None, decoders=None)[source]
The Multi-modal VAE model.
- Parameters:
model_config (MVAEConfig) – An instance of MVAEConfig in which any model’s parameters is made available.
encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.
decoder (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.
- compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]
Estimate the negative joint likelihood.
- Parameters:
inputs (MultimodalBaseDataset) – a batch of samples.
K (int) – the number of importance samples for the estimation. Default to 1000.
batch_size_K (int) – Default to 100.
- Returns:
The negative log-likelihood summed over the batch.
- compute_mu_log_var_subset(inputs, subset)[source]
Computes the parameters of the posterior when conditioning on the modalities contained in subset.
- encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]
Generate encodings conditioning on all modalities or a subset of modalities.
- Parameters:
inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.
cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.
N (int) – The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).
- Returns:
contains z (torch.Tensor (n_data, N, latent_dim)), one_latent_space (bool) = True
- Return type:
ModelOutput
- forward(inputs, **kwargs)[source]
The main function of the model that computes the loss and some monitoring metrics. One of the advantages of MVAE is that we can train with incomplete data.
- Parameters:
inputs (MultimodalBaseDataset) – The data. It can be an instance of IncompleteDataset which contains a field masks for weakly supervised learning. masks is a dictionary indicating which datasamples are missing in each of the modalities. For each modality, a boolean tensor indicates which samples are available. (The non available samples are assumed to be replaced with zero values in the multimodal dataset entry.)