CVAE: Conditional Variational Autoencoder
Conditional Variational Autoencoder model (https://arxiv.org/abs/1906.02691).
This model is used to model the distribution of \(y\) knowing \(x\). The general generative model is:
where \(p_{\theta}(y|z,x)\) is called the decoder and \(p_{\theta}(z|x)\) is a prior distribution. This prior might depend on \(x\) or not, if \(z\) is considered independant of \(x\). The approximate posterior distribution can also depend on \(x\) : \(q_{\phi}(z|y,x)\) (or not).
The Evidence Lower Bound writes:
- class multivae.models.CVAEConfig(conditioning_modalities, main_modality, input_dims=None, latent_dim=10, beta=1.0, decoder_dist='normal', decoder_dist_params=<factory>, custom_architectures=<factory>)[source]
This is the configuration class for the Conditional Variational Autoencoder model.
- Parameters:
input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).
latent_dim (int) – The dimension of the latent space. Default: 10.
conditioning_modalities (List[str]) – The modalities to condition the model on.
main_modality (str) – The main modality to reconstruct.
beta (float) – The parameter that weighs the KL divergence in the ELBO. Default to 1.0.
decoder_dist (str) – The decoder distribution to use. Possible values [‘normal’, ‘bernoulli’, ‘laplace’, ‘categorical’]. For Bernoulli distribution, the decoder is expected to output logits.
decoder_dist_params (dict) – To eventually specify parameters for the output decoder distribution. Default to None.
- class multivae.models.CVAE(model_config, encoder=None, decoder=None, prior_network=None)[source]
Main class for the Conditional Variational Autoencoder.
- Parameters:
model_config (CVAEConfig) – the model configuration class.
encoder (BaseEncoder) – The encoder network.
decoder (BaseConditionalDecoder) – The conditional decoder network.
prior_network (BaseJointEncoder) – Takes the conditional modalities and returns the parameters for the prior distribution.
- decode(embedding, **kwargs)[source]
Decode embeddings to reconstruct the main modality.
- Returns:
A ModelOutput instance containing the reconstruction.
>>> embeddings = model.encode(inputs, N=2) >>> output = model.decode(embeddings) >>> output.reconstruction
- encode(inputs, N=1, **kwargs)[source]
Generate latent code by encoding the data and sampling from the posterior distribution.
- Parameters:
inputs (MultimodalBaseDataset) – The data to encode.
N (int, optional) – number of samples per datapoint to sample from the posterior. Defaults to 1.
- Returns:
A ModelOutput instance containing the embeddings. The shape of the embeddings is (N,batch_size,latent_dim)
>>> output = model.encode(inputs, N=2) >>> z = output.z
- forward(inputs, **kwargs)[source]
Forward pass of the Conditional Variational Autoencoder.
- Parameters:
inputs (dict) – A dictionary containing the input data for each modality.
- Returns:
A ModelOutput instance containing the loss and metrics.
- Return type:
ModelOutput
- generate_from_prior(cond_mod_data, N=1, **kwargs)[source]
Generates latent variables from the prior, conditioning on cond_mod_data.
- Args :
cond_mod_data (Dict[str,torch.Tensor]) : Data from the conditioning modality. N (int) : number of latent codes to sample from the prior per datapoint
- Returns:
A ModelOutput instance containing the embeddings.
- predict(inputs, cond_mod='all', N=1, **kwargs)[source]
Reconstruct from the input or from the conditioning modalities.
- Parameters:
inputs (MultimodalBaseDataset) – The data to use for prediction.
cond_mod (Union[str, list]) – Either ‘all’ to perform reconstruction or the list of conditioning modalities to generate from the prior.
N (int) – number of samples per datapoint to sample from the posterior or prior.
- Returns:
A ModelOutput instance containing the reconstruction / generation.
>>> # reconstructions >>> output = model.predict(inputs, cond_mod = 'all') >>> reconstruction = output.reconstruction
- Return type:
ModelOutput