Source code for multivae.models.dmvae.dmvae_model

import math
from typing import Dict, Union

import torch
import torch.distributions as dist
from pythae.models.base.base_utils import ModelOutput
from pythae.models.nn.base_architectures import BaseDecoder
from torch import nn

from multivae.data.datasets.base import IncompleteDataset, MultimodalBaseDataset
from multivae.models.nn.default_architectures import (
    BaseDictDecodersMultiLatents,
    BaseDictEncoders_MultiLatents,
)

from ..base import BaseMultiVAE
from ..base.base_utils import kl_divergence, rsample_from_gaussian, stable_poe
from ..nn.base_architectures import BaseMultilatentEncoder
from .dmvae_config import DMVAEConfig


[docs] class DMVAE(BaseMultiVAE): """The DMVAE model from the paper 'Private-Shared Disentangled Multimodal VAE for Learning of Latent Representations'. Mihee Lee, Vladimir Pavlovic Args: model_config (DMVAEConfig): An instance of DMVAEConfig in which any model's parameters is made available. encoders (Dict[str, ~multivae.models.nn.base_architectures.BaseMultilatentEncoder]): A dictionary containing the modalities names and the encoders for each modality. Each encoder is an instance of Multivae's BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None. decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing the modalities names and the decoders for each modality. Each decoder is an instance of Pythae's BaseDecoder. """ def __init__( self, model_config: DMVAEConfig, encoders: Union[Dict[str, BaseMultilatentEncoder], None] = None, decoders: Union[Dict[str, BaseDecoder], None] = None, ): super().__init__(model_config, encoders, decoders) self.beta = model_config.beta self.model_name = "DMVAE" self._set_private_betas(model_config.modalities_specific_betas) self._set_modalities_specific_dim(model_config) self.multiple_latent_spaces = True def _set_modalities_specific_dim(self, model_config): if model_config.modalities_specific_dim is None: self.style_dims = {m: 1.0 for m in self.encoders} else: if model_config.modalities_specific_dim.keys() != self.encoders.keys(): raise AttributeError( "The keys in modalities_specific_dim doesn't match ", "the keys in the encoders or input_dims", ) else: self.style_dims = model_config.modalities_specific_dim return def _set_private_betas(self, beta_dict): if beta_dict is None: self.private_betas = {mod: 1.0 for mod in self.encoders} else: if not self.encoders.keys() == beta_dict.keys(): raise AttributeError( "The modality_specific_betas doesn't have the same " "keys (modalities) as the provided encoders dict." ) self.private_betas = beta_dict def default_encoders(self, model_config) -> nn.ModuleDict: return BaseDictEncoders_MultiLatents( input_dims=model_config.input_dims, latent_dim=model_config.latent_dim, modality_dims=model_config.modalities_specific_dim, ) def default_decoders(self, model_config) -> nn.ModuleDict: return BaseDictDecodersMultiLatents( input_dims=model_config.input_dims, latent_dim=model_config.latent_dim, modality_dims=model_config.modalities_specific_dim, ) def _infer_latent_parameters(self, inputs, subset=None): """Compute the latent parameters for the shared and private latent spaces, taking the product-of-experts on the subset. """ # if no subset is provided, use all available modalities if subset is None: subset = list(inputs.data.keys()) # First compute all embeddings private_params = {} shared_params = {} for mod in subset: output_mod = self.encoders[mod](inputs.data[mod]) private_params[mod] = ( output_mod.style_embedding, output_mod.style_log_covariance, ) if len(output_mod.style_embedding.shape) == 1: private_params[mod] = ( output_mod.style_embedding.unsqueeze(0), output_mod.style_log_covariance.unsqueeze(0), ) shared_params[mod] = (output_mod.embedding, output_mod.log_covariance) if len(output_mod.embedding.shape) == 1: shared_params[mod] = ( output_mod.embedding.unsqueeze(0), output_mod.log_covariance.unsqueeze(0), ) # Compute the PoE for the shared variable list_mu = [shared_params[mod][0] for mod in shared_params] list_lvs = [] # For unavailable modalities, set the variance to infinity so that it doesn't count in the PoE for mod in subset: log_var_mod = shared_params[mod][1].clone() if hasattr(inputs, "masks"): log_var_mod[(1 - inputs.masks[mod].int()).bool().flatten()] = torch.inf list_lvs.append(log_var_mod) # Add N(0,I) prior to the product of experts list_mu.append(torch.zeros_like(list_mu[0])) list_lvs.append(torch.zeros_like(list_lvs[0])) joint_mu, joint_lv = stable_poe(torch.stack(list_mu), torch.stack(list_lvs)) return joint_mu, joint_lv, shared_params, private_params
[docs] def forward( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], **kwargs ): """The main function of the model that computes the loss and some monitoring metrics. One of the advantages of DMVAE is that we can train with incomplete data. Args: inputs (MultimodalBaseDataset): The data. It can be an instance of IncompleteDataset which contains a field masks for weakly supervised learning. masks is a dictionary indicating which datasamples are missing in each of the modalities. For each modality, a boolean tensor indicates which samples are available. (The non available samples are assumed to be replaced with zero values in the multimodal dataset entry.) """ ( joint_mu, joint_lv, shared_params, private_params, ) = self._infer_latent_parameters(inputs) metrics = dict() # Compute the joint elbo joint_elbo = self._compute_elbo(joint_mu, joint_lv, private_params, inputs) loss = joint_elbo metrics["joint"] = joint_elbo.mean() # Compute crossmodal elbos for k, params in shared_params.items(): mod_elbo = self._compute_elbo(params[0], params[1], private_params, inputs) if hasattr(inputs, "masks"): mod_elbo = inputs.masks[k] * mod_elbo loss += mod_elbo metrics[k] = mod_elbo.mean() return ModelOutput(loss=loss.mean(), metrics=metrics)
def _compute_elbo(self, q_mu, q_lv, private_params, inputs): shared_z = rsample_from_gaussian(q_mu, q_lv) # Compute reconstructions recon_loss = 0 for mod in self.encoders: # Sample the modality specific mu, logvar = private_params[mod] z_mod = rsample_from_gaussian(mu, logvar) z = torch.cat([shared_z, z_mod], dim=1) recon_mod = self.decoders[mod](z).reconstruction recon_mod = ( self.recon_log_probs[mod](recon_mod, inputs.data[mod]) * self.rescale_factors[mod] ) recon_mod = recon_mod.reshape(recon_mod.size(0), -1).sum(-1) if hasattr(inputs, "masks"): # filter unavailable modalities in the reconstruction loss recon_mod = inputs.masks[mod].float() * recon_mod recon_loss += recon_mod # Compute KL divergence for shared variable shared_kl = kl_divergence( q_mu, q_lv, torch.zeros_like(q_mu), torch.zeros_like(q_lv) ) kl = shared_kl * self.beta # Add the modality specific kls for mod in self.encoders: mu, lv = private_params[mod] kl_mod = kl_divergence(mu, lv, torch.zeros_like(mu), torch.zeros_like(lv)) kl_mod = kl_mod.reshape(kl_mod.size(0), -1).sum(-1) if hasattr(inputs, "masks"): kl_mod = inputs.masks[mod].float() * kl_mod kl += kl_mod * self.private_betas[mod] neg_elbo = -recon_loss + kl return neg_elbo
[docs] def encode( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], cond_mod: Union[list, str] = "all", N: int = 1, return_mean=False, **kwargs, ): """Generate encodings conditioning on all modalities or a subset of modalities. Args: inputs (MultimodalBaseDataset): The dataset to use for the conditional generation. cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities names to condition on. N (int) : The number of encodings to sample for each datapoint. Default to 1. return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample). Returns: ModelOutput : Contains fields 'z' (torch.Tensor (N, n_data, latent_dim)) 'one_latent_space' (bool) = False 'modalities_z' (dict[str,torch.Tensor (N, n_data,mod_latent_dim)]) """ # Call super to perform some checks and preprocess the cond_mod argument # you obtain a list of the modalities' names to condition on cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod # Compute the shared latent variable conditioning on input modalities sub_mu, sub_logvar, _, private_params = self._infer_latent_parameters( inputs, cond_mod ) flatten = kwargs.pop("flatten", False) z = rsample_from_gaussian( sub_mu, sub_logvar, N=N, return_mean=return_mean, flatten=flatten ) modalities_z = {} for mod in self.encoders: if mod in cond_mod: mod_mu, mod_lv = private_params[mod] else: mod_mu = torch.zeros((sub_mu.shape[0], self.style_dims[mod])).to( sub_mu.device ) mod_lv = torch.zeros_like(mod_mu).to(sub_logvar.device) modalities_z[mod] = rsample_from_gaussian( mod_mu, mod_lv, N=N, return_mean=return_mean, flatten=flatten ) return ModelOutput(z=z, one_latent_space=False, modalities_z=modalities_z)
[docs] def generate_from_prior(self, n_samples, **kwargs): """Generates latent variables from the prior for the shared latent spaces and for each modality specific latent space. Args: n_samples """ device = self.device if self.device is not None else "cpu" # Generate shared latent variable shared_latent_shape = ( [n_samples, self.latent_dim] if n_samples > 1 else [self.latent_dim] ) z_shared = dist.Normal(0, 1).rsample(shared_latent_shape).to(device) # Generate modalities specific variables modalities_z = {} for k, dim in self.style_dims.items(): shape = [n_samples, dim] if n_samples > 1 else [dim] modalities_z[k] = dist.Normal(0, 1).rsample(shape).to(device) return ModelOutput( z=z_shared, one_latent_space=False, modalities_z=modalities_z )
[docs] @torch.no_grad() def compute_joint_nll( self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], K: int = 1000, batch_size_K: int = 100, ): """Estimate the negative joint likelihood. Args: inputs (MultimodalBaseDataset) : a batch of samples. K (int) : the number of importance samples for the estimation. Default to 1000. batch_size_K (int) : Default to 100. Returns: The negative log-likelihood summed over the batch. """ # Check that the dataset is complete self.eval() if hasattr(inputs, "masks"): raise AttributeError( "The compute_joint_nll method is not yet implemented for incomplete datasets." ) # Compute the parameters of the joint posterior for the shared latent space # Compute the shared latent variable conditioning on input modalities mu, log_var, _, private_params = self._infer_latent_parameters(inputs) sigma = torch.exp(0.5 * log_var) qz_xy = dist.Normal(mu, sigma) # Sample K latents from the shared joint posterior z_joint = qz_xy.rsample([K]).permute( 1, 0, 2 ) # shape : n_data x K x latent_dim n_data, _, _ = z_joint.shape # iter on each datapoint to compute the iwae estimate of ln(p(x)) ll = 0 ln_prior, ln_posterior = 0, 0 for i in range(n_data): start_idx = 0 stop_idx = min(start_idx + batch_size_K, K) lnpxs = [] # iterate over the mini-batch for the K samples while start_idx < stop_idx: shared_latents = z_joint[i][start_idx:stop_idx] # Compute ln p(x_m|z) for z in latents and for each modality m lpx_zs = 0 for mod in inputs.data: # Sample from the modality specific latent space mu_private, logvar_private = private_params[mod] mu_private, sigma_private = ( mu_private[i], torch.exp(0.5 * logvar_private[i]), ) private_latents = dist.Normal(mu_private, sigma_private).rsample( [len(shared_latents)] ) latents = torch.cat([shared_latents, private_latents], dim=-1) decoder = self.decoders[mod] recon = decoder(latents)[ "reconstruction" ] # (batch_size_K, nb_channels, w, h) x_m = inputs.data[mod][i] # (nb_channels, w, h) lpx_zs += ( self.recon_log_probs[mod]( recon, torch.stack([x_m] * len(recon)) ) .reshape(recon.size(0), -1) .sum(-1) ) # Compute ln(p(z_private)) ln_prior += dist.Normal(0, 1).log_prob(private_latents).sum(dim=-1) # Compute ln(q(z_private|x)) ln_posterior += ( dist.Normal(mu_private, sigma_private) .log_prob(private_latents) .sum(-1) ) # Compute ln(p(z_shared)) prior = dist.Normal(0, 1) ln_prior += prior.log_prob(shared_latents).sum(dim=-1) # Compute posteriors ln(q(z_shared|x,y)) qz_xy = dist.Normal(mu[i], sigma[i]) ln_posterior += qz_xy.log_prob(shared_latents).sum(dim=-1) ln_px = torch.logsumexp(lpx_zs + ln_prior - ln_posterior, dim=0) lnpxs.append(ln_px) # next batch start_idx += batch_size_K stop_idx = min(stop_idx + batch_size_K, K) ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - math.log(K) return -ll