Source code for multivae.models.crmvae.crmvae_model

from typing import Union

import numpy as np
import torch
import torch.distributions as dist
from pythae.models.base.base_utils import ModelOutput

from multivae.data.datasets.base import IncompleteDataset, MultimodalBaseDataset

from ..base import BaseMultiVAE
from ..base.base_utils import kl_divergence, poe, rsample_from_gaussian
from .crmvae_config import CRMVAEConfig


[docs] class CRMVAE(BaseMultiVAE): """Main class for the CRMVAE model proposed in https://openreview.net/forum?id=Rn8u4MYgeNJ. Args: model_config (CRMVAEConfig): An instance of CRMVAEConfig containing all the parameters for the model. encoders (Dict[str, ~pythae.models.nn.base_architectures.BaseEncoder]): A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae's BaseEncoder. Default: None. decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae's BaseDecoder. """ def __init__( self, model_config: CRMVAEConfig, encoders: dict = None, decoders: dict = None ): super().__init__(model_config, encoders, decoders) self.model_name = "CRMVAE" def _rsample( self, latent_params: ModelOutput, N=1, return_mean=False, flatten=False ): mean = latent_params.embedding log_var = latent_params.log_covariance z = rsample_from_gaussian(mean, log_var, N, return_mean, flatten) return mean, log_var, z
[docs] def forward(self, inputs: MultimodalBaseDataset, **kwargs) -> ModelOutput: """Forward pass of the model. Returns the loss and additional metrics in a ModelOutput Instance. """ # Compute latents parameters for q(z|x_i) and q(z|X) latents = self._infer_all_latent_parameters(inputs) results = {} z_samples = {} # for reconstruction # Sample from q(z|X) joint_mu, joint_logvar, shared_embeddings = self._rsample(latents["joint"]) z_samples["joint"] = shared_embeddings # Compute the KL(q(z|X)|p(z)). p(z) is a normal distribution N(0,I) joint_kld = kl_divergence( joint_mu, joint_logvar, torch.zeros_like(joint_mu), torch.zeros_like(joint_logvar), ) results["joint_divergence"] = joint_kld.mean() divergence = joint_kld # Sample from each unimodal posterior q(z|x_i) and compute KL(q(z|X)|q(z|x_i)) for m, latent_params in latents["modalities_no_mask"].items(): mu, log_var, embeddings = self._rsample(latent_params) z_samples[m] = embeddings # Compute KL(q(z|X) | q(z|x_i)) results[f"kl_{m}"] = kl_divergence(joint_mu, joint_logvar, mu, log_var) # Remove unavailable samples if hasattr(inputs, "masks"): results[f"kl_{m}"] = results[f"kl_{m}"] * inputs.masks[m].float() divergence += results[f"kl_{m}"] results[f"kl_{m}"] = results[f"kl_{m}"].mean() # Compute E_{q(z|X)} log p(x_i|z) + E_{q(z|x_i)}(log p(x_i|z)) loss_rec = 0 for gen_mod, decoder in self.decoders.items(): for m in ["joint", gen_mod]: # for m in ['joint']: z = z_samples[m] recon = decoder(z).reconstruction # apply rescaling factors is any are available m_rec = ( -self.recon_log_probs[gen_mod](recon, inputs.data[gen_mod]) * self.rescale_factors[gen_mod] ) m_rec = m_rec.reshape(m_rec.size(0), -1).sum(-1) # Cancel out unavailable samples in the reconstruction if hasattr(inputs, "masks"): m_rec = inputs.masks[gen_mod].float() * m_rec loss_rec += m_rec # Save metric for monitoring results[f"recon_{gen_mod}_from_{m}"] = m_rec.mean() # Average accross the posteriors : TODO (average over available posteriors only ?) loss_rec = loss_rec / (2 * (self.n_modalities + 1)) divergence = divergence / (self.n_modalities + 1) total_loss = loss_rec + self.model_config.beta * divergence return ModelOutput( loss=total_loss.sum(), loss_sum=total_loss.sum(), metrics=results )
def _modality_encode( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], **kwargs ): """Computes for each modality, the parameters mu and logvar of the unimodal posterior. Args: inputs (MultimodalBaseDataset): The data to encode. Returns: dict: Containing for each modality the encoder output. """ encoders_outputs = {} masked_outputs = {} for m, data_m in inputs.data.items(): output = self.encoders[m](data_m) # For unavailable samples, set the log-variance to infty so that they don't contribute to the # product of experts encoders_outputs[m] = output masked_outputs[m] = ModelOutput( embedding=output.embedding, log_covariance=output.log_covariance.clone() ) if hasattr(inputs, "masks"): masked_outputs[m].log_covariance[ (1 - inputs.masks[m].int()).bool() ] = torch.inf return encoders_outputs, masked_outputs def _infer_all_latent_parameters(self, inputs: MultimodalBaseDataset, **kwargs): """This function takes all the modalities contained in inputs and compute the product of experts of the modalities encoders. Args: inputs (MultimodalBaseDataset): The data. Returns: dict : Contains the modalities' encoders parameters and the poe parameters. """ latents = {} enc_mods, masked_mods = self._modality_encode(inputs) latents["modalities_no_mask"] = enc_mods latents["modalities_masked"] = masked_mods device = enc_mods[list(inputs.data.keys())[0]].embedding.device mus = torch.Tensor().to(device) logvars = torch.Tensor().to(device) # Concatenate all the masked params for computing the PoE for masked_params in masked_mods.values(): mus_mod = masked_params.embedding log_vars_mod = masked_params.log_covariance mus = torch.cat((mus, mus_mod.unsqueeze(0)), dim=0) logvars = torch.cat((logvars, log_vars_mod.unsqueeze(0)), dim=0) # Case with only one sample : adapt the shape if len(mus.shape) == 2: mus = mus.unsqueeze(1) logvars = logvars.unsqueeze(1) # Compute the PoE and save the result joint_mu, joint_logvar = poe(mus, logvars) latents["joint"] = ModelOutput(embedding=joint_mu, log_covariance=joint_logvar) return latents
[docs] def encode( self, inputs: MultimodalBaseDataset, cond_mod: Union[list, str] = "all", N: int = 1, return_mean=False, **kwargs, ) -> ModelOutput: """Generate encodings conditioning on all modalities or a subset of modalities. Args: inputs (MultimodalBaseDataset): The dataset to use for the conditional generation. cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities names to condition on. N (int) : The number of encodings to sample for each datapoint. Default to 1. return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput instance with fields: z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) = True """ # Call super() function to transform to preprocess cond_mod. You obtain a list of # the modalities' names to condition on. cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod flatten = kwargs.pop("flatten", False) # Only keep the relevant modalities for prediction cond_inputs = MultimodalBaseDataset( data={k: inputs.data[k] for k in cond_mod}, ) # Compute the product of experts of the modalities in cond_mod latents_subsets = self._infer_all_latent_parameters(cond_inputs) _, _, z = self._rsample( latents_subsets["joint"], N=N, return_mean=return_mean, flatten=flatten ) return ModelOutput(z=z, one_latent_space=True)
[docs] @torch.no_grad() def compute_joint_nll( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], K: int = 1000, batch_size_K: int = 100, ) -> torch.Tensor: """Estimate the negative joint likelihood. Args: inputs (MultimodalBaseDataset) : a batch of samples. K (int) : the number of importance samples for the estimation. Default to 1000. batch_size_K (int) : Default to 100. Returns: The negative log-likelihood summed over the batch. """ self.eval() if hasattr(inputs, "masks"): raise AttributeError( "The compute_joint_nll method is not yet implemented for incomplete datasets." ) # Compute the parameters of the joint posterior joint_params = self._infer_all_latent_parameters(inputs)["joint"] # Sample K latents from the joint posterior mu, logvar, z_joint = self._rsample(joint_params, N=K) z_joint = z_joint.permute(1, 0, 2) # shape : n_data x K x latent_dim n_data, _, _ = z_joint.shape # Then iter on each datapoint to compute the iwae estimate of ln(p(x)) ll = 0 for i in range(n_data): start_idx = 0 stop_idx = min(start_idx + batch_size_K, K) lnpxs = [] while start_idx < stop_idx: latents = z_joint[i][start_idx:stop_idx] # Compute p(x_m|z) for z in latents and for each modality m lpx_zs = 0 for mod in inputs.data: decoder = self.decoders[mod] recon = decoder(latents)[ "reconstruction" ] # (batch_size_K, nb_channels, w, h) x_m = inputs.data[mod][i] # (nb_channels, w, h) lpx_zs += ( self.recon_log_probs[mod]( recon, torch.stack([x_m] * len(recon)) ) .reshape(recon.size(0), -1) .sum(-1) ) # Compute ln(p(z)) prior = dist.Normal(0, 1) lpz = prior.log_prob(latents).sum(dim=-1) # Compute posteriors -ln(q(z|X)) qz_xy = dist.Normal(mu[i], torch.exp(logvar[i] * 0.5)) lqz_xy = qz_xy.log_prob(latents).sum(dim=-1) ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xy, dim=0) lnpxs.append(ln_px) # next batch of samples K start_idx += batch_size_K stop_idx = min(stop_idx + batch_size_K, K) ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - np.log(K) return -ll