TELBO

Implementation the TELBO algorithm from “Generative Models of Visually Grounded Imagination” (https://arxiv.org/abs/1705.10762).

The TELBO model use a joint encoder \(q_{\phi}(z|X)\) as the JMVAE but uses the following Triple ELBO loss:

\[\mathcal L(X) = \mathbb E_{q_{\phi}(z|X)}\left[ \frac{p_{\theta}(z,X)}{q_{\phi}(z|X)} \right] + \sum_{j=1}^{M} \mathbb E_{q_{\phi}(z|x_j)}\left[ \frac{p_{\theta}(z,x_j)}{q_{\phi}(z|x_j)} \right]\]

It is trained with a two-steps training, first learning the joint encoder and decoders then training the unimodal encoders \(q_{\phi}(z|x_j)\) with previous parameters fixed.

Note

This model must be trained with the ~multivae.trainers.multistage.MultiStageTrainer

Note

As it uses a joint encoder network, this model can not be trained with partially observed samples.

class multivae.models.TELBOConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=True, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, warmup=10, lambda_factors=None, gamma_factors=None)[source]

Configuration class for the TELBO model from (arXiv:1705.10762 [cs, stat]) “Generative Models of Visually Grounded Imagination” (Vedantam et al,2018).

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • warmup (int) – How many epochs to train the joint encoder and decoders before freezing them and learn the unimodal encoders. It is recommended to use half of the total training time for the first step. Default to 10.

  • lambda_factors (dict[str,float]) – Ponderation factors for the reconstructions in the Joint Elbo. If None is provided but uses_likelihood_rescaling is True, we use the inverse product of dimensions as a rescaling factor for each modality. If None is provided and uses_likelihood_rescaling is False, each factor is set to one. Default to None.

  • gamma_factors (dict[str,float]) – Ponderation factors for the reconstructions in the unimodal elbos. If None is provided but uses_likelihood_rescaling is True, we use the inverse product of dimensions as a rescaling factor for each modality. If None is provided and uses_likelihood_rescaling is False, each factor is set to one. Default to None.

  • uses_likelihood_rescaling – Indicates how to set lambda or gamma factors when None are provided. Ignored when lambda_factors and gamma_factors are provided. Default to True.

class multivae.models.TELBO(model_config, encoders=None, decoders=None, joint_encoder=None, **kwargs)[source]

The Triple ELBO VAE model.

Parameters:
  • model_config (TELBOConfig) – An instance of TELBOConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

  • joint_encoder (BaseJointEncoder) – takes all the modalities as an input. If none is provided, one is created from the unimodal encoders. Default : None.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

z (torch.Tensor (N,n_data, latent_dim)) one_latent_space (bool) = True

Return type:

ModelOutput instance with fields

forward(inputs, **kwargs)[source]

Forward pass of the model.