Source code for multivae.models.mvtcae.mvtcae_model

from typing import Union

import numpy as np
import torch
import torch.distributions as dist
from pythae.models.base.base_utils import ModelOutput

from multivae.data.datasets.base import IncompleteDataset, MultimodalBaseDataset

from ..base import BaseMultiVAE
from ..base.base_utils import poe, rsample_from_gaussian
from .mvtcae_config import MVTCAEConfig


[docs] class MVTCAE(BaseMultiVAE): """MVTCAE model. Args: model_config (MVTCAEConfig): An instance of MVTCAEConfig in which any model's parameters is made available. encoders (Dict[str, ~pythae.models.nn.base_architectures.BaseEncoder]): A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Pythae's BaseEncoder. Default: None. decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae's BaseDecoder. """ def __init__( self, model_config: MVTCAEConfig, encoders: dict = None, decoders: dict = None ): super().__init__(model_config, encoders, decoders) self.alpha = model_config.alpha self.beta = model_config.beta self.model_name = "MVTCAE"
[docs] def forward(self, inputs: MultimodalBaseDataset, **kwargs) -> ModelOutput: """Forward pass of the model that returns the loss.""" # Compute latents parameters for all subsets latents = self._inference(inputs) results = {} # Sample from the joint posterior joint_mu, joint_logvar = latents["joint"][0], latents["joint"][1] shared_embeddings = rsample_from_gaussian(joint_mu, joint_logvar) ndata = len(shared_embeddings) joint_kld = -0.5 * torch.sum( 1 - joint_logvar.exp() - joint_mu.pow(2) + joint_logvar ) assert not torch.isnan(joint_kld) results["joint_divergence"] = joint_kld # Compute the reconstruction losses for each modality loss_rec = 0 for m_key in self.encoders.keys(): # reconstruct this modality from the shared embeddings representation recon = self.decoders[m_key](shared_embeddings).reconstruction m_rec = ( -self.recon_log_probs[m_key](recon, inputs.data[m_key]) * self.rescale_factors[m_key] ) m_rec = m_rec.reshape(m_rec.size(0), -1).sum(-1) # Keep only the available samples if hasattr(inputs, "masks"): m_rec = inputs.masks[m_key].float() * m_rec results[m_key] = m_rec.sum() loss_rec += m_rec.sum() assert not torch.isnan(loss_rec) latent_modalities = latents["modalities"] kld_losses = 0.0 for m_key in latent_modalities.keys(): o = latent_modalities[m_key] mu, logvar = o.embedding, o.log_covariance results["kld_" + m_key] = -0.5 * ( 1 - joint_logvar.exp() / logvar.exp() - (joint_mu - mu).pow(2) / logvar.exp() + joint_logvar - logvar ).reshape(mu.size(0), -1).sum(-1) # Keep only the available samples if hasattr(inputs, "masks"): results["kld_" + m_key][(1 - inputs.masks[m_key].int()).bool()] = 0 results["kld_" + m_key] = results["kld_" + m_key].sum() kld_losses += results["kld_" + m_key].sum() assert not torch.isnan(kld_losses) rec_weight = (self.n_modalities - self.alpha) / self.n_modalities cvib_weight = self.alpha / self.n_modalities # 1/6 vib_weight = 1 - self.alpha # 0.1 kld_weighted = cvib_weight * kld_losses + vib_weight * joint_kld total_loss = rec_weight * loss_rec + self.beta * kld_weighted return ModelOutput( loss=total_loss / ndata, loss_sum=total_loss, metrics=results )
def _modality_encode( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], **kwargs ): """Computes for each modality, the parameters mu and logvar of the unimodal posterior. Args: inputs (MultimodalBaseDataset): The data to encode. Returns: dict: Containing for each modality the encoder output. """ encoders_outputs = dict() for m, m_key in enumerate(inputs.data.keys()): input_modality = inputs.data[m_key] output = self.encoders[m_key](input_modality) # For unavailable samples, set the log-variance to infty so that they don't contribute to the # product of experts if hasattr(inputs, "masks"): output.log_covariance[~inputs.masks[m_key].bool()] = torch.inf encoders_outputs[m_key] = output return encoders_outputs def _inference(self, inputs: MultimodalBaseDataset, **kwargs): """This function takes all the modalities contained in inputs and compute the product of experts of the modalities encoders. Args: inputs (MultimodalBaseDataset): The data. Returns: dict : Contains the modalities' encoders parameters and the poe parameters. """ latents = {} enc_mods = self._modality_encode(inputs) latents["modalities"] = enc_mods device = enc_mods[list(inputs.data.keys())[0]].embedding.device mus = torch.Tensor().to(device) logvars = torch.Tensor().to(device) mods = list(inputs.data.keys()) for m, mod in enumerate(mods): mus_mod = enc_mods[mod].embedding log_vars_mod = enc_mods[mod].log_covariance mus = torch.cat((mus, mus_mod.unsqueeze(0)), dim=0) logvars = torch.cat((logvars, log_vars_mod.unsqueeze(0)), dim=0) # Case with only one sample : adapt the shape if len(mus.shape) == 2: mus = mus.unsqueeze(1) logvars = logvars.unsqueeze(1) joint_mu, joint_logvar = poe(mus, logvars) latents["joint"] = [joint_mu, joint_logvar] return latents
[docs] def encode( self, inputs: MultimodalBaseDataset, cond_mod: Union[list, str] = "all", N: int = 1, return_mean=False, **kwargs, ) -> ModelOutput: """Generate encodings conditioning on all modalities or a subset of modalities. Args: inputs (MultimodalBaseDataset): The dataset to use for the conditional generation. cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities names to condition on. N (int) : The number of encodings to sample for each datapoint. Default to 1. return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput instance with fields: z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) = True """ # Call super() function to transform to preprocess cond_mod. You obtain a list of # the modalities' names to condition on. cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod # Only keep the relevant modalities for prediction cond_inputs = MultimodalBaseDataset( data={k: inputs.data[k] for k in cond_mod}, ) latents_subsets = self._inference(cond_inputs) mu, log_var = latents_subsets["joint"] flatten = kwargs.pop("flatten", False) z = rsample_from_gaussian( mu, log_var, N=N, return_mean=return_mean, flatten=flatten ) return ModelOutput(z=z, one_latent_space=True)
[docs] @torch.no_grad() def compute_joint_nll( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], K: int = 1000, batch_size_K: int = 100, ): """Estimate the negative joint likelihood. Args: inputs (MultimodalBaseDataset) : a batch of samples. K (int) : the number of importance samples for the estimation. Default to 1000. batch_size_K (int) : Default to 100. Returns: The negative log-likelihood summed over the batch. """ # Check that the dataset is complete self.eval() if hasattr(inputs, "masks"): raise AttributeError( "The compute_joint_nll method is not yet implemented for incomplete datasets." ) # Compute the parameters of the joint posterior mu, log_var = self._inference(inputs)["joint"] sigma = torch.exp(0.5 * log_var) qz_xy = dist.Normal(mu, sigma) # Sample K latents from the joint posterior z_joint = qz_xy.rsample([K]).permute( 1, 0, 2 ) # shape : n_data x K x latent_dim n_data, _, _ = z_joint.shape # Then iter on each datapoint to compute the iwae estimate of ln(p(x)) ll = 0 for i in range(n_data): start_idx = 0 stop_idx = min(start_idx + batch_size_K, K) lnpxs = [] while start_idx < stop_idx: latents = z_joint[i][start_idx:stop_idx] # Compute p(x_m|z) for z in latents and for each modality m lpx_zs = 0 for mod in inputs.data: decoder = self.decoders[mod] recon = decoder(latents)[ "reconstruction" ] # (batch_size_K, nb_channels, w, h) x_m = inputs.data[mod][i] # (nb_channels, w, h) lpx_zs += ( self.recon_log_probs[mod]( recon, torch.stack([x_m] * len(recon)) ) .reshape(recon.size(0), -1) .sum(-1) ) # Compute ln(p(z)) prior = dist.Normal(0, 1) lpz = prior.log_prob(latents).sum(dim=-1) # Compute posteriors -ln(q(z|X)) qz_xy = dist.Normal(mu[i], sigma[i]) lqz_xy = qz_xy.log_prob(latents).sum(dim=-1) ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xy, dim=0) lnpxs.append(ln_px) # next batch of samples K start_idx += batch_size_K stop_idx = min(stop_idx + batch_size_K, K) ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - np.log(K) return -ll