CMVAE

Implementation of “Deep Generative Clustering with Multimodal Diffusion Variational Autoencoders” (Palumbo et al, 2023)(https://openreview.net/forum?id=k5THrhXDV3).

This model builds on the MMVAE+ by adding a Mixture-of-Gaussians prior on the shared latent space. The generative model is as follows:

\[ \begin{align}\begin{aligned}\begin{split}\begin{split} & c \sim \pi \\ & z|c \sim \mathcal{N}(\mu_c, \Sigma_c) \\ & \forall m, w_m \sim p(w_m) \\ & \forall m, x_m|z, w_m \sim p_{\theta}(x_m|z,w_m)\\\end{split}\\\end{split}\end{aligned}\end{align} \]

The joint posterior for the latent space variable \(z\) is a mixture-of-experts:

\[q_{\Phi_z}(z|X) = \frac{1}{M} \sum_m q_{\Phi_z}(z|x_m)\]

This model uses trainable auxiliary priors \(r_m(w_m)\) for the modality specific latent spaces during training.

The ELBO of the CMVAE model writes:

\[\begin{split}\frac{1}{M}\sum_{m=1}^{M} \mathbb E_{ \substack{ q_{\Phi_{z_m}}(z|x_m) \\ q_{\phi_{w_m}}(w_m|x_m) \\ q(c|z,X) } } \left[ G_{\Phi_{z},\phi_{w_m},\theta, \pi}(X,c,z,w_m) \right]\end{split}\]

where

\[ \begin{align}\begin{aligned}\begin{split}G_{\Phi_{z},\phi_{w_m},\theta, \pi}(X,c,z,w_m) &= \log p_{\theta}(x_m|z, w_m) + \sum_{n \neq m} \mathbb{E}_{\tilde{w_n} \sim r_n(w_n)}\left[ \log p_{\theta_n}(x_n|z, \tilde{w_n}) \right] \\\end{split}\\\begin{split}& + \beta \log \left( \frac{p_{\pi}(c)p_{\theta}(z|c)p(w_m)}{q_{\Phi_z}(z|X)q_{\phi_m}(w_m|x_m)q(c|X,z)}\right)\\\end{split}\end{aligned}\end{align} \]

In practice the ELBO is approximated using importance sampling with K> 1 samples. This method can also be used for clustering with an ad-hoc procedure for selecting the number of clusters a posteriori.

Note

This model can be used in the partially observed setting. In that scenario, we adapt the model in a similar fashion as for the MMVAE+.

Note

The diffusion decoders are not yet supported.

class multivae.models.CMVAEConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, K=10, prior_and_posterior_dist='laplace_with_softmax', learn_modality_prior=True, beta=1.0, modalities_specific_dim=None, reconstruction_option='joint_prior', loss='dreg_looser', number_of_clusters=10)[source]

This class is the configuration class for the CMVAE model.

Parameters:
  • n_modalities (int) – The number of modalities. Default: None.

  • latent_dim (int) – The dimension of the latent space. Default: None.

  • input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).

  • uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.

  • rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.

  • decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.

  • decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with decoder_dist_params =  {'mod1' : {"scale" : 0.75}}.

  • K (int) – the number of samples to use in the IWAE or DreG loss. Default to 1.

  • prior_and_posterior_dist (str) – The type of distribution to use for posterior and prior. Possible values [‘laplace_with_softmax’,’normal_with_softplus’,’normal’]. Default to ‘laplace_with_softmax’ the posterior distribution that is used in the original paper.

  • learn_modality_prior (bool) – Learn modality specific priors. Should be True for the method to work. Default to True.

  • beta (float) – Regularizes the divergence term as in beta-VAE. Default to 1.

  • modalities_specific_dim (int) – The dimensionality of the modalitie’s private latent spaces. Must be provided.

  • reconstruction_option (Literal['single_prior','joint_prior']) – Specifies how to sample the modality specific variable when reconstructing/ translating modalities. Default to ‘joint_prior’ used in the article.

  • loss (Literal['dreg_looser','iwae_looser']) – Default to ‘dreg_looser’.

  • number_of_clusters (int) – Default to 10.

class multivae.models.CMVAE(model_config, encoders=None, decoders=None)[source]

The CMVAE model from “Deep Generative Clustering with Multimodal Diffusion Variational Autoencoders” (Palumbo et al, 2023). The diffusion decoders are not implemented in this version.

Parameters:
  • model_config (CMVAEConfig) – An instance of CMVAEConfig in which any model’s parameters is made available.

  • encoders (Dict[str, BaseMultilatentEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Multivae’s BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None.

  • decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.

compute_joint_nll(inputs, K=1000, **kwargs)[source]

Estimate the negative joint likelihood.

Parameters:
  • inputs (MultimodalBaseDataset) – a batch of samples.

  • K (int) – the number of importance samples for the estimation. Default to 1000.

Returns:

The negative log-likelihood summed over the batch.

encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]

Generate encodings conditioning on all modalities or a subset of modalities.

Parameters:
  • inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.

  • cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.

  • N (int) – The number of encodings to sample for each datapoint. Default to 1.

  • return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).

Returns:

Contains the following fields

’z’ (torch.Tensor (n_data, N, latent_dim)) ‘one_latent_space’ (bool) ‘modalities_z’ (Dict[str,torch.Tensor (n_data, N, latent_dim) ])

Return type:

ModelOutput

forward(inputs, **kwargs)[source]

Forward pass of the CMVAE model. Returns the loss on the batch.

generate_from_prior(n_samples, **kwargs)[source]

Generate latent variables sampling from the prior distribution.

property pc_params

Parameters of prior distribution on latent clusters.

predict_clusters(inputs, **kwargs)[source]

Returns the clusters for all samples in inputs.

Returns:

with fields: clusters and pc_zs (dict).

Return type:

ModelOutput

Note

The clusters assignement can be accessed through clusters = model_output.clusters

prune_clusters(train_data, batch_size=128)[source]

Follows the pruning procedure described in the paper to compute the optimal number of clusters. At the end of this pruning, the model._pc_params will have been adapted to correspond to selected clusters.

Parameters:
Returns:

the list of entropy values from 0 to max_clusters.

Return type:

h_values (list)