MultimodalBaseDataset
- class multivae.data.datasets.MultimodalBaseDataset(data, labels=None)[source]
This class is the base class for datasets. A
__getitem__is redefined and outputs a python dictionary with the keys corresponding to data and labels. You can use this class to define new datasets.If you want, you can also create your own dataset class, inheriting from MultimodalBaseDataset and overwriting the __getitem__ function. (Just make sure the output format stays the same). For instance:
>>> from multivae.data.datasets import MultimodalBaseDataset, DatasetOutput >>> >>> class MyDataset(MultimodalBaseDataset): ... def __init__(self, my_arguments): ... # your code ... ... def __getitem__(self, index): ... # your code ... ... return DatasetOutput( ... data = your_data # must be a Dict[str, Tensor], ... labels = your_labels # optional : don't add this field if you don't have labels ... )
- Parameters:
data (dict) – A dictionary containing the modalities’ name and a tensor or numpy array for each modality.
labels (Union[torch.Tensor, numpy.ndarray]) – A torch.Tensor or numpy.ndarray instance containing the labels.
- transform_for_plotting(tensor, modality)[source]
A function that to override in subclasses if you want to transform a tensor data for plotting. This function is called by the BaseTrainer to visualize generations during training and by the Visualization Module.
For instance: if you have a 3D dimensional images, you might want to visualize generations during training but only 2D images can be logged to wandb. In that case, you can override this function in your dataset class. For instance with;
>>> def transform_for_plotting(self, tensor, modality): ... if modality == '3Dimage': ... return tensor[:, 0, :, :] # select a slice ... return tensor