CVAE: Conditional Variational Autoencoder

Conditional Variational Autoencoder model (https://arxiv.org/abs/1906.02691).

This model is used to model the distribution of \(y\) knowing \(x\). The general generative model is:

\[p_{\theta}(y,z|x) = p_{\theta}(y|z,x)p_{\theta}(z|x)\]

where \(p_{\theta}(y|z,x)\) is called the decoder and \(p_{\theta}(z|x)\) is a prior distribution. This prior might depend on \(x\) or not, if \(z\) is considered independant of \(x\). The approximate posterior distribution can also depend on \(x\) : \(q_{\phi}(z|y,x)\) (or not).

The Evidence Lower Bound writes:

\[\mathcal{L(y|x)} = \mathbb{E}_{q_{\phi}(z|y,x)}\left( \ln p_{\theta}(y|z,x) \right) - KL(q_{\phi}(z|y,x)||p_{\theta}(z|x))\]
class multivae.models.CVAEConfig(conditioning_modalities, main_modality, input_dims=None, latent_dim=10, beta=1.0, decoder_dist='normal', decoder_dist_params=<factory>, custom_architectures=<factory>)[source]

This is the configuration class for the Conditional Variational Autoencoder model.

Parameters:
  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • latent_dim (int) – The dimension of the latent space. Default: 10.

  • conditioning_modalities (List[str]) – The modalities to condition the model on.

  • main_modality (str) – The main modality to reconstruct.

  • beta (float) – The parameter that weighs the KL divergence in the ELBO. Default to 1.0.

  • decoder_dist (str) – The decoder distribution to use. Possible values [‘normal’, ‘bernoulli’, ‘laplace’, ‘categorical’]. For Bernoulli distribution, the decoder is expected to output logits.

  • decoder_dist_params (dict) – To eventually specify parameters for the output decoder distribution. Default to None.

class multivae.models.CVAE(model_config, encoder=None, decoder=None, prior_network=None)[source]

Main class for the Conditional Variational Autoencoder.

Parameters:
  • model_config (CVAEConfig) – the model configuration class.

  • encoder (BaseEncoder) – The encoder network.

  • decoder (BaseConditionalDecoder) – The conditional decoder network.

  • prior_network (BaseJointEncoder) – Takes the conditional modalities and returns the parameters for the prior distribution.

decode(embedding, **kwargs)[source]

Decode embeddings to reconstruct the main modality.

Returns:

A ModelOutput instance containing the reconstruction.

>>> embeddings = model.encode(inputs, N=2)
>>> output = model.decode(embeddings)
>>> output.reconstruction

encode(inputs, N=1, **kwargs)[source]

Generate latent code by encoding the data and sampling from the posterior distribution.

Parameters:
  • inputs (MultimodalBaseDataset) – The data to encode.

  • N (int, optional) – number of samples per datapoint to sample from the posterior. Defaults to 1.

Returns:

A ModelOutput instance containing the embeddings. The shape of the embeddings is (N,batch_size,latent_dim)

>>> output = model.encode(inputs, N=2)
>>> z = output.z

forward(inputs, **kwargs)[source]

Forward pass of the Conditional Variational Autoencoder.

Parameters:

inputs (dict) – A dictionary containing the input data for each modality.

Returns:

A ModelOutput instance containing the loss and metrics.

Return type:

ModelOutput

generate_from_prior(cond_mod_data, N=1, **kwargs)[source]

Generates latent variables from the prior, conditioning on cond_mod_data.

Args :

cond_mod_data (Dict[str,torch.Tensor]) : Data from the conditioning modality. N (int) : number of latent codes to sample from the prior per datapoint

Returns:

A ModelOutput instance containing the embeddings.

predict(inputs, cond_mod='all', N=1, **kwargs)[source]

Reconstruct from the input or from the conditioning modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The data to use for prediction.

  • cond_mod (Union[str, list]) – Either ‘all’ to perform reconstruction or the list of conditioning modalities to generate from the prior.

  • N (int) – number of samples per datapoint to sample from the posterior or prior.

Returns:

A ModelOutput instance containing the reconstruction / generation.

>>> # reconstructions
>>> output = model.predict(inputs, cond_mod = 'all')
>>> reconstruction = output.reconstruction

Return type:

ModelOutput