Source code for multivae.models.nexus.nexus_model

import logging
from typing import Dict, Union

import numpy as np
import torch
import torch.distributions as dist
from pythae.models.base.base_utils import ModelOutput

from multivae.data.datasets.base import MultimodalBaseDataset
from multivae.models.base import BaseDecoder, BaseEncoder
from multivae.models.nn.default_architectures import (
    BaseAEConfig,
    Decoder_AE_MLP,
    Encoder_VAE_MLP,
    nn,
)

from ..base import BaseMultiVAE
from ..base.base_utils import rsample_from_gaussian
from .nexus_config import NexusConfig

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class Nexus(BaseMultiVAE): """The Nexus model from (Vasco et al 2022) "Leveraging hierarchy in multimodal generative models for effective cross-modality inference". Args: model_config (NexusConfig): An instance of NexusConfig in which any model's parameters is made available. encoders (Dict[str, ~pythae.models.nn.BaseEncoder]): A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of ~pythae.models.nn.BaseEncoder decoders (Dict[str, ~pythae.models.nn.BaseDecoder]): A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of ~pythae.models.nn.BaseDecoder top_encoders (Dict[str, ~pythae.models.nn.BaseEncoder]) : A dictionary containing for each modality, the top encoder to use. joint_encoder (~multivae.models.nn.BaseJointEncoder): The encoder that takes the aggregated message and encode it to obtain the high level latent distribution. top_decoders (Dict[str, ~pythae.models.nn.BaseDecoder]) : A dictionary containing for each modality, the top decoder to use. """ def __init__( self, model_config: NexusConfig, encoders: Dict[str, BaseEncoder] = None, decoders: Dict[str, BaseDecoder] = None, top_encoders: Dict[str, BaseEncoder] = None, joint_encoder: Union[BaseEncoder, None] = None, top_decoders: Dict[str, BaseEncoder] = None, **kwargs, ): super().__init__(model_config, encoders, decoders, **kwargs) self.model_name = "NEXUS" # Set all architectures self._set_top_decoders(top_decoders, model_config) self._set_top_encoders(top_encoders, model_config) self._set_joint_encoder(joint_encoder, model_config) self._set_bottom_betas(model_config.bottom_betas) self._set_gammas(model_config.gammas) self.start_keep_best_epoch = model_config.warmup + 1 # important for training. self.adapt_top_decoder_variance = self._set_top_decoder_variance(model_config) self.check_aggregator(model_config) def _compute_bottom_elbos(self, inputs: MultimodalBaseDataset, **kwargs): """Passes the inputs through the first level of encoding and compute the bottom elbos.""" epoch = kwargs.pop("epoch", 1) annealing = min(epoch / self.model_config.warmup, 1.0) # Compute the first level representations and ELBOs modalities_msg = {} bottom_loss = 0 first_level_z = {} metrics = {} for m, x_m in inputs.data.items(): # Encode the modality output_m = self.encoders[m](x_m) z_m = rsample_from_gaussian(output_m.embedding, output_m.log_covariance) # Decode and reconstruct recon_x_m = self.decoders[m](z_m).reconstruction # Compute -log p(x_m|z_m) nlogprob = ( -(self.recon_log_probs[m](recon_x_m, x_m) * self.rescale_factors[m]) .reshape(recon_x_m.size(0), -1) .sum(-1) ) # Compute KL(q(z_m|x_m)||p(z_m)). p(z_m) is a standard gaussian KLD = -0.5 * torch.sum( 1 + output_m.log_covariance - output_m.embedding.pow(2) - output_m.log_covariance.exp(), dim=-1, ) # Compute the negative elbo loss m_elbo = nlogprob + KLD * self.bottom_betas[m] * annealing # Save a detached z first_level_z[m] = z_m.clone().detach() # Pass the modality specific latent variable through the top encoder to compute the message modalities_msg[m] = self.top_encoders[m](first_level_z[m]).embedding # Save some metrics for monitoring the training. metrics["recon_loss_" + m] = nlogprob.mean() metrics["kl_" + m] = KLD.mean() # Partial dataset : use masks to filter out unavailable samples in the loss if hasattr(inputs, "masks"): m_elbo = m_elbo * inputs.masks[m].float() bottom_loss += m_elbo return bottom_loss, modalities_msg, first_level_z, metrics
[docs] def forward(self, inputs: MultimodalBaseDataset, **kwargs): """Forward pass of the model. Returns loss and metrics.""" # Compute bottom level elbos bottom_loss, modalities_msg, first_level_z, metrics = ( self._compute_bottom_elbos(inputs, **kwargs) ) # Aggregate the modalities messages aggregated_msg = self._aggregate_during_training(inputs, modalities_msg) # Compute the higher level latent variable z_\sigma joint_output = self.joint_encoder(aggregated_msg) joint_z = rsample_from_gaussian( joint_output.embedding, joint_output.log_covariance ) # Compute log p(z_m|z_sigma) z_recon_loss = 0 for m in self.top_decoders: z_m_recon = self.top_decoders[m](joint_z).reconstruction # Eventually adapt the scale of the top decoder if m in self.adapt_top_decoder_variance: scale = ( ((first_level_z[m] - z_m_recon) ** 2) .mean([0, 1], keepdim=True) .sqrt() ) else: scale = 1 z_m_recon_loss = ( -(dist.Normal(z_m_recon, scale).log_prob(first_level_z[m])).sum(-1) * self.gammas[m] ) # Partial dataset, we don't reconstruct the missing modalities if hasattr(inputs, "masks"): z_m_recon_loss = z_m_recon_loss * inputs.masks[m] z_recon_loss += z_m_recon_loss metrics["recon_z_" + m] = ( z_m_recon_loss.mean() ) # save metrics for monitoring # Compute KL(q(z_sigma|z1::M) | p(z_sigma)). The prior p(z_sigma) is standard gaussian joint_KLD = -0.5 * torch.sum( 1 + joint_output.log_covariance - joint_output.embedding.pow(2) - joint_output.log_covariance.exp(), dim=1, ) epoch = kwargs.pop("epoch", 1) annealing = min(epoch / self.model_config.warmup, 1.0) # Compute top loss and total loss top_loss = z_recon_loss + self.model_config.top_beta * joint_KLD * annealing total_loss = top_loss + bottom_loss metrics.update( { "annealing": annealing, "bottom_loss": bottom_loss.mean(0), "top_loss": top_loss.mean(0), "joint_KLD": joint_KLD.mean(0), } ) # Return the mean averaged on the batch. return ModelOutput( loss=total_loss.mean(0), loss_sum=total_loss.sum(), metrics=metrics, )
def _aggregate_during_training( self, inputs: MultimodalBaseDataset, modalities_msg: dict ): """Aggregate the modalities during training. It applies the forced perceptual dropout if the dataset is not already incomplete.""" if self.model_config.aggregator == "mean": # With an already incomplete dataset, we don't apply dropout if hasattr(inputs, "masks"): normalization_per_sample = torch.stack( [inputs.masks[m] for m in inputs.masks], dim=0 ).sum(0) # Apply the masks and sum aggregated_msg = 0 for m, msg in modalities_msg.items(): aggregated_msg += msg * inputs.masks[m].unsqueeze(1) # Normalize aggregated_msg = (aggregated_msg.t() / normalization_per_sample).t() # With a complete dataset, we apply Forced Perceptual Dropout during training else: # before stack , we have n_modalities tensor of shape n_data, msg_dim. # After we have one single tensor of shape n_data, n_modalities, msg_dim tensor_modalities_msg = torch.stack( list(modalities_msg.values()), dim=1 ) batch_msgs = [] # we iter over the batch samples for msgs in tensor_modalities_msg: # msgs shape : n_modalities, msg_dim bernoulli_drop = ( dist.Bernoulli(self.model_config.dropout_rate).sample().item() ) if bernoulli_drop == 1: # choose a random subset to keep subset_size = np.random.randint(1, self.n_modalities) msgs = msgs[torch.randperm(self.n_modalities)] msgs = msgs[:subset_size] batch_msgs.append(msgs.mean(0)) aggregated_msg = torch.stack(batch_msgs, dim=0) return aggregated_msg return
[docs] def encode( self, inputs: MultimodalBaseDataset, cond_mod: Union[list, str] = "all", N: int = 1, return_mean=False, **kwargs, ): """Generate encodings conditioning on all modalities or a subset of modalities. Args: inputs (MultimodalBaseDataset): The dataset to use for the conditional generation. cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities names to condition on. N (int) : The number of encodings to sample for each datapoint. Default to 1. return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput instance with fields: z (torch.Tensor (n_data, N, latent_dim)) one_latent_space (bool) = True """ cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod modalities_z = {} modalities_msg = {} flatten = kwargs.pop("flatten", False) # Encode each modality with the bottom encoders for m in cond_mod: output_m = self.encoders[m](inputs.data[m]) modalities_z[m] = rsample_from_gaussian( output_m.embedding, output_m.log_covariance, N, return_mean, flatten=True, ) modalities_msg[m] = self.top_encoders[m](modalities_z[m]).embedding # Compute aggregated msg aggregated_msg = None if self.model_config.aggregator == "mean": aggregated_msg = torch.stack(list(modalities_msg.values()), dim=0).mean(0) nexus_output = self.joint_encoder(aggregated_msg) z = rsample_from_gaussian( nexus_output.embedding, nexus_output.log_covariance, N=1, return_mean=return_mean, ) if N > 1 and not flatten: z = z.reshape(N, -1, *z.shape[1:]) modalities_z = { m: modalities_z[m].reshape(N, -1, *modalities_z[m].shape[1:]) for m in modalities_z } return ModelOutput(z=z, one_latent_space=True, modalities_z=modalities_z)
[docs] def decode( self, embedding: ModelOutput, modalities: Union[list, str] = "all", **kwargs ): """Decodes the embeddings given by the latent function.""" self.eval() with torch.no_grad(): if modalities == "all": modalities = list(self.encoders.keys()) elif isinstance(modalities, str): modalities = [modalities] # For self reconstruction, we use the bottom encodings. use_bottom_z_for_reconstruction = kwargs.pop("use_bottom_z_for_recon", True) if not hasattr(embedding, "modalities_z"): use_bottom_z_for_reconstruction = False outputs = ModelOutput() # If the embedding has three dimensions, we flatten it and then reshape it at the end. reshape = False if len(embedding.z.shape) == 3: N, bs, _ = embedding.z.shape reshape = True for m in modalities: if (use_bottom_z_for_reconstruction) and ( m in embedding.modalities_z.keys() ): z_m = embedding.modalities_z[m] if reshape: z_m = z_m.view(N * bs, -1) else: z = embedding.z if reshape: z = z.view(N * bs, -1) z_m = self.top_decoders[m](z).reconstruction recon = self.decoders[m](z_m).reconstruction if reshape: recon = recon.reshape(N, bs, *recon.shape[1:]) outputs[m] = recon return outputs
def _set_top_decoder_variance(self, config): """Returns a list of the modalities for which the variance needs to be adapted.""" if config.adapt_top_decoder_variance is None: return [] for m in config.adapt_top_decoder_variance: if m not in self.modalities_name: raise AttributeError( f"A string provided in *adapt_top_decoder_variance* field doesn't match any of the modalities name : {m} is not in {self.modalities_name}" ) return config.adapt_top_decoder_variance def _set_bottom_betas(self, bottom_betas): if bottom_betas is None: bottom_betas = {m: 1.0 for m in self.encoders} if bottom_betas.keys() != self.encoders.keys(): raise AttributeError( "The bottom_betas keys do not match the modalitiesnames in encoders." ) self.bottom_betas = bottom_betas def _set_gammas(self, gammas): if gammas is None: self.gammas = {m: 1.0 for m in self.encoders} elif gammas.keys() != self.encoders.keys(): raise AttributeError( "The gammas keys do not match the modalitiesnames in encoders." ) else: self.gammas = gammas def default_encoders(self, model_config: NexusConfig): if ( model_config.input_dims is None or model_config.modalities_specific_dim is None ): raise AttributeError( "Please provide encoders architectures or " "valid input_dims and modalities_specific_dim in the" "model configuration" ) encoders = nn.ModuleDict() for mod in model_config.input_dims: config = BaseAEConfig( input_dim=model_config.input_dims[mod], latent_dim=model_config.modalities_specific_dim[mod], ) encoders[mod] = Encoder_VAE_MLP(config) return encoders def default_decoders(self, model_config: NexusConfig): if ( model_config.input_dims is None or model_config.modalities_specific_dim is None ): raise AttributeError( "Please provide decoders architectures or " "valid input_dims and modalities_specific_dim in the" "model configuration" ) decoders = nn.ModuleDict() for mod in model_config.input_dims: config = BaseAEConfig( input_dim=model_config.input_dims[mod], latent_dim=model_config.modalities_specific_dim[mod], ) decoders[mod] = Decoder_AE_MLP(config) return decoders def _default_top_encoders(self, model_config: NexusConfig): if model_config.modalities_specific_dim is None: raise AttributeError( "Please provide top_encoders architectures or " "valid modalities_specific_dim in the" "model configuration" ) encoders = nn.ModuleDict() for mod in model_config.input_dims: config = BaseAEConfig( input_dim=(model_config.modalities_specific_dim[mod],), latent_dim=model_config.msg_dim, ) encoders[mod] = Encoder_VAE_MLP(config) return encoders def _default_top_decoders(self, model_config: NexusConfig): if model_config.modalities_specific_dim is None: raise AttributeError( "Please provide top_decoders architectures or " "valid modalities_specific_dim in the" "model configuration" ) decoders = nn.ModuleDict() for mod in model_config.input_dims: config = BaseAEConfig( input_dim=(model_config.modalities_specific_dim[mod],), latent_dim=model_config.latent_dim, ) decoders[mod] = Decoder_AE_MLP(config) return decoders def _default_joint_encoder(self, model_config: NexusConfig): return Encoder_VAE_MLP( BaseAEConfig( input_dim=(model_config.msg_dim,), latent_dim=model_config.latent_dim ) ) def _set_top_encoders(self, top_encoders, model_config): # Provide default encoders if None are provided if top_encoders is None: top_encoders = self._default_top_encoders(model_config) else: self.model_config.custom_architectures.append("top_encoders") # Check top encoders type and set the attribute self.top_encoders = nn.ModuleDict() for k in top_encoders: if not isinstance(top_encoders[k], BaseEncoder): raise AttributeError( "Top Encoders must be instances of multivae.models.base.BaseEncoder" ) self.top_encoders[k] = top_encoders[k] def _set_top_decoders(self, top_decoders, model_config): # Provide default MLP decoders if None are provided. if top_decoders is None: top_decoders = self._default_top_decoders(model_config) else: self.model_config.custom_architectures.append("top_decoders") # Check the decoders type and set the attribute self.top_decoders = nn.ModuleDict() for k in top_decoders: if not isinstance(top_decoders[k], BaseDecoder): raise AttributeError( "Top Decoders must be instances of multivae.models.base.BaseDecoder" ) self.top_decoders[k] = top_decoders[k] def _set_joint_encoder(self, joint_encoder, model_config): # Provide default Nexus encoder if None is Provided if joint_encoder is None: joint_encoder = self._default_joint_encoder(model_config) else: self.model_config.custom_architectures.append("joint_encoder") # Check encoder type and set the attribute if not isinstance(joint_encoder, BaseEncoder): raise AttributeError( "Joint encoder must be an instance of multivae.models.base.BaseEncoder" ) self.joint_encoder = joint_encoder def check_aggregator(self, model_config): if model_config.aggregator not in ["mean"]: raise AttributeError( f"This aggregator {model_config.aggregator} is not supported at the moment" )