import os
from pathlib import Path
import numpy as np
import torch
from torch.nn.functional import one_hot
from .base import DatasetOutput, IncompleteDataset
def unstack_tensor(tensor, dim=0):
tensor_lst = []
for i in range(tensor.size(dim)):
tensor_lst.append(tensor[i])
tensor_unstack = torch.cat(tensor_lst, dim=0)
return tensor_unstack
[docs]
class MHD(IncompleteDataset): # pragma: no cover
"""Dataset class for the MHD dataset introduced in the paper:
'Leveraging hierarchy in multimodal generative models for effective
cross-modality inference' (Vasco et al, 2021).'
In this version of the dataset class, we add the possibility to
simulate missingness in the data, depending on the dataclass (Missing Not At Random).
For that, the `missing_probabilities` parameter provides probabilities of missingness for each class,
and for each modality. For instance, the code below will define a dataset with missing samples in the
trajectory modality, only in the classes 0,1,2, et 9.
.. code-block:: python
>>> from multivae.data.datasets import MHD
>>> missing_probabilities = {
... image = np.zeros(10,).float(),
... audio = np.zeros(10,).float(),
... trajectory = [0.1,0.3,0.4,0.,0.,0.,0.,0.,0.,0.9]
... }
>>> dataset = MHD(data_path,
... 'train',
... modalities = ['image', 'audio', 'trajectory'],
... download = True,
... missing_probabilities = missing_probabilities)
Args:
datapath (str) : Where the data is stored. It must contained the 'mhd_train.pt' file and
'mhd_test.pt' file.
split (Literal['train', 'test']) : Split of the data to use. Default to 'train'.
modalities (list) : The modalities to use among 'label', 'trajectory', 'image', 'audio'.
By default, we use all.
download (bool) : If the dataset is not present at the given path, wether to download it or not.
Default to False.
missing_probabilities (dict) : For each modality, the probabilities for each class
to be missing in the created incomplete dataset. By default, we use no missing data.
seed (int) : default to 0. You can change the seed to create a different incomplete dataset.
"""
def __init__(
self,
datapath: str,
split="train",
modalities: list = ["label", "audio", "trajectory", "image"],
download=False,
missing_probabilities=dict(
label=[0.0] * 10, audio=[0.0] * 10, trajectory=[0.0] * 10, image=[0.0] * 10
),
seed=0,
keep_incomplete=True,
):
self.data_file = os.path.join(datapath, f"mhd_{split}.pt")
self.modalities = modalities
if not os.path.exists(self.data_file):
if not download:
raise RuntimeError(
f"Dataset not found at path {datapath} and download is set to False. "
"Please change the path or set download to True"
)
else:
try:
self.__download__(split, datapath)
except:
raise RuntimeError(
"gdown must be installed to download the dataset automatically."
"Install gdown with "
' "pip install gdown" or download the dataset manually at the following url'
"train : https://docs.google.com/uc?export=download&id=1Tj1i-hXA0INQpU0jmuTMO4IwfDoGD2oV"
"test : https://docs.google.com/uc?export=download&id=1qiEjFNCFn1ws383pKmY3zJtm4JDymOU6"
)
(
self._s_data,
self._i_data,
self._t_data,
self._a_data,
self._traj_normalization,
self._audio_normalization,
) = torch.load(self.data_file)
self.data = dict()
if "image" in modalities:
self.data["image"] = self._i_data
if "label" in modalities:
self.data["label"] = one_hot(self._s_data, num_classes=10).float()
if "trajectory" in modalities:
self.data["trajectory"] = self._t_data
if "audio" in modalities:
self.data["audio"] = self._a_data
self.labels = self._s_data
self.n_data = len(self._s_data)
self.is_incomplete = (
sum([sum(missing_probabilities[s]) for s in missing_probabilities]) != 0
)
self.keep_incomplete = keep_incomplete
if self.is_incomplete:
# generate the masks
self.masks = {}
for i, mod in enumerate(self.data):
# randomly define the missing samples.
p = 1 - torch.tensor(missing_probabilities[mod])[self._s_data]
self.masks[mod] = torch.bernoulli(
p, generator=torch.Generator().manual_seed(seed + i)
).bool()
# To be sure, also erase the content of the masked samples
for k in self.masks:
reverse_dim_order = tuple(np.arange(len(self.data[k].shape))[::-1])
self.data[k] = self.data[k].permute(*reverse_dim_order).float()
# now the batch dimension is last
self.data[k] *= self.masks[k].float() # erase missing samples
# put dimensions back in order
self.data[k] = self.data[k].permute(*reverse_dim_order)
if not self.keep_incomplete:
# take the intersection of the modality masks
global_mask = torch.ones((self.n_data))
for m in self.modalities:
global_mask = global_mask * self.masks[m]
# only keep the samples where all modalities are available
self.data = {k: self.data[k][global_mask.bool()] for k in self.data}
self.n_data = torch.sum(global_mask.bool()).item()
def __download__(self, split, datapath): # pragram : no cover
import gdown
if not os.path.exists(datapath):
os.makedirs(Path(datapath), exist_ok=True)
if split == "train":
gdown.download(
"https://docs.google.com/uc?export=download&id=1Tj1i-hXA0INQpU0jmuTMO4IwfDoGD2oV",
output=os.path.join(datapath, f"mhd_{split}.pt"),
)
else:
gdown.download(
"https://docs.google.com/uc?export=download&id=1qiEjFNCFn1ws383pKmY3zJtm4JDymOU6",
output=os.path.join(datapath, f"mhd_{split}.pt"),
)
def __getitem__(self, index):
"""Args:
index (int): Index
Returns:
tuple: (t_data, m_data, f_data)
"""
data = {s: self.data[s][index] for s in self.data}
if "audio" in data:
# Audio modality is a 3x32x32 representation, need to unstack!
audio = unstack_tensor(data["audio"]).unsqueeze(0)
data["audio"] = audio.permute(0, 2, 1)
if not self.is_incomplete or not self.keep_incomplete:
return DatasetOutput(data=data, labels=self._s_data[index])
else:
masks = {s: self.masks[s][index] for s in self.data}
return DatasetOutput(data=data, labels=self._s_data[index], masks=masks)
def __len__(self):
return self.n_data
def get_audio_normalization(self):
return self._audio_normalization
def get_traj_normalization(self):
return self._traj_normalization