Base MultiVae Class

Abstract class.

class multivae.models.BaseMultiVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>)[source]

This is the base config for a Multi-Modal VAE model.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

class multivae.models.BaseMultiVAE(model_config, encoders=None, decoders=None)[source]

Base class for Multimodal VAE models.

Parameters:
  • model_config (BaseMultiVAEConfig) – An instance of BaseMultiVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

check_input_dims(model_config)[source]

Check that the input dimensions are coherent with the provided number of modalities.

compute_cond_nll(inputs, subset, pred_mods, k_iwae=1000)[source]

Compute the conditional likelihood :math: ln p(x_{pred}|x_{cond})` with MonteCarlo Sampling and the approximation : .. math:

\ln p(x_{pred)|x_{cond}) = \frac{1}{K}\sum_{z^{(i)} ~ q(z^{(i)}|x_{cond}), i=1}^{K} \ln p(x_{pred}|z^{(i)}).
Parameters:
  • inputs (MultimodalBaseDataset) – the data to compute the likelihood on.

  • cond_mod (str) – the modality to condition on

  • gen_mod (str) – the modality to condition on

  • K (int, optional) – number of samples per batch. Defaults to 1000.

Returns:

Contains the negative log-likelihood for each modality in pred_mods.

Return type:

dict

decode(embedding, modalities='all')[source]

Decode a latent variable z in all modalities specified in modalities.

Parameters:
  • embedding (ModelOutput) – contains the latent variables. It must have the same format as the output of the encode function.

  • modalities (Union(List, str), Optional) – the modalities to decode from z. Default to ‘all’.

Returns:

containing a tensor per modality name.

Return type:

ModelOutput

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

forward(inputs, **kwargs)[source]

Main forward pass outputing the VAE outputs This function should output a ModelOutput instance gathering all the model outputs.

Parameters:

inputs (BaseDataset) – The training data with labels, masks etc…

Returns:

A ModelOutput instance providing the outputs of the model.

Return type:

ModelOutput

Note

The loss must be computed in this forward pass and accessed through loss = model_output.loss

generate_from_prior(n_samples, **kwargs)[source]

Generate latent samples from the prior distribution. This is the base class in which we consider a static standard Normal Prior. This may be overwritten in subclasses.

Parameters:
  • n_samples (int) – number of samples to generate

  • **kwargs – additional arguments

Returns:

A ModelOutput instance containing the generated samples

Return type:

ModelOutput

predict(inputs, cond_mod='all', gen_mod='all', N=1, flatten=False, **kwargs)[source]

Generate in all modalities conditioning on a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The data to condition on. It must contain at least the modalities contained in cond_mod.

  • cond_mod (Union[list, str], optional) – The modalities to condition on. Defaults to ‘all’.

  • gen_mod (Union[list, str], optional) – The modalities to generate. Defaults to ‘all’.

  • N (int) – Number of samples to generate. Default to 1.

  • flatten (int) – If N>1 and flatten is False, the returned samples have dimensions (N,len(inputs),…). Otherwise, the returned samples have dimensions (len(inputs)*N, …)

Returns:

~pythae.models.base.base_utils.ModelOutput

..codeblock :
>>> predictions = model.predict(test_set, cond_mod = ['modality1', 'modality2'], gen_mod='modality3')
>>> predictions.modality3
sanity_check(encoders, decoders)[source]

Check coherences between the encoders, decoders and model configuration.

set_decoders(decoders)[source]

Set the decoders of the model.

set_decoders_dist(recon_dict, dist_params_dict)[source]

Set the reconstruction losses functions decoders_dist and the log_probabilites functions recon_log_probs. recon_log_probs is the normalized negative version of recon_loss and is used only for likelihood estimation.

set_encoders(encoders)[source]

Set the encoders of the model.

set_rescale_factors()[source]

Set the rescale factors for the reconstruction losses. When using likelihood rescaling, the rescale factors are used to compute the reconstruction losses.

update()[source]

Method that allows model update during the training (at the end of a training epoch).

If needed, this method must be implemented in a child class.

By default, it does nothing.