Source code for multivae.models.mmvae.mmvae_model

import logging
import math
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from pythae.models.base.base_utils import ModelOutput
from torch.distributions import Laplace, Normal

from multivae.data.datasets.base import MultimodalBaseDataset
from multivae.data.utils import drop_unused_modalities

from ..base import BaseMultiVAE
from .mmvae_config import MMVAEConfig

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class MMVAE(BaseMultiVAE): """The Variational Mixture-of-Experts Autoencoder model. Args: model_config (MMVAEConfig): An instance of MMVAEConfig in which any model's parameters is made available. encoders (Dict[str, ~pythae.models.nn.base_architectures.BaseEncoder]): A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae's BaseEncoder. Default: None. decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae's BaseDecoder. """ def __init__( self, model_config: MMVAEConfig, encoders: dict = None, decoders: dict = None ): super().__init__(model_config, encoders, decoders) if model_config.prior_and_posterior_dist == "laplace_with_softmax": self.post_dist = Laplace self.prior_dist = Laplace elif model_config.prior_and_posterior_dist == "normal": self.post_dist = Normal self.prior_dist = Normal else: raise AttributeError( " The posterior_dist parameter must be " " either 'laplace_with_softmax' or 'normal'. " f" {model_config.prior_and_posterior_dist} was provided." ) self.prior_mean = torch.nn.Parameter( torch.zeros(1, self.latent_dim), requires_grad=False ) self.prior_log_var = torch.nn.Parameter( torch.zeros(1, self.latent_dim), requires_grad=model_config.learn_prior ) self.model_name = "MMVAE"
[docs] def log_var_to_std(self, log_var): """For latent distributions parameters, transform the log covariance to the standard deviation of the distribution either applying softmax or not. This follows the original implementation. """ if self.model_config.prior_and_posterior_dist == "laplace_with_softmax": return F.softmax(log_var, dim=-1) * log_var.size(-1) + 1e-6 else: return torch.exp(0.5 * log_var)
@property def pz_params(self): """From the prior mean and log_covariance, return the mean and standard deviation, either applying softmax or not depending on the choice of prior distribution. Returns: tuple: mean, std """ mean = self.prior_mean if self.model_config.prior_and_posterior_dist == "laplace_with_softmax": std = ( F.softmax(self.prior_log_var, dim=-1) * self.prior_log_var.size(-1) + 1e-6 ) else: std = torch.exp(0.5 * self.prior_log_var) return mean, std
[docs] def forward(self, inputs: MultimodalBaseDataset, **kwargs): """Forward pass of the model. Outputs the loss and metrics.""" # First compute all the encodings for all modalities # drop modalities that are completely unavailable in the batch to avoid Nan in backward inputs = drop_unused_modalities(inputs) embeddings = {} qz_xs = {} qz_xs_detach = {} reconstructions = {} compute_loss = kwargs.pop("compute_loss", True) detailed_output = kwargs.pop("detailed_output", False) k_iwae = kwargs.pop("K", self.model_config.K) for cond_mod in inputs.data: output = self.encoders[cond_mod](inputs.data[cond_mod]) mu, log_var = output.embedding, output.log_covariance sigma = self.log_var_to_std(log_var) qz_x = self.post_dist(mu, sigma) z_x = qz_x.rsample([k_iwae]) # K,n_batch,latent_dim # The DREG loss uses detached parameters in the loss computation afterwards. qz_x_detach = self.post_dist(mu.detach(), sigma.detach()) # Then compute all the cross-modal reconstructions reconstructions[cond_mod] = {} for recon_mod in inputs.data: decoder = self.decoders[recon_mod] z = z_x.reshape(-1, z_x.shape[-1]) # (K*n_batch, latent_dim) recon = decoder(z)["reconstruction"] recon = recon.reshape((*z_x.shape[:-1], *recon.shape[1:])) reconstructions[cond_mod][recon_mod] = recon qz_xs[cond_mod] = qz_x embeddings[cond_mod] = z_x qz_xs_detach[cond_mod] = qz_x_detach # Compute DREG loss if compute_loss: if self.model_config.loss == "dreg_looser": loss_output = self.dreg_looser( qz_xs_detach, embeddings, reconstructions, inputs ) elif self.model_config.loss == "iwae_looser": loss_output = self.iwae_looser( qz_xs, embeddings, reconstructions, inputs ) else: raise NotImplementedError() else: loss_output = ModelOutput() if detailed_output: loss_output["qz_xs"] = qz_xs loss_output["qz_xs_detach"] = qz_xs_detach loss_output["zss"] = embeddings loss_output["recon"] = reconstructions return loss_output
[docs] def compute_k_lws(self, qz_xs, embeddings, reconstructions, inputs): """Compute likelihood terms for all modalities and for all k. returns : dict containing the likelihoods terms (not aggregated) for all modalities. """ if hasattr(inputs, "masks"): # Compute the number of available modalities per sample n_mods_sample = torch.sum( torch.stack(tuple(inputs.masks.values())).int(), dim=0 ) else: n_mods_sample = torch.tensor([self.n_modalities]) lws = {} # to collect likelihoods for mod in embeddings: z = embeddings[mod] # (K, n_batch, latent_dim) n_mods_sample = n_mods_sample.to(z.device) prior = self.prior_dist(*self.pz_params) ### Compute log p(z) lpz = prior.log_prob(z).sum(-1) ### Compute log q(z|X) # Get all the log(q(z|x_i)) if hasattr(inputs, "masks"): # For incomplete data, we only use available modalities lqz_x = [] for m in qz_xs: qz = qz_xs[m].log_prob(z).sum(-1) # Set the probability to 0 for unavailable modalities, # so that they don't weigh in the Mixture-of-Experts qz[torch.stack([inputs.masks[m] == False] * len(z))] = -torch.inf lqz_x.append(qz) lqz_x = torch.stack(lqz_x) # n_modalities,K,nbatch else: lqz_x = torch.stack( [qz_xs[m].log_prob(z).sum(-1) for m in qz_xs] ) # n_modalities,K,nbatch # Compute the mixture of expert probability lqz_x = torch.logsumexp(lqz_x, dim=0) - torch.log( n_mods_sample ) # log_mean_exp ### Compute log p(X|z) lpx_z = 0 for recon_mod in reconstructions[mod]: x_recon = reconstructions[mod][recon_mod] K, n_batch = x_recon.shape[0], x_recon.shape[1] lpx_z_mod = ( self.recon_log_probs[recon_mod](x_recon, inputs.data[recon_mod]) .view(K, n_batch, -1) .mul(self.rescale_factors[recon_mod]) .sum(-1) ) if hasattr(inputs, "masks"): # cancel unavailable modalities lpx_z_mod *= inputs.masks[recon_mod].float() lpx_z += lpx_z_mod ### Compute the entire likelihood lw = lpx_z + lpz - lqz_x if hasattr(inputs, "masks"): # cancel unavailable modalities lw *= inputs.masks[mod].float() lws[mod] = lw # K, batch_size return lws, n_mods_sample
[docs] def dreg_looser(self, qz_xs, embeddings, reconstructions, inputs): """The DreG estimation for IWAE. losses components in lws needs to have been computed on **detached** posteriors. """ lws, n_mods_sample = self.compute_k_lws( qz_xs, embeddings, reconstructions, inputs ) # Compute all the wk weights for individual likelihoods wk = {} with torch.no_grad(): for mod, lw in lws.items(): wk[mod] = ( lw - torch.logsumexp(lw, 0, keepdim=True) ).exp() # K, n_batch # Compute the loss lws = torch.stack( [(lws[mod] * wk[mod]) for mod in embeddings], dim=0 ) # n_modalities,K, n_batch lws = lws.sum(1) # sum on K # The gradient with respect to \phi is multiplied one more time by wk # To achieve that, we register a hook on the latent variables z for mod in embeddings: embeddings[mod].register_hook( lambda grad, w=wk[mod]: w.unsqueeze(-1) * grad ) # Take the mean over modalities lws = lws.sum(0) / n_mods_sample # Return the sum over the batch return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics={})
[docs] def iwae_looser(self, qz_xs, embeddings, reconstructions, inputs): """Compute the iwae loss without the DReG estimator for the gradient.""" lws, n_mods_sample = self.compute_k_lws( qz_xs, embeddings, reconstructions, inputs ) # Transform into a tensor lws = torch.stack(list(lws.values()), dim=0) # n_modalities, K, n_batch # Take log_mean_exp on K to compute the IWAE estimation lws = torch.logsumexp(lws, dim=1) - math.log( lws.size(1) ) # n_modalities, n_batch # Take the mean on modalities outside the log lws = lws.sum(0) / n_mods_sample # Return the sum over the batch return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics={})
def iwae(self, qz_xs, embeddings, reconstructions, inputs): lws, n_mods_sample = self.compute_k_lws( qz_xs, embeddings, reconstructions, inputs ) # Transform into a tensor lws = torch.stack(list(lws.values()), dim=0) # n_modalities, K, n_batch # Take log_mean_exp on K to compute the IWAE estimation lws = torch.logsumexp(lws, dim=1) - math.log( lws.size(1) ) # n_modalities, n_batch # Take log_mean_exp on the modalities lws = torch.logsumexp(lws, dim=0) - n_mods_sample.log() # Return the sum over the batch return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics={})
[docs] def encode( self, inputs: MultimodalBaseDataset, cond_mod: Union[list, str] = "all", N: int = 1, return_mean=False, **kwargs, ): """Generate encodings conditioning on all modalities or a subset of modalities. Args: inputs (MultimodalBaseDataset): The dataset to use for the conditional generation. cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities names to condition on. N (int) : The number of encodings to sample for each datapoint. Default to 1. return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput instance with fields 'z' (torch.Tensor (n_data, N, latent_dim)),'one_latent_space' (bool) = True """ cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod if all([s in self.encoders.keys() for s in cond_mod]): if return_mean: emb = torch.stack( [self.encoders[mod](inputs.data[mod]).embedding for mod in cond_mod] ).mean(0) if N > 1: z = torch.stack([emb] * N) else: z = emb else: # Choose one of the conditioning modalities at random mod = np.random.choice(cond_mod) output = self.encoders[mod](inputs.data[mod]) mu, log_var = output.embedding, output.log_covariance sigma = self.log_var_to_std(log_var) qz_x = self.post_dist(mu, sigma) sample_shape = torch.Size([]) if N == 1 else torch.Size([N]) z = qz_x.rsample(sample_shape) flatten = kwargs.pop("flatten", False) if flatten: z = z.reshape(-1, self.latent_dim) return ModelOutput(z=z, one_latent_space=True)
[docs] @torch.no_grad() def compute_joint_nll( self, inputs: MultimodalBaseDataset, K: int = 1000, batch_size_K: int = 100 ): """Estimate the negative joint likelihood. Args: inputs (MultimodalBaseDataset) : a batch of samples. K (int) : the number of importance samples for the estimation. Default to 1000. batch_size_K (int) : Default to 100. Returns: The negative log-likelihood summed over the batch. """ # Check the dataset is not incomplete self.eval() if hasattr(inputs, "masks"): raise AttributeError( "The compute_joint_nll method is not yet implemented for incomplete datasets." ) # Compute all the parameters of the joint posterior q(z|x_1), q(z|x_2), ... post_params = [] for cond_mod in self.encoders: output = self.encoders[cond_mod](inputs.data[cond_mod]) mu, log_var = output.embedding, output.log_covariance sigma = self.log_var_to_std(log_var) post_params.append((mu, sigma)) # Sample K latents from the joint posterior z_joint = self.encode(inputs, N=K).z.permute(1, 0, 2) # n_data x K x latent_dim n_data, _, _ = z_joint.shape # Then iter on each datapoint to compute the iwae estimate of ln(p(x)) ll = 0 for i in range(n_data): start_idx = 0 stop_idx = min(start_idx + batch_size_K, K) lnpxs = [] while start_idx < stop_idx: latents = z_joint[i][start_idx:stop_idx] # Compute ln p(x_m|z) for z in latents and for each modality m lpx_zs = 0 for mod in inputs.data: decoder = self.decoders[mod] recon = decoder(latents)[ "reconstruction" ] # (batch_size_K, nb_channels, w, h) x_m = inputs.data[mod][i] # (nb_channels, w, h) lpx_zs += ( self.recon_log_probs[mod]( recon, torch.stack([x_m] * len(recon)) ) .reshape(recon.size(0), -1) .sum(-1) ) # Compute ln(p(z)) prior = self.prior_dist(*self.pz_params) lpz = prior.log_prob(latents).sum(dim=-1) # Compute posteriors ln(q(z|X)) = ln(1/M \sum_m q(z|x_m)) qz_xs = [self.post_dist(p[0][i], p[1][i]) for p in post_params] lqz_xy = torch.logsumexp( torch.stack([q.log_prob(latents).sum(-1) for q in qz_xs]), dim=0 ) - np.log(self.n_modalities) ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xy, dim=0) lnpxs.append(ln_px) # next batch start_idx += batch_size_K stop_idx = min(stop_idx + batch_size_K, K) ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - np.log(K) return -ll
[docs] @torch.no_grad() def compute_joint_nll_paper( self, inputs: MultimodalBaseDataset, K: int = 1000, batch_size_K: int = 10 ): """Computes the joint likelihood like in the original dataset, using all Mixture of experts samples and modality rescaling. """ self.eval() lws = [] nb_computed_samples = 0 while nb_computed_samples < K: n_samples = min(batch_size_K, K - nb_computed_samples) nb_computed_samples += n_samples # Compute a iwae likelihood estimate using n_samples output = self.forward( inputs, compute_loss=False, K=n_samples, detailed_output=True ) lw = -self.iwae(output.qz_xs, output.zss, output.recon, inputs).loss lws.append(lw + np.log(n_samples * self.n_modalities)) ll = torch.logsumexp(torch.stack(lws), dim=0) - np.log( nb_computed_samples * self.n_modalities ) # n_batch return -ll # we return the negative log liklihood
[docs] def generate_from_prior(self, n_samples, **kwargs): sample_shape = [n_samples] if n_samples > 1 else [] z = self.prior_dist(*self.pz_params).rsample(sample_shape).to(self.device) return ModelOutput(z=z.squeeze(), one_latent_space=True)