JNF: Joint Normalizing Flows
Joint Normalizing Flows (JNF) from https://arxiv.org/abs/2502.03952.
JNF uses a joint encoder to model \(q_{\phi}(z|X)\) and surrogate unimodal encoders \(q_{\phi_j}(z|x_j)\) for \(1\leq j\leq M\).
The loss used is the same as the JMVAE (with \(\alpha = 1\)) but the unimodal encoders \(q_{\phi_j}(z|x_j)\) are modelled with Masked Autoregressive Flows.
Contrary to the JMVAE, this model is trained with separate stages: : first the joint encoder is trained, then the unimodal encoders are trained.
Note
This model must be trained with ~multivae.trainers.multistage_trainer.MultiStageTrainer.
Note
As it uses a joint encoder, this model can not be used in the partially observed setting.
- class multivae.models.JNFConfig(n_modalities, latent_dim=10, input_dims=None, uses_likelihood_rescaling=False, rescale_factors=None, decoders_dist=None, decoder_dist_params=None, custom_architectures=<factory>, warmup=10, beta=1)[source]
This is the base config for the JNF model.
- Parameters:
n_modalities (int) – The number of modalities. Default: None.
latent_dim (int) – The dimension of the latent space. Default: None.
input_dims (dict[str,tuple]) – The modalities’names (str) and input shapes (tuple).
uses_likelihood_rescaling (bool) – To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False.
rescale_factors (dict[str, float]) – The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None.
decoders_dist (Dict[str, Union[function, str]]) – Per modality, you can provide a string in [‘normal’,’bernoulli’,’laplace’]. For Bernoulli distribution, the decoder is expected to output logits. If None is provided, a normal distribution is used for each modality.
decoder_dist_params (Dict[str,dict]) – Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with
decoder_dist_params = {'mod1' : {"scale" : 0.75}}.warmup (int) – The number of warmup epochs for training the joint encoder and decoders. Default to 10.
beta (float) – Weighing factor for the regularization of the joint VAE. Default to 1.
- class multivae.models.JNF(model_config, encoders=None, decoders=None, joint_encoder=None, flows=None, **kwargs)[source]
The JNF model.
- Parameters:
model_config (JNFConfig) – Contains parameters for the JNF model.
encoders (Dict[str, BaseEncoder]) – A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae’s BaseEncoder.
decoders (Dict[str, BaseDecoder]) – A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae’s BaseDecoder.
joint_encoder (BaseJointEncoder) – Takes all the modalities as an input. If none is provided, one is created from the unimodal encoders. Default : None.
flows (Dict[str, BaseNF]) – A dictionary containing the modalities names and the flows to use for each modality. If None is provided, a default MAF flow is used for each modality.
- encode(inputs, cond_mod='all', N=1, return_mean=False, **kwargs)[source]
Generate encodings conditioning on all modalities or a subset of modalities.
- Parameters:
inputs (MultimodalBaseDataset) – The dataset to use for the conditional generation.
cond_mod (Union[list, str]) – Either ‘all’ or a list of str containing the modalities names to condition on.
N (int) – The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) – if True, returns the mean of the posterior distribution (instead of a sample).
- **kwargs:
- mcmc_steps(int)the number of Monte-Carlo step to perform when sampling from the product
of experts. Default to 100. If the coherences results are bad and the latent space is quite large, consider augmenting this number.
- n_lf (int)The number of leapfrog steps in the Hamiltonian Monte Carlo Sampling.
Default to 10.
- eps_lf (float)the time step to use in the Hamiltonian Monte Carlo Sampling.
default to 0.01.
- Returns:
Contains fields ‘z’ (torch.Tensor (N, n_data, latent_dim)) ‘one_latent_space’ (bool) = True
- Return type:
ModelOutput