import logging
import math
from typing import Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pythae.models.base.base_utils import ModelOutput
from torch.distributions import Laplace, Normal
from multivae.data.datasets.base import MultimodalBaseDataset
from multivae.data.utils import drop_unused_modalities
from multivae.models.nn.default_architectures import (
BaseDictDecodersMultiLatents,
BaseDictEncoders_MultiLatents,
)
from ..base import BaseMultiVAE
from .mmvaePlus_config import MMVAEPlusConfig
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]
class MMVAEPlus(BaseMultiVAE):
"""The MMVAE+ model.
Args:
model_config (MMVAEPlusConfig): An instance of MMVAEConfig in which any model's
parameters is made available.
encoders (Dict[str, ~multivae.models.nn.base_architectures.BaseMultilatentEncoder]): A dictionary containing
the modalities names and the encoders for each modality. Each encoder is an instance of
Multivae's BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None.
decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing
the modalities names and the decoders for each modality. Each decoder is an instance of
Pythae's BaseDecoder.
"""
def __init__(
self,
model_config: MMVAEPlusConfig,
encoders: dict = None,
decoders: dict = None,
):
if model_config.modalities_specific_dim is None:
raise AttributeError(
"The modalities_specific_dim attribute must"
" be provided in the model config."
)
super().__init__(model_config, encoders, decoders)
if model_config.prior_and_posterior_dist == "laplace_with_softmax":
self.post_dist = Laplace
self.prior_dist = Laplace
elif model_config.prior_and_posterior_dist == "normal":
self.post_dist = Normal
self.prior_dist = Normal
elif model_config.prior_and_posterior_dist == "normal_with_softplus":
self.post_dist = Normal
self.prior_dist = Normal
else:
raise AttributeError(
" The posterior_dist parameter must be "
" either 'laplace_with_softmax','normal' or 'normal_with_softplus'. "
f" {model_config.prior_and_posterior_dist} was provided."
)
# Set the priors for shared and private spaces.
self.mean_priors = torch.nn.ParameterDict()
self.logvars_priors = torch.nn.ParameterDict()
self.beta = model_config.beta
self.modalities_specific_dim = model_config.modalities_specific_dim
self.reconstruction_option = model_config.reconstruction_option
self.multiple_latent_spaces = True
self.style_dims = {m: self.modalities_specific_dim for m in self.encoders}
# Add the private and shared latents priors.
# modality specific priors (referred to as r distributions in paper)
for mod in list(self.encoders.keys()):
self.mean_priors[mod] = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim),
requires_grad=False,
)
self.logvars_priors[mod] = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim),
requires_grad=model_config.learn_modality_prior,
)
# general prior (for the entire latent code) referred to as p in the paper
self.mean_priors["shared"] = torch.nn.Parameter(
torch.zeros(
1, model_config.latent_dim + model_config.modalities_specific_dim
),
requires_grad=False,
)
self.logvars_priors["shared"] = torch.nn.Parameter(
torch.zeros(
1, model_config.latent_dim + model_config.modalities_specific_dim
),
requires_grad=model_config.learn_shared_prior,
)
self.model_name = "MMVAEPlus"
self.objective = model_config.loss
def _log_var_to_std(self, log_var):
"""For latent distributions parameters, transform the log covariance to the
standard deviation of the distribution either applying softmax, softplus
or simply torch.exp(0.5 * ...) depending on the model configuration.
"""
if self.model_config.prior_and_posterior_dist == "laplace_with_softmax":
return F.softmax(log_var, dim=-1) * log_var.size(-1) + 1e-6
elif self.model_config.prior_and_posterior_dist == "normal_with_softplus":
return F.softplus(log_var) + 1e-6
else:
return torch.exp(0.5 * log_var)
def _compute_posteriors_and_embeddings(self, inputs, detach, **kwargs):
# Drop unused modalities
inputs = drop_unused_modalities(inputs)
# First compute all the encodings for all modalities
embeddings = {}
posteriors = {m: {} for m in inputs.data}
reconstructions = {}
k_iwae = kwargs.pop("K", self.model_config.K)
for cond_mod in inputs.data:
output = self.encoders[cond_mod](inputs.data[cond_mod])
mu, log_var = output.embedding, output.log_covariance
mu_style = output.style_embedding
log_var_style = output.style_log_covariance
sigma = self._log_var_to_std(log_var)
sigma_style = self._log_var_to_std(log_var_style)
# Shared latent variable
qu_x = self.post_dist(mu, sigma)
u_x = qu_x.rsample([k_iwae])
# Private latent variable
qw_x = self.post_dist(mu_style, sigma_style)
w_x = qw_x.rsample([k_iwae])
# The DREG loss uses detached parameters in the loss computation afterwards.
if detach:
qu_x = self.post_dist(mu.clone().detach(), sigma.clone().detach())
qw_x = self.post_dist(
mu_style.clone().detach(), sigma_style.clone().detach()
)
# Then compute all the cross-modal reconstructions
reconstructions[cond_mod] = {}
for recon_mod in inputs.data:
# Self-reconstruction
if recon_mod == cond_mod:
z_x = torch.cat([u_x, w_x], dim=-1)
# Cross modal reconstruction
else:
# only keep the shared latent and generate private from prior
mu_prior_mod = torch.cat(
[self.mean_priors[recon_mod]] * len(mu), axis=0
)
sigma_prior_mod = torch.cat(
[self._log_var_to_std(self.logvars_priors[recon_mod])]
* len(mu),
axis=0,
)
w = self.prior_dist(
mu_prior_mod,
sigma_prior_mod,
).rsample([k_iwae])
z_x = torch.cat([u_x, w], dim=-1)
# Decode
z = z_x.reshape(-1, z_x.shape[-1])
recon = self.decoders[recon_mod](z)["reconstruction"]
recon = recon.reshape((*z_x.shape[:-1], *recon.shape[1:]))
reconstructions[cond_mod][recon_mod] = recon
posteriors[cond_mod] = {"u": qu_x, "w": qw_x}
embeddings[cond_mod] = {"u": u_x, "w": w_x}
return embeddings, posteriors, reconstructions
[docs]
def forward(self, inputs: MultimodalBaseDataset, **kwargs):
"""Compute loss and metrics."""
if self.objective == "dreg_looser":
# The DreG estimation uses detached posteriors
embeddings, posteriors, reconstructions = (
self._compute_posteriors_and_embeddings(inputs, detach=True)
)
return self._dreg_looser(posteriors, embeddings, reconstructions, inputs)
if self.objective == "iwae_looser":
embeddings, posteriors, reconstructions = (
self._compute_posteriors_and_embeddings(inputs, detach=False)
)
return self._iwae_looser(posteriors, embeddings, reconstructions, inputs)
raise NotImplementedError
@property
def pz_params(self):
"""From the prior mean and log_covariance, return the mean and standard
deviation, either applying softmax or not depending on the choice of prior
distribution.
Returns:
tuple: mean, std
"""
mean = self.mean_priors["shared"]
log_var = self.logvars_priors["shared"]
std = self._log_var_to_std(log_var)
return mean, std
def _compute_k_lws(self, posteriors, embeddings, reconstructions, inputs):
"""Compute the individual likelihoods without any aggregation on k_iwae
or the batch.
"""
if hasattr(inputs, "masks"):
# Compute the number of available modalities per sample
n_mods_sample = torch.sum(
torch.stack(tuple(inputs.masks.values())).int(), dim=0
)
else:
n_mods_sample = torch.tensor([self.n_modalities])
lws = {}
for mod in embeddings:
u = embeddings[mod]["u"] # (K, n_batch, latent_dim)
w = embeddings[mod]["w"] # (K, n_batch, latent_dim)
n_mods_sample = n_mods_sample.to(u.device)
### Compute log p(z)
z = torch.cat([u, w], dim=-1)
lpz = self.prior_dist(*self.pz_params).log_prob(z).sum(-1)
### Compute log q(u|X) where u is the shared latent
# Get all the individual log q(u|x_i) for all modalities
if hasattr(inputs, "masks"):
qu_x = []
for m in posteriors:
qu = posteriors[m]["u"].log_prob(u).sum(-1)
# for unavailable modalities, set the log prob to -infinity so that it accounts for 0
# in the log_sum_exp.
qu[torch.stack([inputs.masks[m] == False] * len(u))] = -torch.inf
qu_x.append(qu)
lqu_x = torch.stack(qu_x) # n_modalities,K,nbatch
else:
lqu_x = torch.stack(
[posteriors[m]["u"].log_prob(u).sum(-1) for m in posteriors]
) # n_modalities,K,nbatch
# Compute the mixture of experts
lqu_x = torch.logsumexp(lqu_x, dim=0) - torch.log(
n_mods_sample
) # log_mean_exp
### Compute log q(w |x_m)
lqw_x = posteriors[mod]["w"].log_prob(w).sum(-1)
### Compute log p(X|u,w) for all modalities
lpx_z = 0
for recon_mod in reconstructions[mod]:
x_recon = reconstructions[mod][recon_mod]
K, n_batch = x_recon.shape[0], x_recon.shape[1]
lpx_z_mod = (
self.recon_log_probs[recon_mod](x_recon, inputs.data[recon_mod])
.view(K, n_batch, -1)
.mul(self.rescale_factors[recon_mod])
.sum(-1)
)
if hasattr(inputs, "masks"):
# cancel unavailable modalities
lpx_z_mod *= inputs.masks[recon_mod].float()
lpx_z += lpx_z_mod
### Compute the entire likelihood
lw = lpx_z + self.beta * (lpz - lqu_x - lqw_x)
if hasattr(inputs, "masks"):
# cancel unavailable modalities
lw *= inputs.masks[mod].float()
lws[mod] = lw
return lws, n_mods_sample
def _dreg_looser(self, posteriors, embeddings, reconstructions, inputs):
"""The DreG estimation for IWAE. losses components in lws needs to have been computed on
**detached** posteriors.
"""
lws, n_mods_sample = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs
)
### Compute the wk for each modality
wk = {}
with torch.no_grad(): # The wk are constants
for m, lw in lws.items():
wk[m] = (
lw - torch.logsumexp(lw, 0, keepdim=True)
).exp() # K, batch_size
### Compute the loss
lws = torch.stack(
[lws[mod] * wk[mod] for mod in lws], dim=0
) # n_modalities, K, batch_size
lws = lws.sum(1) # Sum over the k_iwae samples
### Take the mean over the modalities (outside the log)
lws = lws.sum(0) / n_mods_sample
# The gradient with respect to \phi is multiplied one more time by wk
# To achieve that, we register a hook on the latent variables u and w
for mod in embeddings:
embeddings[mod]["w"].register_hook(
lambda grad, w=wk[mod]: w.unsqueeze(-1) * grad
)
embeddings[mod]["u"].register_hook(
lambda grad, w=wk[mod]: w.unsqueeze(-1) * grad
)
### Return the sum over the batch
return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics=dict())
def _iwae_looser(self, posteriors, embeddings, reconstructions, inputs):
"""The IWAE loss but with the sum outside of the loss for increased stability.
(following Shi et al 2019).
"""
# Get all individual likelihoods
lws, n_mods_sample = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs
)
lws = torch.stack(list(lws.values()), dim=0) # (n_modalities, K, n_batch)
# Take log_mean_exp on the k_iwae samples to obtain the k-sampled estimate
lws = torch.logsumexp(lws, dim=1) - math.log(
lws.size(1)
) # n_modalities, n_batch
# Take the mean on modalities
lws = lws.sum(0) / n_mods_sample
# Return the sum over the batch
return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics=dict())
[docs]
def encode(
self,
inputs: MultimodalBaseDataset,
cond_mod: Union[list, str] = "all",
N: int = 1,
return_mean=False,
**kwargs,
):
"""Generate encodings conditioning on all modalities or a subset of modalities.
Args:
inputs (MultimodalBaseDataset): The dataset to use for the conditional generation.
cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities
names to condition on.
N (int) : The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample).
Returns:
ModelOutput : contains fields
'z' (torch.Tensor (n_data, N, latent_dim))
'one_latent_space' (bool) = False
'modalities_z' (Dict[str,torch.Tensor (n_data, N, latent_dim) ])
"""
# Look up the batchsize
batch_size = len(list(inputs.data.values())[0])
cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod
if all(s in self.encoders.keys() for s in cond_mod):
# For the conditioning modalities we compute all the embeddings
encoders_outputs = {m: self.encoders[m](inputs.data[m]) for m in cond_mod}
if return_mean:
list_mean = [o.embedding for o in encoders_outputs.values()]
embedding = torch.mean(torch.stack(list_mean), dim=0)
z = torch.stack([embedding] * N) if N > 1 else embedding
else:
# Choose one of the conditioning modalities at random to sample the shared information.
random_mod = np.random.choice(cond_mod)
# Sample the shared latent code
mu = encoders_outputs[random_mod].embedding
sigma = self._log_var_to_std(
encoders_outputs[random_mod].log_covariance
)
sample_shape = torch.Size([]) if N == 1 else torch.Size([N])
z = self.post_dist(mu, sigma).rsample(sample_shape)
flatten = kwargs.pop("flatten", False)
if flatten:
z = z.reshape(-1, self.latent_dim)
# Modality specific encodings : given by encoders for conditioning modalities
# Sampling from the priors for the rest of the modalities.
style_z = {}
for m in self.encoders:
if m not in cond_mod:
# Sample from priors parameters.
if self.reconstruction_option == "single_prior":
mu_m = self.mean_priors[m]
logvar_m = self.logvars_priors[m]
if self.reconstruction_option == "joint_prior":
mu_m = self.mean_priors["shared"][:, self.latent_dim :]
logvar_m = self.logvars_priors["shared"][:, self.latent_dim :]
mu_m = torch.cat([mu_m] * batch_size, dim=0)
logvar_m = torch.cat([logvar_m] * batch_size, dim=0)
else:
# Sample from posteriors parameters
mu_m = encoders_outputs[m].style_embedding
logvar_m = encoders_outputs[m].style_log_covariance
sigma_m = self._log_var_to_std(logvar_m)
if return_mean:
style_z[m] = torch.stack([mu_m] * N) if N > 1 else mu_m
else:
style_z[m] = self.post_dist(mu_m, sigma_m).rsample(sample_shape)
if flatten:
style_z[m] = style_z[m].reshape(-1, self.modalities_specific_dim)
return ModelOutput(z=z, one_latent_space=False, modalities_z=style_z)
[docs]
def generate_from_prior(self, n_samples, **kwargs):
sample_shape = [n_samples] if n_samples > 1 else []
z = self.prior_dist(*self.pz_params).rsample(sample_shape).to(self.device)
return ModelOutput(z=z.squeeze(), one_latent_space=True)
def default_encoders(self, model_config) -> nn.ModuleDict:
return BaseDictEncoders_MultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims={
m: model_config.modalities_specific_dim
for m in self.model_config.input_dims
},
)
def default_decoders(self, model_config) -> nn.ModuleDict:
return BaseDictDecodersMultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims={
m: model_config.modalities_specific_dim for m in model_config.input_dims
},
)
[docs]
@torch.no_grad()
def compute_joint_nll(self, inputs, K=1000, batch_size_K=100):
"""Estimate the negative joint likelihood.
Args:
inputs (MultimodalBaseDataset) : a batch of samples.
K (int) : the number of importance samples for the estimation. Default to 1000.
batch_size_K (int) : Default to 100.
Returns:
The negative log-likelihood summed over the batch.
"""
# Check that the dataset is not incomplete
self.eval()
if hasattr(inputs, "masks"):
raise AttributeError(
"The compute_joint_nll method is not yet implemented for incomplete datasets."
)
n_data = len(inputs.data.popitem()[1]) # number of samples in the dataset
ll = 0
# Set the rescale factors and beta to 1 for the computation of the likelihood
rescale_factors, self.rescale_factors = (
self.rescale_factors.copy(),
{m: 1 for m in self.rescale_factors},
)
beta, self.beta = self.model_config.beta, 1
for i in range(n_data):
inputs_i = MultimodalBaseDataset(
data={m: inputs.data[m][i].unsqueeze(0) for m in inputs.data}
)
k_iwae = K // self.n_modalities # number of samples per modality
embeddings, posteriors, reconstructions = (
self._compute_posteriors_and_embeddings(
inputs_i, detach=False, K=k_iwae
)
)
lws, _ = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs_i
)
# aggregate by taking the logsumexp on all lws element
lws = torch.cat(list(lws.values()), dim=0) # n_modalities*K, n_batch
# Take log_mean_exp on all samples
ll += torch.logsumexp(lws, dim=0) - math.log(lws.size(0)) # n_batch
# revert changes made on rescale factors and beta
self.rescale_factors = rescale_factors
self.beta = beta
return -ll.sum()