from itertools import chain, combinations
from typing import Union
import numpy as np
import torch
import torch.distributions as dist
from pythae.models.base.base_utils import ModelOutput
from multivae.data.datasets.base import IncompleteDataset, MultimodalBaseDataset
from multivae.models.nn.default_architectures import (
BaseDictDecodersMultiLatents,
BaseDictEncoders_MultiLatents,
)
from ..base import BaseMultiVAE
from ..base.base_utils import poe, rsample_from_gaussian
from .mopoe_config import MoPoEConfig
[docs]
class MoPoE(BaseMultiVAE):
"""Implementation for the Mixture of Product of experts model from
'Generalized Multimodal ELBO' Sutter 2021 (https://arxiv.org/abs/2105.02470).
This implementation is heavily based on the official one at
https://github.com/thomassutter/MoPoE.
Args:
model_config (MoPoEConfig): Contains all the parameters for the model.
encoders (dict): Contains the encoder for each modality. When using
modalities' specific latent spaces, the encoders must be instances
of ~multivae.models.nn.base_architectures.BaseMultilatentEncoder. Else,
the encoders must be instances of ~pythae.models.nn.base_architectures.BaseEncoder.
When None are provided, default MLP architectures are used.
decoders (dict): Contains the decoder for each modality. Each decoder must be an
instance of ~pythae.models.nn.base_architectures.BaseDecoder. When using modalities's
specific latent spaces, the decoder takes as input the concatenation of both
latent codes. When None are provided, default MLP architectures are used.
"""
def __init__(
self, model_config: MoPoEConfig, encoders: dict = None, decoders: dict = None
):
super().__init__(model_config, encoders, decoders)
self.multiple_latent_spaces = model_config.modalities_specific_dim is not None
self.model_name = "MoPoE"
list_subsets = self.model_config.subsets
if isinstance(list_subsets, dict):
list_subsets = list(list_subsets.values())
if list_subsets is None:
list_subsets = self.all_subsets()
self.set_subsets(list_subsets)
if self.multiple_latent_spaces:
self.style_dims = model_config.modalities_specific_dim
# Set default architectures if not provided
if encoders is None:
encoders = BaseDictEncoders_MultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims=model_config.modalities_specific_dim,
)
self.set_encoders(encoders)
if decoders is None:
decoders = BaseDictDecodersMultiLatents(
input_dims=model_config.input_dims,
latent_dim=model_config.latent_dim,
modality_dims=model_config.modalities_specific_dim,
)
self.set_decoders(decoders)
[docs]
def all_subsets(self):
"""Returns a list containing all possible subsets of the modalities.
(But the empty one).
"""
xs = list(self.encoders.keys())
# note we return an iterator rather than a list
subsets_list = chain.from_iterable(
combinations(xs, n) for n in range(len(xs) + 1)
)
return subsets_list
[docs]
def set_subsets(self, subsets_list):
"""Builds a dictionary of the subsets.
The keys are the subset_names created by concatenating the modalities' names.
The values are the list of modalities names.
"""
subsets = dict()
for mod_names in subsets_list:
mods = []
for mod_name in sorted(mod_names):
if (mod_name not in self.encoders.keys()) and (mod_name != ""):
raise AttributeError(
f"The provided subsets list contains unknown modality name {mod_name}."
" that is not the encoders dictionary or inputs_dim dictionary."
)
mods.append(mod_name)
key = "_".join(sorted(mod_names))
subsets[key] = mods
self.subsets = subsets
self.model_config.subsets = subsets
return
[docs]
def calc_joint_divergence(
self, mus: torch.Tensor, logvars: torch.Tensor, weights: torch.Tensor
):
"""Computes the KL divergence between the mixture of experts and the prior, by
developping into the sum of the tractable KLs divergences of each expert.
Args:
mus (Tensor): The means of the experts. (n_subset,n_samples, latent_dim)
logvars (Tensor): The logvars of the experts.(n_subset,n_samples, latent_dim)
weights (Tensor): The weights of the experts.(n_subset,n_samples)
Returns:
Tensor, Tensor: The group divergence summed over modalities, A tensor containing the KL terms for each experts.
"""
weights = weights.clone()
num_mods = mus.shape[0]
num_samples = mus.shape[1]
klds = torch.zeros(num_mods, num_samples)
device = mus.device
klds = klds.to(device)
weights = weights.to(device)
for k in range(0, num_mods):
kld_ind = -0.5 * (
1 - logvars[k, :, :].exp() - mus[k, :, :].pow(2) + logvars[k, :, :]
).sum(-1)
klds[k, :] = kld_ind
group_div = (
(weights * klds).sum(dim=0).mean()
) # sum over experts, mean over samples
divs = dict()
divs["joint_divergence"] = group_div
return divs
[docs]
def forward(self, inputs: MultimodalBaseDataset, **kwargs) -> ModelOutput:
# Compute latents parameters for all subsets
latents = self.inference(inputs)
results = dict()
# Get the embeddings for shared latent space
shared_embeddings = rsample_from_gaussian(
latents["joint"][0], latents["joint"][1]
)
len_batch = shared_embeddings.shape[0]
# Compute the divergence to the prior
div = self.calc_joint_divergence(
latents["mus"], latents["logvars"], latents["weights"]
)
for k, key in enumerate(div.keys()):
results[key] = div[key]
# Compute the reconstruction losses for each modality
loss = 0
kld = results["joint_divergence"]
for m_key in self.encoders.keys():
# reconstruct this modality from the shared embeddings representation
if self.multiple_latent_spaces:
try:
# sample from the modality specific latent space
style_mu = latents["modalities"][m_key].style_embedding
style_log_var = latents["modalities"][m_key].style_log_covariance
style_embeddings = rsample_from_gaussian(style_mu, style_log_var)
full_embedding = torch.cat(
[shared_embeddings, style_embeddings], dim=-1
)
except: # noqa
raise AttributeError(
" model_config.modality_specific_dims is not None, "
f"but encoder output for modality {m_key} doesn't have a "
"style_embedding attribute. "
"When using multiple latent spaces, the encoders' output"
"should be of the form : ModelOuput(embedding = ...,"
"style_embedding = ...,log_covariance = ..., style_log_covariance = ...)"
)
else:
full_embedding = shared_embeddings
recon = self.decoders[m_key](full_embedding).reconstruction
m_rec = (
(
-self.recon_log_probs[m_key](recon, inputs.data[m_key])
* self.rescale_factors[m_key]
)
.view(recon.size(0), -1)
.sum(-1)
)
# reconstruction loss
if hasattr(inputs, "masks"):
results["recon_" + m_key] = (m_rec * inputs.masks[m_key].float()).mean()
# We still take the mean over the all batch even if some samples are missing
# in order to give the same weights to all samples between batches
else:
results["recon_" + m_key] = m_rec.mean()
loss += results["recon_" + m_key]
# If using modality specific latent spaces, add modality specific klds
if self.multiple_latent_spaces:
style_mu = latents["modalities"][m_key].style_embedding
style_log_var = latents["modalities"][m_key].style_log_covariance
style_kld = -0.5 * (
1 - style_log_var.exp() - style_mu.pow(2) + style_log_var
).view(style_mu.size(0), -1).sum(-1)
if hasattr(inputs, "masks"):
style_kld *= inputs.masks[m_key].float()
kld += style_kld.mean() * self.model_config.beta_style
loss = loss + self.model_config.beta * kld
return ModelOutput(loss=loss, loss_sum=loss * len_batch, metrics=results)
[docs]
def modality_encode(
self, inputs: Union[MultimodalBaseDataset, IncompleteDataset], **kwargs
):
"""Computes for each modality, the parameters mu and logvar of the
unimodal posterior.
Args:
inputs (MultimodalBaseDataset): The data to encode.
Returns:
dict: Containing for each modality the encoder output.
"""
encoders_outputs = dict()
for m, m_key in enumerate(self.encoders.keys()):
input_modality = inputs.data[m_key]
output = self.encoders[m_key](input_modality)
encoders_outputs[m_key] = output
return encoders_outputs
def _poe_fusion(self, mus: torch.Tensor, logvars: torch.Tensor):
# Following the original implementation : add the prior when we consider the
# subset that contains all the modalities
if mus.shape[0] == len(self.encoders.keys()):
num_samples = mus[0].shape[0]
device = mus.device
mus = torch.cat(
(mus, torch.zeros(1, num_samples, self.latent_dim).to(device)), dim=0
)
logvars = torch.cat(
(logvars, torch.zeros(1, num_samples, self.latent_dim).to(device)),
dim=0,
)
return poe(mus, logvars)
[docs]
def subset_mask(self, inputs: IncompleteDataset, subset: Union[list, tuple]):
"""Returns a filter of the samples available in ALL the modalities contained in subset."""
filter = torch.tensor(
True,
).to(inputs.masks[subset[0]].device)
for mod in subset:
filter = torch.logical_and(filter, inputs.masks[mod])
return filter
[docs]
def inference(self, inputs: MultimodalBaseDataset, **kwargs):
"""Args:
inputs (MultimodalBaseDataset): The data.
Returns:
dict: all the subset and joint posteriors parameters.
"""
latents = dict()
enc_mods = self.modality_encode(inputs)
latents["modalities"] = enc_mods
device = inputs.data[list(inputs.data.keys())[0]].device
mus = torch.Tensor().to(device)
logvars = torch.Tensor().to(device)
distr_subsets = dict()
availabilities = []
for k, s_key in enumerate(self.subsets.keys()):
if s_key != "":
mods = self.subsets[s_key]
mus_subset = torch.Tensor().to(device)
logvars_subset = torch.Tensor().to(device)
if hasattr(inputs, "masks"):
filter = self.subset_mask(inputs, mods)
availabilities.append(filter)
for m, mod in enumerate(mods):
mus_mod = enc_mods[mod].embedding
log_vars_mod = enc_mods[mod].log_covariance
mus_subset = torch.cat((mus_subset, mus_mod.unsqueeze(0)), dim=0)
logvars_subset = torch.cat(
(logvars_subset, log_vars_mod.unsqueeze(0)), dim=0
)
# Case with only one sample : adapt the shape
if len(mus_subset.shape) == 2:
mus_subset = mus_subset.unsqueeze(1)
logvars_subset = logvars_subset.unsqueeze(1)
s_mu, s_logvar = self._poe_fusion(mus_subset, logvars_subset)
distr_subsets[s_key] = [s_mu, s_logvar]
# Add the subset posterior to be part of the mixture of experts
mus = torch.cat((mus, s_mu.unsqueeze(0)), dim=0)
logvars = torch.cat((logvars, s_logvar.unsqueeze(0)), dim=0)
if hasattr(inputs, "masks"):
# if we have an incomplete dataset, we need to randomly choose
# from the mixture of available experts
availabilities = torch.stack(availabilities, dim=0).float()
if len(availabilities.shape) == 1:
availabilities = availabilities.unsqueeze(1)
availabilities /= torch.sum(availabilities, dim=0) # (n_subset,n_samples)
joint_mu, joint_logvar = self.random_mixture_component_selection(
mus, logvars, availabilities
)
weights = availabilities
else:
weights = (1 / float(mus.shape[0])) * torch.ones(mus.shape[0]).to(device)
joint_mu, joint_logvar = self.deterministic_mixture_component_selection(
mus, logvars, weights
)
weights = (1 / float(mus.shape[0])) * torch.ones(
mus.shape[0], mus.shape[1]
).to(device)
latents["mus"] = mus
latents["logvars"] = logvars
latents["weights"] = weights
latents["joint"] = [joint_mu, joint_logvar]
latents["subsets"] = distr_subsets
return latents
[docs]
def encode(
self,
inputs: MultimodalBaseDataset,
cond_mod: Union[list, str] = "all",
N: int = 1,
return_mean=False,
**kwargs,
) -> ModelOutput:
"""Generate encodings conditioning on all modalities or a subset of modalities.
We use the product of experts on the conditioning modalities.
Args:
inputs (MultimodalBaseDataset): The dataset to use for the conditional generation.
cond_mod (Union[list, str]): Either 'all' or a list of str containing the modalities
names to condition on.
N (int) : The number of encodings to sample for each datapoint. Default to 1.
return_mean (bool) : if True, returns the mean of the posterior distribution (instead of a sample).
Returns:
ModelOutput instance with fields:
z (torch.Tensor (n_data, N, latent_dim))
one_latent_space (bool)
modalities_z (Dict[str,torch.Tensor (n_data, N, latent_dim) ])
"""
cond_mod = super().encode(inputs, cond_mod, N, **kwargs).cond_mod
# Compute the str associated to the subset
key = "_".join(sorted(cond_mod))
latents_subsets = self.inference(inputs)
# Take the product of experts on the subset
mu, log_var = latents_subsets["subsets"][key]
if return_mean and len(cond_mod) == self.n_modalities:
# Aggregate the posterior
mu = torch.stack(
[latents_subsets["subsets"][k][0] for k in latents_subsets["subsets"]]
).mean(0)
flatten = kwargs.pop("flatten", False)
z = rsample_from_gaussian(mu, log_var, N, return_mean, flatten=flatten)
if self.multiple_latent_spaces:
modalities_z = {}
for m in self.encoders:
if m in cond_mod:
mu_style = latents_subsets["modalities"][m].style_embedding
log_var_style = latents_subsets["modalities"][
m
].style_log_covariance
else:
mu_style = torch.zeros((len(mu), self.style_dims[m])).to(mu.device)
log_var_style = torch.zeros((len(mu), self.style_dims[m])).to(
mu.device
)
modalities_z[m] = rsample_from_gaussian(
mu_style, log_var_style, N, return_mean, flatten
)
return ModelOutput(z=z, one_latent_space=False, modalities_z=modalities_z)
return ModelOutput(z=z, one_latent_space=True)
[docs]
def random_mixture_component_selection(self, mus, logvars, availabilities):
"""Randomly select a subset for each sample among the available subsets.
Args:
mus (tensor): (n_subset,n_samples,latent_dim) the means of subset posterior.
logvars (tensor): (n_subset,n_samples,latent_dim) the log covariance of subset posterior.
availabilities (tensor): (n_subset,n_samples) boolean tensor.
"""
probs = availabilities.permute(1, 0) # n_samples,n_subset
choice = dist.OneHotCategorical(probs=probs).sample()
mus_ = mus.permute(1, 0, 2) # n_samples, n_subset,latent_dim
logvars_ = logvars.permute(1, 0, 2)
mus_ = mus_[choice.bool()]
logvars_ = logvars_[choice.bool()]
return mus_, logvars_
[docs]
def deterministic_mixture_component_selection(self, mus, logvars, w_modalities):
"""Associate a subset mu and log_covariance per sample in a balanced way, so that the proportion
of samples per subset correspond to w_modalities.
"""
num_components = mus.shape[0] # number of components
num_samples = mus.shape[1]
idx_start = []
idx_end = []
for k in range(0, num_components):
if k == 0:
i_start = 0
else:
i_start = int(idx_end[k - 1])
if k == w_modalities.shape[0] - 1:
i_end = num_samples
else:
i_end = i_start + int(torch.floor(num_samples * w_modalities[k]))
idx_start.append(i_start)
idx_end.append(i_end)
idx_end[-1] = num_samples
mu_sel = torch.cat(
[mus[k, idx_start[k] : idx_end[k], :] for k in range(w_modalities.shape[0])]
)
logvar_sel = torch.cat(
[
logvars[k, idx_start[k] : idx_end[k], :]
for k in range(w_modalities.shape[0])
]
)
return [mu_sel, logvar_sel]
[docs]
@torch.no_grad()
def compute_joint_nll(
self,
inputs: Union[MultimodalBaseDataset, IncompleteDataset],
K: int = 1000,
batch_size_K: int = 100,
):
"""Estimate the negative joint likelihood.
In the original code, the product of experts is used as inference distribution
for computing the nll instead of the MoPoe, but that is less coherent with the definition of the
MoPoE definition as the joint posterior.
Args:
inputs (MultimodalBaseDataset) : a batch of samples.
K (int) : the number of importance samples for the estimation. Default to 1000.
batch_size_K (int) : Default to 100.
Returns:
The negative log-likelihood summed over the batch.
"""
# Check that the dataset is complete
self.eval()
if hasattr(inputs, "masks"):
raise AttributeError(
"The compute_joint_nll method is not yet implemented for incomplete datasets."
)
# Compute the parameters of each subset posterior
infer = self.inference(inputs)
mu, log_var = infer["joint"] # a random subset is selected for each sample
mus_subset = infer["mus"]
log_vars_subset = infer["logvars"]
# And sample from the posterior
z_joint = rsample_from_gaussian(
mu, log_var, N=K
) # shape K x n_data x latent_dim
z_joint = z_joint.permute(1, 0, 2)
n_data, _, _ = z_joint.shape
# If using multiple latent spaces, sample from the private latent spaces as well
private_params = {}
private_zs = {}
if self.multiple_latent_spaces:
for mod in inputs.data:
private_params[mod] = (
infer["modalities"][mod].style_embedding,
infer["modalities"][mod].style_log_covariance,
)
style_embeddings = rsample_from_gaussian(*private_params[mod], N=K)
private_zs[mod] = style_embeddings.permute(
1, 0, 2
) # shape n_data x K x private dim
# Then iter on each datapoint to compute the iwae estimate of ln(p(x))
ll = 0
for i in range(n_data):
start_idx = 0
stop_idx = min(start_idx + batch_size_K, K)
lnpxs = []
while start_idx < stop_idx:
shared_latents = z_joint[i][start_idx:stop_idx]
# Compute ln p(x_m|z) for z in latents and for each modality m
lpx_zs = 0
lpz = 0
lqz_xs = 0
for mod in inputs.data:
if self.multiple_latent_spaces:
private_latents = private_zs[mod][i][start_idx:stop_idx]
full_embedding = torch.cat(
[shared_latents, private_latents], dim=-1
)
else:
full_embedding = shared_latents
decoder = self.decoders[mod]
recon = decoder(full_embedding)[
"reconstruction"
] # (batch_size_K, nb_channels, w, h)
x_m = inputs.data[mod][i] # (nb_channels, w, h)
lpx_zs += (
self.recon_log_probs[mod](
recon, torch.stack([x_m] * len(recon))
)
.reshape(recon.size(0), -1)
.sum(-1)
)
# If we are using modalities specific latent spaces compute the modalities priors and posteriors
if self.multiple_latent_spaces:
lpz += dist.Normal(0, 1).log_prob(private_latents).sum(dim=-1)
qz_x = dist.Normal(
private_params[mod][0][i],
torch.exp(0.5 * private_params[mod][1][i]),
)
lqz_xs += qz_x.log_prob(private_latents).sum(dim=-1)
# Compute ln(p(z))
prior = dist.Normal(0, 1)
lpz += prior.log_prob(shared_latents).sum(dim=-1)
# Compute shared posterior -ln(q(z|x,y) = -ln (1/S \sum q(z|x_s))
qz_xs = [
dist.Normal(
mus_subset[j][i], torch.exp(0.5 * log_vars_subset[j][i])
)
for j in range(len(mus_subset))
]
lqz_xs_tensor = torch.stack(
[q.log_prob(shared_latents).sum(-1) for q in qz_xs]
)
lqz_xs += torch.logsumexp(lqz_xs_tensor, dim=0) - np.log(
len(lqz_xs_tensor)
) # log_mean_exp
ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xs, dim=0)
lnpxs.append(ln_px)
# next batch
start_idx += batch_size_K
stop_idx = min(stop_idx + batch_size_K, K)
ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - np.log(K)
return -ll
@torch.no_grad()
def _compute_joint_nll_from_subset_encoding(
self,
subset,
inputs: Union[MultimodalBaseDataset, IncompleteDataset],
K: int = 1000,
batch_size_K: int = 100,
):
"""Computes the joint negative log-likelihood using the PoE posterior as importance sampling distribution.
The result is summed over the input batch.
"""
# Check that the dataset is complete
self.eval()
if hasattr(inputs, "masks"):
raise AttributeError(
"The compute_joint_nll method is not yet implemented for incomplete datasets."
)
subset_name = "_".join(sorted(subset))
# Compute the parameters of the joint posterior
infer = self.inference(inputs)
mu, log_var = infer["subsets"][subset_name]
# And sample from the posterior
z_joint = rsample_from_gaussian(mu, log_var, K) # shape K x n_data x latent_dim
z_joint = z_joint.permute(1, 0, 2)
n_data, _, _ = z_joint.shape
# If using multiple latent spaces, sample from the private latent spaces as well
private_params = {}
private_zs = {}
if self.multiple_latent_spaces:
for mod in inputs.data:
private_params[mod] = (
infer["modalities"][mod].style_embedding,
infer["modalities"][mod].style_log_covariance,
)
style_embeddings = rsample_from_gaussian(*private_params[mod], K)
private_zs[mod] = style_embeddings.permute(
1, 0, 2
) # shape n_data x K x private dim
# Then iter on each datapoint to compute the iwae estimate of ln(p(x))
ll = 0
for i in range(n_data):
start_idx = 0
stop_idx = min(start_idx + batch_size_K, K)
lnpxs = []
while start_idx < stop_idx:
shared_latents = z_joint[i][start_idx:stop_idx]
# Compute ln p(x_m|z) for z in latents and for each modality m
lpx_zs = 0 # ln(p(x,y|z))
lpz = 0 # ln(p(z)) prior
lqz_xs = 0 # ln(q(z|X)) posterior
for mod in inputs.data:
if self.multiple_latent_spaces:
private_latents = private_zs[mod][i][start_idx:stop_idx]
full_embedding = torch.cat(
[shared_latents, private_latents], dim=-1
)
else:
full_embedding = shared_latents
decoder = self.decoders[mod]
recon = decoder(full_embedding)[
"reconstruction"
] # (batch_size_K, nb_channels, w, h)
x_m = inputs.data[mod][i] # (nb_channels, w, h)
lpx_zs += (
self.recon_log_probs[mod](
recon, torch.stack([x_m] * len(recon))
)
.reshape(recon.size(0), -1)
.sum(-1)
)
# If we are using modalities specific latent spaces compute the modalities priors and posteriors
if self.multiple_latent_spaces:
lpz += dist.Normal(0, 1).log_prob(private_latents).sum(dim=-1)
qz_x = dist.Normal(
private_params[mod][0][i],
torch.exp(0.5 * private_params[mod][1][i]),
)
lqz_xs += qz_x.log_prob(private_latents).sum(dim=-1)
# Compute ln(p(z))
prior = dist.Normal(0, 1)
lpz += prior.log_prob(shared_latents).sum(dim=-1)
# Compute posteriors -ln(q(z|x,y))
qz_xy = dist.Normal(mu[i], torch.exp(0.5 * log_var[i]))
lqz_xs += qz_xy.log_prob(shared_latents).sum(dim=-1)
ln_px = torch.logsumexp(lpx_zs + lpz - lqz_xs, dim=0)
lnpxs.append(ln_px)
# next batch
start_idx += batch_size_K
stop_idx = min(stop_idx + batch_size_K, K)
ll += torch.logsumexp(torch.Tensor(lnpxs), dim=0) - np.log(K)
return -ll
[docs]
def compute_joint_nll_paper(
self,
inputs: Union[MultimodalBaseDataset, IncompleteDataset],
K: int = 1000,
batch_size_K: int = 100,
):
"""Estimates the negative joint likelihood using the PoE posterior as importance sampling distribution.
The result is summed over the input batch.
This is the method used in the original paper implementation.
"""
entire_subset = list(self.encoders.keys())
return self._compute_joint_nll_from_subset_encoding(
entire_subset, inputs, K, batch_size_K
)