MoPoE: Mixture of Product of Experts
MoPoE model from Generalized Multimodal ELBO (Sutter et al 2021).
The MoPoE-VAE uses a Mixture of Product of Experts.
Formally, for each subset \(S \in \mathcal P(\{1 ,\dots,M\})\) a PoE distribution is defined \(\tilde{q}_{\phi}(z|(x_j)_{j \in S}) = PoE((q_{\phi_j}(z|x_j))_{j \in S})\).
Then the joint posterior is defined as:
The ELBO is optimized:
The MoPoE model can be used with additional modality-specific latent spaces.
Note
To adapt this model to the partially observed setting, the loss is computed with all available subsets \(S \in S_{obs}(X)\), where \(S_{obs}(X)\) is the set of observed modalities for the sample \(X\) at hand, instead of all subsets \(S \in P(\{1 ,\dots,M\}\).
- class multivae.models.MoPoEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, subsets=None, beta=1.0, beta_style=1.0, modalities_specific_dim=None)[source]
This class is the configuration class for the MoPoE model, from ‘Generalized Multimodal ELBO’ Sutter 2021 (https://arxiv.org/abs/2105.02470).
- Parameters:
n_modalities (int) – The number of modalities. Default: None.
latent_dim (int) – The dimension of the latent space. Default: None.
input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).
uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.
rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.
decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.
decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with
decoder_dist_params = {'mod1' : {"scale" : 0.75}}.subsets (Union[List[list], Dict[list]]) – List containing the subsets to consider. If None is provided, all subsets are considered. Example of valid input : [[‘mod_1’, ‘mod_2’], [‘mod_1’], [‘mod_2’]]. Default to None.
beta (float) – The weight to the KL divergence term in the ELBO. Default to 1.0
beta_style (float) – The beta factor for additional elbos in the case, there are multiple latent spaces. Default to 1.
modalities_specific_dim (dict) – a dictionary containing the modalities names and the dimension of the additional latent space for each modality. Default to None.
- class multivae.models.MoPoE(model_config, encoders=None, decoders=None)[source]
Implementation for the Mixture of Product of experts model from ‘Generalized Multimodal ELBO’ Sutter 2021 (https://arxiv.org/abs/2105.02470).
This implementation is heavily based on the official one at https://github.com/thomassutter/MoPoE.
- Parameters:
model_config (MoPoEConfig) – Contains all the parameters for the model.
encoders (dict) – Contains the encoder for each modality. When using modalities’ specific latent spaces, the encoders must be instances of ~multivae.models.nn.base_architectures.BaseMultilatentEncoder. Else, the encoders must be instances of ~pythae.models.nn.base_architectures.BaseEncoder. When None are provided, default MLP architectures are used.
decoders (dict) – Contains the decoder for each modality. Each decoder must be an instance of ~pythae.models.nn.base_architectures.BaseDecoder. When using modalities’s specific latent spaces, the decoder takes as input the concatenation of both latent codes. When None are provided, default MLP architectures are used.
- all_subsets()[source]
Returns a list containing all possible subsets of the modalities. (But the empty one).
- calc_joint_divergence(mus, logvars, weights)[source]
Computes the KL divergence between the mixture of experts and the prior, by developping into the sum of the tractable KLs divergences of each expert.
- Parameters:
mus (Tensor) – The means of the experts. (n_subset,n_samples, latent_dim)
logvars (Tensor) – The logvars of the experts.(n_subset,n_samples, latent_dim)
weights (Tensor) – The weights of the experts.(n_subset,n_samples)
- Returns:
The group divergence summed over modalities, A tensor containing the KL terms for each experts.
- Return type:
Tensor, Tensor
- compute_joint_nll(inputs, K=1000, batch_size_K=100)[source]
Estimate the negative joint likelihood. In the original code, the product of experts is used as inference distribution for computing the nll instead of the MoPoe, but that is less coherent with the definition of the MoPoE definition as the joint posterior.
- Parameters:
inputs (MultimodalBaseDataset) – a batch of samples.
K (int) – the number of importance samples for the estimation. Default to 1000.
batch_size_K (int) – Default to 100.
- Returns:
The negative log-likelihood summed over the batch.
- compute_joint_nll_paper(inputs, K=1000, batch_size_K=100)[source]
Estimates the negative joint likelihood using the PoE posterior as importance sampling distribution. The result is summed over the input batch.
This is the method used in the original paper implementation.
- deterministic_mixture_component_selection(mus, logvars, w_modalities)[source]
Associate a subset mu and log_covariance per sample in a balanced way, so that the proportion of samples per subset correspond to w_modalities.
- encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]
Generate encodings conditioning on all modalities or a subset of modalities. We use the product of experts on the conditioning modalities.
- Parameters:
inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.
cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.
N (int) – The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).
- Returns:
z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) modalities_z (Dict[str,torch.Tensor (n_data, N, latent_dim) ])
- Return type:
ModelOutput instance with fields
- forward(inputs, **kwargs)[source]
Main forward pass outputing the VAE outputs This function should output a
ModelOutputinstance gathering all the model outputs.- Parameters:
inputs (BaseDataset) – The training data with labels, masks etc…
- Returns:
A ModelOutput instance providing the outputs of the model.
- Return type:
ModelOutput
Note
The loss must be computed in this forward pass and accessed through
loss = model_output.loss
- inference(inputs, **kwargs)[source]
- Parameters:
inputs (MultimodalBaseDataset) – The data.
- Returns:
all the subset and joint posteriors parameters.
- Return type:
- modality_encode(inputs, **kwargs)[source]
Computes for each modality, the parameters mu and logvar of the unimodal posterior.
- Parameters:
inputs (MultimodalBaseDataset) – The data to encode.
- Returns:
Containing for each modality the encoder output.
- Return type:
- random_mixture_component_selection(mus, logvars, availabilities)[source]
Randomly select a subset for each sample among the available subsets.
- Parameters:
mus (tensor) – (n_subset,n_samples,latent_dim) the means of subset posterior.
logvars (tensor) – (n_subset,n_samples,latent_dim) the log covariance of subset posterior.
availabilities (tensor) – (n_subset,n_samples) boolean tensor.