import logging
import math
from typing import Union
import numpy as np
import torch
import torch.distributions as dist
import torch.nn as nn
import torch.nn.functional as F
from pythae.models.base.base_utils import ModelOutput
from scipy.stats import entropy
from torch.distributions import Laplace, Normal
from torch.utils.data import DataLoader
from tqdm import tqdm
from multivae.data.datasets.base import MultimodalBaseDataset
from multivae.data.utils import drop_unused_modalities, set_inputs_to_device
from multivae.models.nn.default_architectures import (
BaseDictDecodersMultiLatents,
BaseDictEncoders_MultiLatents,
)
from ..base import BaseMultiVAE
from .cmvae_config import CMVAEConfig
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]
class CMVAE(BaseMultiVAE):
"""The CMVAE model from "Deep Generative Clustering with Multimodal Diffusion Variational Autoencoders"
(Palumbo et al, 2023).
The diffusion decoders are not implemented in this version.
Args:
model_config (CMVAEConfig): An instance of CMVAEConfig in which any model's
parameters is made available.
encoders (Dict[str, ~multivae.models.nn.base_architectures.BaseMultilatentEncoder]): A dictionary containing
the modalities names and the encoders for each modality. Each encoder is an instance of
Multivae's BaseMultilatentEncoder since this model uses multiple latent spaces. Default: None.
decoders (Dict[str, ~pythae.models.nn.base_architectures.BaseDecoder]): A dictionary containing
the modalities names and the decoders for each modality. Each decoder is an instance of
Pythae's BaseDecoder.
"""
def __init__(
self,
model_config: CMVAEConfig,
encoders: dict = None,
decoders: dict = None,
):
if model_config.modalities_specific_dim is None:
raise AttributeError(
"The modalities_specific_dim attribute must"
" be provided in the model config."
)
super().__init__(model_config, encoders, decoders)
self.model_name = "CMVAE"
if model_config.prior_and_posterior_dist == "laplace_with_softmax":
self.latent_dist = Laplace
elif model_config.prior_and_posterior_dist == "normal":
self.latent_dist = Normal
elif model_config.prior_and_posterior_dist == "normal_with_softplus":
self.latent_dist = Normal
else:
raise AttributeError(
" The posterior_dist parameter must be "
" either 'laplace_with_softmax','normal' or 'normal_with_softplus'. "
f" {model_config.prior_and_posterior_dist} was provided."
)
self.multiple_latent_spaces = True
self.n_clusters = model_config.number_of_clusters
self.style_dims = {
m: self.model_config.modalities_specific_dim for m in self.encoders
}
# Set the modality specific priors for private spaces (referred to as r in )
self.r_mean_priors = torch.nn.ParameterDict()
self.r_logvars_priors = torch.nn.ParameterDict()
for mod in list(self.encoders.keys()):
self.r_mean_priors[mod] = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim),
requires_grad=False,
) # the mean is fixed but the scale can change
self.r_logvars_priors[mod] = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim),
requires_grad=model_config.learn_modality_prior,
)
# Set the regularization prior for the private spaces (referred to as p(w_m))
# in the paper
self.w_mean_prior = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim), requires_grad=False
)
self.w_logvar_prior = torch.nn.Parameter(
torch.zeros(1, model_config.modalities_specific_dim), requires_grad=False
)
# Initialize the weights for the cluster distribution
self._pc_params = torch.nn.Parameter(
torch.zeros(self.n_clusters),
requires_grad=True,
)
# Initialize the mean and variances for each cluster in the shared latent spaces
self.mean_clusters = nn.ParameterList(
[
nn.Parameter(
((2 * torch.rand(1, self.latent_dim)) - 1), requires_grad=True
)
for c_k in range(self.n_clusters)
]
)
# NOTE : the scales are fixed to 1 in the original code !
self.logvar_clusters = nn.ParameterList(
[
nn.Parameter(torch.zeros(1, self.latent_dim), False)
for c_k in range(self.n_clusters)
]
)
@property
def pc_params(self):
"""Parameters of prior distribution on latent clusters."""
return F.softmax(self._pc_params, dim=-1)
def _log_var_to_std(self, log_var):
"""For latent distributions parameters, transform the log covariance to the
standard deviation of the distribution either applying softmax, softplus
or simply torch.exp(0.5 * ...) depending on the model configuration.
"""
if self.model_config.prior_and_posterior_dist == "laplace_with_softmax":
return F.softmax(log_var, dim=-1) * log_var.size(-1) + 1e-6
elif self.model_config.prior_and_posterior_dist == "normal_with_softplus":
return F.softplus(log_var) + 1e-6
else:
return torch.exp(0.5 * log_var)
def _compute_posteriors_and_embeddings(self, inputs, detach, **kwargs):
# Drop unused modalities
inputs = drop_unused_modalities(inputs)
# First compute all the encodings for all modalities
embeddings = {}
posteriors = {}
reconstructions = {}
k_iwae = kwargs.pop("K", self.model_config.K)
for cond_mod in inputs.data:
output = self.encoders[cond_mod](inputs.data[cond_mod])
mu, log_var = output.embedding, output.log_covariance
mu_style = output.style_embedding
log_var_style = output.style_log_covariance
sigma = self._log_var_to_std(log_var)
sigma_style = self._log_var_to_std(log_var_style)
# Shared latent variable
qu_x = self.latent_dist(mu, sigma)
u_x = qu_x.rsample([k_iwae])
# Private latent variable
qw_x = self.latent_dist(mu_style, sigma_style)
w_x = qw_x.rsample([k_iwae])
# Then compute all the cross-modal reconstructions
reconstructions[cond_mod] = {}
for recon_mod in inputs.data:
# Self-reconstruction
if recon_mod == cond_mod:
z_x = torch.cat([u_x, w_x], dim=-1)
# Cross modal reconstruction
else:
# only keep the shared latent and generate private from prior
mu_prior_mod = torch.cat(
[self.r_mean_priors[recon_mod]] * len(mu), axis=0
)
sigma_prior_mod = torch.cat(
[self._log_var_to_std(self.r_logvars_priors[recon_mod])]
* len(mu),
axis=0,
)
w = self.latent_dist(
mu_prior_mod,
sigma_prior_mod,
).rsample(
[k_iwae]
) # K, n_batch, modality_specific_sim
z_x = torch.cat([u_x, w], dim=-1)
# Decode
z = z_x.reshape(-1, z_x.shape[-1])
recon = self.decoders[recon_mod](z)["reconstruction"]
recon = recon.reshape((*z_x.shape[:-1], *recon.shape[1:]))
reconstructions[cond_mod][recon_mod] = recon
# The DREG loss uses detached posteriors in the loss computation afterwards.
if detach:
qu_x = self.latent_dist(mu.clone().detach(), sigma.clone().detach())
qw_x = self.latent_dist(
mu_style.clone().detach(), sigma_style.clone().detach()
)
posteriors[cond_mod] = {"u": qu_x, "w": qw_x}
embeddings[cond_mod] = {"u": u_x, "w": w_x}
return posteriors, embeddings, reconstructions
[docs]
def forward(self, inputs: MultimodalBaseDataset, **kwargs):
"""Forward pass of the CMVAE model. Returns the loss on the batch."""
if self.model_config.loss == "dreg_looser":
posteriors, embeddings, reconstructions = (
self._compute_posteriors_and_embeddings(inputs, detach=True, **kwargs)
)
# For the DreG estimation, we compute the individual likelihoods with detached posteriors.
lws, embeddings, n_mods_sample = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs
)
return self._dreg_looser(lws, embeddings, n_mods_sample)
if self.model_config.loss == "iwae_looser":
posteriors, embeddings, reconstructions = (
self._compute_posteriors_and_embeddings(inputs, detach=False, **kwargs)
)
lws, _, n_mods_sample = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs
)
return self._iwae_looser(lws, n_mods_sample)
raise NotImplementedError()
def _compute_k_lws(self, posteriors, embeddings, reconstructions, inputs):
"""Compute all losses components without any aggregation on K nor batch.
Returns:
lws (dict) : the losses for each modality
embeddings (dict) : the embeddings for each modality
n_mod_samples (Tensor): the number of available modalities per sample
"""
if hasattr(inputs, "masks"):
# Compute the number of available modalities per sample
n_mods_sample = torch.sum(
torch.stack(tuple(inputs.masks.values())).int(), dim=0
)
else:
n_mods_sample = torch.tensor([self.n_modalities])
lws = {}
for mod in embeddings:
### Compute log p(w_m) / regularizing prior for the private spaces
mu = self.w_mean_prior
sigma = self._log_var_to_std(self.w_logvar_prior)
lpw = self.latent_dist(mu, sigma).log_prob(embeddings[mod]["w"]).sum(-1)
### Compute log q(w_m | x_m)
lqw_x = posteriors[mod]["w"].log_prob(embeddings[mod]["w"]).sum(-1)
### Compute log q_{\phi_z}(z| X )
u = embeddings[mod]["u"] # shared latent variable
if hasattr(inputs, "masks"):
lqu_x = []
for m in posteriors:
lqu = posteriors[m]["u"].log_prob(u).sum(-1)
lqu[torch.stack([inputs.masks[m] == False] * len(u))] = -torch.inf
lqu_x.append(lqu)
lqu_x = torch.stack(lqu_x)
else:
lqu_x = torch.stack(
[posteriors[m]["u"].log_prob(u).sum(-1) for m in posteriors]
) # n_modalities,K,nbatch
lqu_x = torch.logsumexp(lqu_x, dim=0) - torch.log(n_mods_sample).to(
lqu_x.device
) # log_mean_exp
### Compute log p_{\pi}(c) for all clusters
lpc = torch.log(self.pc_params) # n_clusters
### Compute log p(z|c) for all clusters
lpzc = []
for i in range(self.n_clusters):
mu_cluster = self.mean_clusters[i]
sigma_cluster = self._log_var_to_std(self.logvar_clusters[i])
lpzc.append(self.latent_dist(mu_cluster, sigma_cluster).log_prob(u))
lpzc = torch.stack(lpzc, dim=0) # n_clusters, K, batch_size, latent_dim
lpzc = lpzc.sum(-1) # n_clusters, K, batch_size
### Compute q (c | z) for all clusters
qzc = (
torch.softmax(lpc.view(self.n_clusters, 1, 1) + lpzc, dim=0) + 1e-20
) # shape n_clusters, K, batch_size
### Compute \sum_m log p(x_m|z,w_m)
lpx_z = 0
for recon_mod in reconstructions[mod]:
x_recon = reconstructions[mod][recon_mod]
k_iwae, n_batch = x_recon.shape[0], x_recon.shape[1]
lpx_z_mod = (
self.recon_log_probs[recon_mod](x_recon, inputs.data[recon_mod])
.view(k_iwae, n_batch, -1)
.mul(self.rescale_factors[recon_mod])
.sum(-1)
)
if hasattr(inputs, "masks"):
# don't reconstruct unavailable modalities
lpx_z_mod *= inputs.masks[recon_mod].float()
lpx_z += lpx_z_mod
### Compute the explicit expectation on q(c|z, X)
lw = 0
for c, q_c in enumerate(qzc):
lw_c = lpx_z + self.model_config.beta * (
lpc[c] + lpzc[c] + lpw - lqu_x - lqw_x - q_c.log()
)
lw += q_c * lw_c
assert lw.shape[0] == (k_iwae)
# lw.shape : (K, n_batch)
if hasattr(inputs, "masks"):
# cancel unavailable modalities
lw *= inputs.masks[mod].float()
lws[mod] = lw
return lws, embeddings, n_mods_sample
def _iwae_looser(self, lws, n_mods_sample):
"""The IWAE loss with the sum outside of the log for increased stability.
(following Shi et al 2019).
"""
lws = torch.stack(list(lws.values()), dim=0) # n_modalities, K, n_batch
# Take log_mean_exp on K
lws = torch.logsumexp(lws, dim=1) - math.log(
lws.size(1)
) # n_modalities, n_batch
# Take the mean on modalities
lws = lws.sum(0) / n_mods_sample # n_batch
# Return the sum over the batch
return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics=dict())
def _dreg_looser(self, lws, embeddings, n_mods_sample):
"""The DreG estimation for IWAE. losses components in lws needs to have been computed on
**detached** posteriors.
"""
wk = {}
with torch.no_grad():
for mod, lw in lws.items():
wk[mod] = (
lw - torch.logsumexp(lw, 0, keepdim=True)
).exp() # K, n_batch
# wk is a constant that will not require grad
# Compute the loss
lws = torch.stack(
[(lws[mod] * wk[mod]) for mod in embeddings], dim=0
) # n_modalities,K, n_batch
lws = lws.sum(1) # sum on K
# The gradient with respect to \phi is multiplied one more time by wk
# To achieve that, we register a hook on the latent variables u and w
for mod in embeddings:
embeddings[mod]["w"].register_hook(
lambda grad, w=wk[mod]: w.unsqueeze(-1) * grad
)
embeddings[mod]["u"].register_hook(
lambda grad, w=wk[mod]: w.unsqueeze(-1) * grad
)
# Average over modalities
lws = lws.sum(0) / n_mods_sample # n_batch
# Return the sum over the batch
return ModelOutput(loss=-lws.sum(), loss_sum=-lws.sum(), metrics=dict())
[docs]
def encode(
self,
inputs: MultimodalBaseDataset,
cond_mod: Union[list, str] = "all",
N: int = 1,
return_mean=False,
**kwargs,
):
"""Generate encodings conditioning on all modalities or a subset of modalities.
Args:
inputs (MultimodalBaseDataset): The dataset to use for the conditional generation.
cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities
names to condition on.
N (int) : The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample).
Returns:
ModelOutput: Contains the following fields
'z' (torch.Tensor (n_data, N, latent_dim))
'one_latent_space' (bool)
'modalities_z' (Dict[str,torch.Tensor (n_data, N, latent_dim) ])
"""
cond_mod = super().encode(inputs, cond_mod, N, return_mean, **kwargs).cond_mod
if all([s in self.encoders.keys() for s in cond_mod]):
# For the conditioning modalities we compute all the embeddings
encoders_outputs = {m: self.encoders[m](inputs.data[m]) for m in cond_mod}
# Choose one of the conditioning modalities at random to sample the shared information.
random_mod = np.random.choice(cond_mod)
# Sample the shared latent code
mu = encoders_outputs[random_mod].embedding
log_var = encoders_outputs[random_mod].log_covariance
sigma = self._log_var_to_std(log_var)
# Adapt shape in the case of one sample for uniformity
if len(mu.shape) == 1:
mu = mu.unsqueeze(0)
sigma = sigma.unsqueeze(0)
# Get the z
if return_mean:
if N > 1:
z = torch.stack([mu] * N)
else:
z = mu
else: # sample
qz_x = self.latent_dist(mu, sigma)
sample_shape = torch.Size([]) if N == 1 else torch.Size([N])
z = qz_x.rsample(sample_shape)
flatten = kwargs.pop("flatten", False)
if flatten:
z = z.reshape(-1, self.latent_dim)
# Modality specific encodings : given by encoders for conditioning modalities
# Sampling from the priors for the rest of the modalities.
style_z = {}
for m in self.encoders:
if m not in cond_mod:
# Sample from priors parameters.
if self.model_config.reconstruction_option == "single_prior":
mu_m = self.r_mean_priors[m]
logvar_m = self.r_logvars_priors[m]
if self.model_config.reconstruction_option == "joint_prior":
mu_m = self.w_mean_prior
logvar_m = self.w_logvar_prior
mu_m = torch.cat([mu_m] * len(mu), dim=0)
logvar_m = torch.cat([logvar_m] * len(mu), dim=0)
else:
# Sample from posteriors parameters
mu_m = encoders_outputs[m].style_embedding
logvar_m = encoders_outputs[m].style_log_covariance
if (
len(mu_m.shape) == 1
): # eventually adapt the shape when there is one sample for uniformity
mu_m = mu_m.unsqueeze(0)
logvar_m = logvar_m.unsqueeze(0)
sigma_m = self._log_var_to_std(logvar_m)
if return_mean:
if N > 1:
style_z[m] = torch.stack([mu_m] * N)
else:
style_z[m] = mu_m
else: # sample
style_z[m] = self.latent_dist(mu_m, sigma_m).rsample(sample_shape)
if flatten:
style_z[m] = style_z[m].reshape(
-1, self.model_config.modalities_specific_dim
)
return ModelOutput(z=z, one_latent_space=False, modalities_z=style_z)
[docs]
def generate_from_prior(self, n_samples, **kwargs):
"""Generate latent variables sampling from the prior distribution."""
# generate the clusters assignements
clusters = dist.Categorical(logits=self._pc_params).sample(
[n_samples]
) # n_samples, n_clusters
# get means for each clusters
means = torch.cat([self.mean_clusters[c] for c in clusters], dim=0)
lvs = torch.cat(
[self.logvar_clusters[c] for c in clusters], dim=0
) # n_samples, latent_dims
# sample shared latent variable
z_shared = self.latent_dist(
means, self._log_var_to_std(lvs)
).sample() # n_samples,latent_dim
# generate private parameters
style_z = {}
for m in self.encoders:
if self.model_config.reconstruction_option == "single_prior":
mu_m = self.r_mean_priors[m]
logvar_m = self.r_logvars_priors[m]
elif self.model_config.reconstruction_option == "joint_prior":
mu_m = self.w_mean_prior
logvar_m = self.w_logvar_prior
else:
raise NotImplementedError()
mu_m = torch.cat([mu_m] * n_samples, dim=0)
logvar_m = torch.cat([logvar_m] * n_samples, dim=0)
style_z[m] = self.latent_dist(mu_m, self._log_var_to_std(logvar_m)).sample()
return ModelOutput(z=z_shared, one_latent_space=False, modalities_z=style_z)
[docs]
def predict_clusters(self, inputs: MultimodalBaseDataset, **kwargs):
"""Returns the clusters for all samples in inputs.
Returns:
ModelOutput: with fields: clusters and pc_zs (dict).
.. note::
The clusters assignement can be accessed through
``clusters = model_output.clusters``
"""
with torch.no_grad():
modalities_cluster_assign = []
pc_zs = {}
# Optional additional computation useful for pruning
compute_norm_lliks = kwargs.pop("compute_lliks", False)
if compute_norm_lliks:
normalized_likelihoods = []
# First we compute the cluster assignements according to each modality individually
for mod in inputs.data:
# Compute shared embeddings
output_encoder = self.encoders[mod](inputs.data[mod])
mu = output_encoder.embedding
sigma = self._log_var_to_std(output_encoder.log_covariance)
z = self.latent_dist(mu, sigma).sample()
# Compute p(c|z) \propto p(z|c)p(c)
lpc = torch.log(self.pc_params + 1e-20) # n_clusters
lpz_c = [
self.latent_dist(
self.mean_clusters[i],
self._log_var_to_std(self.logvar_clusters[i]),
)
.log_prob(z)
.sum(-1)
for i in range(len(self.mean_clusters))
]
lpz_c = torch.stack(lpz_c, dim=0) # n_clusters, batch_size
pc_z = torch.softmax(
lpc.view(-1, 1) + lpz_c, dim=0
) # n_clusters, batch_size
cluster_assign = torch.argmax(pc_z, dim=0) # batch_size
modalities_cluster_assign.append(cluster_assign)
pc_zs[mod] = pc_z
if compute_norm_lliks:
normalized_likelihoods.append(
((lpz_c + lpc.view(-1, 1) - pc_z.log()) * pc_z)
.sum(0)
.squeeze(-1)
/ self.latent_dim
) # batch_size
# Take a majority vote among modalities
modalities_cluster_assign = torch.stack(
modalities_cluster_assign, dim=-1
) # batch_size, n_modalities
vote_cluster = torch.mode(modalities_cluster_assign, dim=-1)[
0
] # batch_size
# Compute the mean normalized likelihood
if compute_norm_lliks:
mean_norm_llik = torch.stack(normalized_likelihoods, dim=0).mean(0)
if compute_norm_lliks:
return ModelOutput(
clusters=vote_cluster, pc_zs=pc_zs, norm_lliks=mean_norm_llik
)
return ModelOutput(clusters=vote_cluster, pc_zs=pc_zs)
[docs]
def prune_clusters(self, train_data: MultimodalBaseDataset, batch_size=128):
"""Follows the pruning procedure described in the paper to compute the optimal
number of clusters.
At the end of this pruning, the model._pc_params will have been
adapted to correspond to selected clusters.
Args:
train_data (MultimodalBaseDataset): The data to use for pruning.
batch_size (int, optional): Defaults to 128.
Returns:
h_values (list): the list of entropy values from 0 to max_clusters.
"""
with torch.no_grad():
dataloader = DataLoader(train_data, batch_size=batch_size)
n_cluster_params = [None] * (self.n_clusters + 1)
h_values = [torch.inf] * (self.n_clusters + 1)
device = "cuda" if torch.cuda.is_available() else "cpu"
self.to(device)
while self.n_clusters >= 2:
logger.info(f"Computing entropy value with {self.n_clusters} clusters")
mass_per_clusters = torch.zeros_like(self._pc_params)
h_data = []
for batch in tqdm(dataloader):
batch.data = set_inputs_to_device(batch.data, device)
# Compute all p(c|z_m) and cluster assignements
cluster_predict = self.predict_clusters(batch, compute_lliks=True)
# Compute the mass per cluster
for i, m in enumerate(mass_per_clusters):
m += (cluster_predict.clusters == i).int().sum()
# Compute the entropies H(p(c|z_m))
h_pzc = []
for mod, pc_z in cluster_predict.pc_zs.items():
# Compute entropy along the cluster axis
h = torch.Tensor(
entropy(pc_z.squeeze(1).cpu().numpy(), axis=0)
/ (
np.log(
np.count_nonzero(
pc_z.squeeze(1).cpu().numpy(), axis=0
)
)
)
)
h_pzc.append(h.to(device))
# Compute the mean entropy over modalities
h_pzc = torch.stack(h_pzc, dim=0).mean(0)
# Compute the penalized_norm_entropy
h_data.append(
self.model_config.beta * h_pzc - cluster_predict.norm_lliks
)
# Take mean on the dataset
h_data = torch.cat(h_data, dim=-1).mean(-1)
# Save the parameters pc
logger.info(f"Entropy value : {h_data}")
h_values[self.n_clusters] = h_data
n_cluster_params[self.n_clusters] = self._pc_params.clone()
# Sanity check : verify that there is no mass in previously eliminated cluster
assert torch.all(
mass_per_clusters[torch.argwhere(self._pc_params == -torch.inf)]
== 0
)
logger.info(f"Mass in each cluster : {mass_per_clusters}")
# Adapt the clusters parameters by removing the cluster with less mass
self.n_clusters = self.n_clusters - 1
# set inf in mass for the clusters that were already removed
mass_per_clusters[self._pc_params.isinf()] = torch.inf
cluster_to_eliminate = torch.argmin(mass_per_clusters)
self._pc_params[cluster_to_eliminate] = -torch.inf
assert torch.sum(~self._pc_params.isinf()) == self.n_clusters
logger.info(f"Adapted pc_params to {self._pc_params}")
# Get the parameters for the number of clusters that minimizes entropy
self.n_clusters = torch.argmin(torch.Tensor(h_values))
self._pc_params = torch.nn.Parameter(n_cluster_params[self.n_clusters])
logger.info(
f"The optimal number of clusters is {self.n_clusters} and the pc_params have been adapted to :{self.pc_params}"
)
return h_values
def default_encoders(self, model_config) -> nn.ModuleDict:
return BaseDictEncoders_MultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims={
m: model_config.modalities_specific_dim
for m in self.model_config.input_dims
},
)
def default_decoders(self, model_config) -> nn.ModuleDict:
return BaseDictDecodersMultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims={
m: model_config.modalities_specific_dim for m in model_config.input_dims
},
)
[docs]
@torch.no_grad()
def compute_joint_nll(self, inputs, K=1000, **kwargs):
"""Estimate the negative joint likelihood.
Args:
inputs (MultimodalBaseDataset) : a batch of samples.
K (int) : the number of importance samples for the estimation. Default to 1000.
Returns:
The negative log-likelihood summed over the batch.
"""
# Check that the dataset is not incomplete for this computation.
self.eval()
if hasattr(inputs, "masks"):
raise AttributeError(
"The compute_joint_nll method is not yet implemented for incomplete datasets."
)
# Get the batch size from the input
n_data = len(list(inputs.data.values())[0])
# Set the rescale factors and beta to one while computing the joint likelihood
rescale_factors, self.rescale_factors = (
self.rescale_factors.copy(),
{m: 1 for m in self.rescale_factors},
)
beta, self.model_config.beta = self.model_config.beta, 1
# Start iterating on the data samples
ll = 0
for i in range(n_data):
inputs_i = MultimodalBaseDataset(
data={m: inputs.data[m][i].unsqueeze(0) for m in inputs.data}
)
# We dispatch the K samples equally between the unimodal posteriors
k_iwae = K // self.n_modalities
posteriors, embeddings, reconstructions = (
self._compute_posteriors_and_embeddings(
inputs_i, detach=False, K=k_iwae
)
)
lws, embeddings, _ = self._compute_k_lws(
posteriors, embeddings, reconstructions, inputs_i
)
# Aggregate by taking the logsumexp on all lws element
lws = torch.cat(list(lws.values()), dim=0) # n_modalities*K, n_batch
# Take log_mean_exp on all samples
ll += torch.logsumexp(lws, dim=0) - math.log(lws.size(0)) # n_batch
# Revert the changes made for the rescale factors and beta
self.rescale_factors = rescale_factors
self.model_config.beta = beta
return -ll.sum()