Source code for multivae.models.mhvae.mhvae_model

import logging
from itertools import combinations
from typing import Union

import torch

from multivae.data.datasets.base import MultimodalBaseDataset
from multivae.models.base import BaseMultiVAE, ModelOutput
from multivae.models.nn.base_architectures import BaseEncoder

from ..base.base_utils import kl_divergence, poe, rsample_from_gaussian
from .mhvae_config import MHVAEConfig

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class MHVAE(BaseMultiVAE): """MHVAE model. Args: model_config (MHVAEConfig) : the model configuration. encoders (Dict[str,BaseEncoder]) : contains the first layer encoder per modality. decoders (Dict[str, BaseDecoder]) : contains the last layer decoder per modality. bottom_up_blocks (Dict[str, list]) : For each modality, contains the (n_latent-1) bottom-up layers. Each layer must be an instance of nn.Module. The last layer must be an instance of BaseEncoder and must return the mean and log_covariance for the deepest latent variable. top_down_blocks (List[nn.Module]): contains the (n_latent-1) top-down layers. Each layer must be an instance of nn.Module. posterior_blocks (List or Dict): contains the (n_latent - 1) posterior layers for each modality. Each layer must be an instance of BaseEncoder. The input dimension of each posterior block must match 2 * the output dimension of the corresponding top_down_blocks. Provide a list if the weights are shared between modalities, and a dictionary if they are not. prior_blocks (List): contains the (n_latent - 1) prior layers. Each layer must be an instance of BaseEncoder. The input dimension of each prior block must match the output dimension of the corresponding top_down_blocks. """ def __init__( self, model_config: MHVAEConfig, encoders: dict, decoders: dict, bottom_up_blocks: dict, top_down_blocks: list, posterior_blocks: Union[list, dict], prior_blocks: list, ): # Super method sets up the base fields as well as encoders / decoders super().__init__(model_config, encoders, decoders) self.n_latent = model_config.n_latent self.beta = model_config.beta self.model_name = "MHVAE" self.sanity_check_bottom_up(encoders, bottom_up_blocks) self.set_bottom_up_blocks(bottom_up_blocks) self.sanity_check_top_down_blocks(top_down_blocks) self.set_top_down_blocks(top_down_blocks) self.sanity_check_prior_blocks(prior_blocks) self.prior_blocks = torch.nn.ModuleList(prior_blocks) self.check_and_set_posterior_blocks(posterior_blocks) self.model_config.custom_architectures.extend( ["bottom_up_blocks", "top_down_blocks", "prior_blocks", "posterior_blocks"] ) def _subsets(self): """Returns : subsets (list) : all the possible subsets of the modalities. """ subsets = [] for i in range(1, self.n_modalities + 1): subsets += combinations(list(self.encoders.keys()), r=i) return subsets def _adapt_log_var_to_missing_data(self, dict_params, inputs): """For incomplete datasets, we set the variance of missing modalities posterior to infinity so that it doesn't contribute to the PoE. Returns: list of mean list of variance """ if hasattr(inputs, "masks"): for m, item in dict_params.items(): item.log_covariance[~inputs.masks[m].bool()] = torch.inf list_means = [dict_params[m].embedding for m in dict_params] list_log_vars = [dict_params[m].log_covariance for m in dict_params] return list_means, list_log_vars
[docs] def subset_encode(self, z_deepest_params, skips, subset, inputs, return_mean=False): """Compute all the latent variables and KL divergences for a given subset of modalities. Args: z_deepest_params (Dict[str, ModelOutput]): dictionary containing the mean and logvar of the deepest latent variable for each modality. skips (Dict[str, List[torch.Tensor]]): dictionary containing the intermediate results of the bottom-up layers for each modality. subset (List[str]): list of modalities to consider to compute the joint posterior. inputs (MultimodalBaseDataset) : the batch data. return_mean (bool): If True, we return the mean everytime we sample from a distribution. Default to False. Returns: z_dict (Dict[str, torch.Tensor]): dictionary containing all the latent variables at each level. kl_dict (Dict[str, torch.Tensor]): dictionary containing all the KL divergences at each level. """ # Only keep the modalities in subset z_deepest_params_subset = {m: z_deepest_params[m] for m in subset} # For missing modalities in the dataset, we set the variance to infty list_mus, list_log_vars = self._adapt_log_var_to_missing_data( z_deepest_params_subset, inputs ) list_mus.append(torch.zeros_like(list_mus[0])) # add the prior p(z_L) mean = 0 list_log_vars.append( torch.zeros_like(list_log_vars[0]) ) # add the prior p(z_L) std = 1, logstd = 0 # Compute the joint posterior q(z_L | x) = p(z_L) * \prod_i q(z_L | x_i ) joint_mu, joint_lv = poe(torch.stack(list_mus), torch.stack(list_log_vars)) # Sample z_L z_l_deepest = rsample_from_gaussian( joint_mu, joint_lv, N=1, return_mean=return_mean ) # Compute KL(q(z_L | x) || p(z_L)) kl_l_deepest = kl_divergence( joint_mu, joint_lv, torch.zeros_like(joint_mu), torch.zeros_like(joint_lv) ).sum() # p(z_L) = N(0,1) # Keep track of all latent variables and KLs z_dict = {f"z_{self.n_latent}": z_l_deepest} kl_dict = {f"kl_{self.n_latent}": kl_l_deepest} # Sample the rest of the z for i in range(self.n_latent - 1, 0, -1): h = self.top_down_blocks[i - 1](z_dict[f"z_{i + 1}"]) # Compute p(z_l|z>l) prior_params = self.prior_blocks[i - 1](h) # Compute q(z_l | x, z>l) = p(z_l|z>l) \prod_i q(z_l | x_i, z>l) zl_params = {} for mod in subset: # Compute the parameters of q(z_l | x_i, z>l) d = skips[mod][i - 1] # skips[mod is of lenght self.n_latent - 1] concat = torch.cat([h, d], dim=1) # concatenate on the channels zl_params[mod] = self._get_posterior_block(mod, i - 1)(concat) # For missing modalities, we set variance to infty list_mus, list_log_vars = self._adapt_log_var_to_missing_data( zl_params, inputs ) # Add the prior to the product of experts list_mus.append(prior_params.embedding) list_log_vars.append(prior_params.log_covariance) joint_mu, joint_lv = poe(torch.stack(list_mus), torch.stack(list_log_vars)) # Sample z_l z_dict[f"z_{i}"] = rsample_from_gaussian( joint_mu, joint_lv, N=1, return_mean=return_mean ) # Compute KL(q(z_l | x, z>l)|p(z_l|z>l)) kl_dict[f"kl_{i}"] = kl_divergence( joint_mu, joint_lv, prior_params.embedding, prior_params.log_covariance ).sum() return z_dict, kl_dict
def _get_posterior_block(self, mod, i): """Returns the posterior block for a given modality and level. Handles the case where the weights are shared between modalities. """ if self.share_posterior_weights: return self.posterior_blocks[i] return self.posterior_blocks[mod][i] def _loss_subset(self, inputs, z_l_deepest_params, skips, subset): """Compute the negative ELBO loss using a subset of modalities for the posterior. Args: inputs (MultimodalBaseDataset): the input data. z_Ls_params (Dict[str, ModelOutput]): dictionary containing the mean and logvar of the deepest latent variable for each modality. skips (Dict[str, List[torch.Tensor]]): dictionary containing the intermediate results of the bottom-up layers for each modality. subset (List[str]): list of modalities to consider to compute the joint posterior. Returns: loss (torch.Tensor): the negative ELBO loss. kl_dict (Dict[str, torch.Tensor]): dictionary containing all the KL divergences at each level. """ # get all the latent variables and KLs in the hierarchy z_dict, kl_dict = self.subset_encode(z_l_deepest_params, skips, subset, inputs) # Reconstruct all modalities using z_1 recon_loss = 0 for mod in self.decoders: output = self.decoders[mod](z_dict["z_1"]) recon = output.reconstruction mod_loss = ( -self.recon_log_probs[mod](recon, inputs.data[mod]) * self.rescale_factors[mod] ) mod_loss = mod_loss.reshape(mod_loss.shape[0], -1).sum(-1) # We don't reconstruct missing modalities if hasattr(inputs, "masks"): mod_loss = mod_loss * inputs.masks[mod] recon_loss += mod_loss.sum() # Sum all kls of all levels kl = 0 for i in range(1, self.n_latent + 1): kl += kl_dict[f"kl_{i}"] loss = recon_loss + self.beta * kl return loss, kl_dict
[docs] def forward(self, inputs: MultimodalBaseDataset, **kwargs): """Compute the average negative ELBO loss using all possible subsets of modalities for the posterior. Args: inputs (MultimodalBaseDataset): the input data. Returns: ModelOutput: a ModelOutput instance containing the mean loss and the KL divergences for monitoring. """ z_l_deepest_params, skips = self.modality_encode(inputs.data) subsets = self._subsets() losses = [] for subset in subsets: loss, kl_dict = self._loss_subset(inputs, z_l_deepest_params, skips, subset) losses.append(loss) loss = torch.stack(losses).mean() # average on all subsets return ModelOutput(loss=loss, loss_sum=loss, metrics=kl_dict)
[docs] def encode(self, inputs, cond_mod="all", N=1, return_mean=False, **kwargs): """Encode the input data conditioning on the modalities in cond_mod and return the latent variables. Args: inputs (MultimodalBaseDataset): the input data. cond_mod (str, list): the modality to condition on. Either 'all' or a list of modalities. N (int): the number of samples to draw from the posterior for each sample. Generated latent_variables will have shape (N, n_data, n_latent) return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput: a ModelOutput instance containing the latent variables. """ cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod z_ls_params, skips = self.modality_encode(inputs.data) # Get the batch size n_data = len(list(z_ls_params.values())[0].embedding) if N > 1: for mod, z_l in z_ls_params.items(): z_l.embedding = torch.cat([z_l.embedding] * N, dim=0) z_l.log_covariance = torch.cat([z_l.log_covariance] * N, dim=0) skips[mod] = [torch.cat([t] * N, dim=0) for t in skips[mod]] # Replicate masks if necessary (N>1) if hasattr(inputs, "masks") and N > 1: masks = inputs.masks.copy() inputs.masks = {m: torch.cat([masks[m]] * N, dim=0) for m in masks} z_dict, _ = self.subset_encode( z_ls_params, skips, cond_mod, inputs, return_mean=return_mean ) flatten = kwargs.pop("flatten", False) if not flatten and N > 1: for k in z_dict: z_dict[k] = z_dict[k].reshape(N, n_data, *z_dict[k].shape[1:]) # Set the masks back to the original value (before it was replicated) if hasattr(inputs, "masks") and N > 1: inputs.masks = masks return ModelOutput(z=z_dict["z_1"], all_z=z_dict, one_latent_space=True)
[docs] def modality_encode(self, data: dict): """Encode each modality on its own. Args: data (Dict[str, torch.Tensor]): the input data for each modality. Returns: z_Ls_params: a dictionary containing for each modality a ModelOutput instance with embedding and logcovariance. skips : a dictionary containing a list of tensors for each modality. """ # Apply all bottom_up layers, save the intermediate results skips = {mod: [] for mod in data} z_ls_params = {} for mod in data: # Apply first encoder layer output = self.encoders[mod](data[mod]) z = output.embedding skips[mod].append(z) # Apply all intermediate layers for i in range(self.n_latent - 2): z = self.bottom_up_blocks[mod][i](z) skips[mod].append(z) # Apply last layer output = self.bottom_up_blocks[mod][-1](z) z_ls_params[mod] = output return z_ls_params, skips
[docs] def sanity_check_bottom_up(self, encoders, bottom_up_blocks): """Check the coherence of the bottom_up_blocks with the encoders.""" # Check the number of modalities if self.n_modalities != len(bottom_up_blocks.keys()): raise AttributeError( f"The provided number of decoders {len(bottom_up_blocks.keys())} doesn't" f"match the number of modalities ({self.n_modalities} in model config " ) # Check coherence with the encoders keys if encoders.keys() != bottom_up_blocks.keys(): raise AttributeError( "The names of the modalities in the encoders dict doesn't match the names of the modalities" " in the bottom_up_blocks dict." ) # Check that the number of layers is correct for mod in bottom_up_blocks: if len(bottom_up_blocks[mod]) != self.n_latent - 1: raise AttributeError( f"There must be {self.n_latent - 1} bottom_up_blocks for modality" f" {mod} but you provided {len(bottom_up_blocks[mod])} layers." ) # Check that the last layer is an instance of BaseEncoder if not isinstance(bottom_up_blocks[mod][-1], BaseEncoder): raise AttributeError( f"The last layer in bottom_up_blocks for modality {mod}" " must be an instance of BaseEncoder" )
[docs] def sanity_check_top_down_blocks(self, top_down_blocks): """Check the coherence of the top_down_blocks with the model configuration.""" if len(top_down_blocks) != self.n_latent - 1: raise AttributeError( f"There must be {self.n_latent - 1} modules in top_down_blocks." )
[docs] def check_and_set_posterior_blocks(self, posterior_blocks): """Check the coherence of the posterior_blocks with the model configuration.""" # Shared weights : a list of modules was provided if isinstance(posterior_blocks, (list, torch.nn.ModuleList)): logger.info("Shared weights for the posterior blocks") self.share_posterior_weights = True if len(posterior_blocks) != self.n_latent - 1: raise AttributeError( f"There must be {self.n_latent - 1} modules in posterior_blocks." ) for block in posterior_blocks: if not isinstance(block, BaseEncoder): raise AttributeError( "The modules in posterior_blocks must be instances of BaseEncoder" ) self.posterior_blocks = torch.nn.ModuleList(posterior_blocks) return # Not shared weights : a dict of lists of modules was provided if isinstance(posterior_blocks, (dict, torch.nn.ModuleDict)): logger.info("Not shared weights for the posterior blocks") self.share_posterior_weights = False if posterior_blocks.keys() != self.encoders.keys(): raise AttributeError( "The keys of posterior_blocks must match the keys of encoders." ) for m, p in posterior_blocks.items(): if len(p) != self.n_latent - 1: raise AttributeError( f"There must be {self.n_latent - 1} modules in posterior_blocks[{m}]." ) for block in p: if not isinstance(block, BaseEncoder): raise AttributeError( f"The modules in posterior_blocks[{m}] must be instances of BaseEncoder" ) self.posterior_blocks = torch.nn.ModuleDict() for mod in posterior_blocks: self.posterior_blocks[mod] = torch.nn.ModuleList(posterior_blocks[mod]) return raise AttributeError("posterior_blocks must be a list or a dict")
[docs] def sanity_check_prior_blocks(self, prior_blocks): """Check the coherence of the prior_blocks with the model configuration.""" if len(prior_blocks) != self.n_latent - 1: raise AttributeError(f"There must be {self.n_latent - 1} modules in prior.") for block in prior_blocks: if not isinstance(block, BaseEncoder): raise AttributeError( "The modules in prior_blocks must be instances of BaseEncoder" )
[docs] def set_top_down_blocks(self, top_down_blocks): """Set the top_down_blocks attribute.""" self.top_down_blocks = torch.nn.ModuleList(top_down_blocks)
[docs] def set_bottom_up_blocks(self, bottom_up_blocks): """Set the bottom_up_blocks attribute.""" self.bottom_up_blocks = torch.nn.ModuleDict() for mod in bottom_up_blocks: self.bottom_up_blocks[mod] = torch.nn.ModuleList(bottom_up_blocks[mod])