IncompleteDataset
- class multivae.data.datasets.IncompleteDataset(data, masks, labels=None)[source]
This class is the base class for datasets with incomplete data. We add a field masks to indicate which data samples are available. This is used with models compatible with partial data. A
__getitem__is redefined and outputs a python dictionary with the keys corresponding to data and masks (optionally labels). This class should be used for any new incomplete datasets.If you want, you can also create your own dataset class, inheriting from IncompleteDataset and overwriting the __getitem__ function. (Just make sure the output format stays the same).
For instance:
>>> from multivae.data.datasets import IncompleteDataset, DatasetOutput >>> >>> class MyDataset(IncompleteDataset): ... def __init__(self, my_arguments): ... # your code ... ... def __getitem__(self, index): ... # your code ... your_data = { ... 'modality_name_1' : .... ... 'modality_name_2' : ... ... } ... # Warning : if 'modality_name_2' is unavailable for this index ... # Artificially fill the value data['modality_name_2'] with ... # a zero-tensor (or any value you want, it doesn't matter) OF THE RIGHT SHAPE. ... # Otherwise MultiVae models won't work. ... ... your_masks = { 'modality_name_1' : True, ... 'modality_name_2' : False # set to False is the modality is unavailable. ... } ... ... ... return DatasetOutput( ... data = your_data # must be a Dict[str, Tensor], ... masks = your_masks # must be a Dict[str, 1d Tensor], ... labels = your_labels # optional : don't add this field if you don't have labels ... )
Warning
If you intend to define your own IncompleteDataset subclass, please take a close look at the code snippet before doing so.
- Parameters:
data (dict[str, torch.Tensor]) – A dictionary containing the modalities’ name and a tensor or numpy array for each modality.
masks (dict[str, torch.Tensor]) – A dictionary containing the modalities’name and a boolean tensor of the same lenght as the data tensor in the data dictionary. For each modality, the mask tensor indicates if a sample is available. The unavailable samples are assumed to have been filled by random/or zeros values in the data dictionary.
labels (Union[Tensor, numpy.ndarray]) – A torch.Tensor or numpy.ndarray instance containing the labels.