Source code for multivae.models.nn.base_architectures

from typing import Dict

import torch
from pythae.models.base.base_model import BaseDecoder, BaseEncoder


[docs] class BaseJointEncoder(BaseEncoder): """This is a base class for Joint Encoders neural networks.""" def __init__(self): BaseEncoder.__init__(self) self.latent_dim = None # to be set in child class
[docs] def forward(self, x: dict): r"""This function must be implemented in a child class. It takes the input data and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput`. If you decide to provide your own joint encoder network, you must make sure your model inherit from this class by setting and then defining your forward function as below. .. code-block:: >>> from multivae.models.nn import BaseEncoder >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Joint_Encoder(BaseEncoder): ... ... def __init__(self): ... BaseEncoder.__init__(self) ... # your code ... self.latent_dim = ... ... ... def forward(self, x: dict): ... # x is a dict with a tensor for each modality ... # your code ... output = ModelOutput( ... embedding=embedding, ... log_covariance=log_var # for VAE based models ... ) ... return output Args: x (dict): Multimodal input to encode : a dictionary that contains modalities' names as keys and modalities' data as values. Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the encoder """ raise NotImplementedError()
[docs] class BaseMultilatentEncoder(BaseEncoder): """This is a base class for for encoders with multiple latent spaces.""" def __init__(self): BaseEncoder.__init__(self) self.latent_dim = None # to be set in child class self.style_dim = None # to be set in child class
[docs] def forward(self, x: torch.Tensor): r"""This function must be implemented in a child class. It takes the input tensor x and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput` with the parameters for the shared latent space and the modality-specific latent space. If you decide to provide your own encoder network in a model that uses multiple latent spaces, you must make sure your model inherits from this class by setting and then defining your forward function as below. .. code-block:: >>> from multivae.models.nn import BaseMultilatentEncoder >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Encoder(BaseMultilatentEncoder): ... ... def __init__(self): ... BaseMultilatentEncoder.__init__(self) ... # your code ... self.latent_dim = ... ... self.style_dim = ... ... ... def forward(self, x): ... # your code ... output = ModelOutput( ... embedding= ..., # shared latent space log_covariance=..., style_embedding=..., # modality-specific latent space style_log_covariance=..., ... ) ... return output Args: x (torch.Tensor): Input data Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the encoder. """ raise NotImplementedError()
[docs] class BaseConditionalDecoder(BaseDecoder): """This is a base class for Conditional Decoders architectures.""" def __init__(self): BaseDecoder.__init__(self) self.latent_dim = None # to be set in child class
[docs] def forward(self, z: torch.Tensor, cond_mods: Dict[str, torch.Tensor]): r"""This function must be implemented in a child class. It takes the latent variable z and conditioning modality and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput` with the reconstruction. If you decide to provide your own decoder network, you must make sure your model inherit from this class by setting and then defining your forward function as below. .. code-block:: >>> from pythae.models.nn import BaseConditionalDecoder >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Conditional_Decoder(BaseConditionalDecoder): ... ... def __init__(self): ... BaseConditionalDecoder.__init__(self) ... # your code ... self.latent_dim = ... ... ... def forward(self, z, cond_mods): ... # your code ... output = ModelOutput( ... reconstruction= ... ... ) ... return output Args: z (torch.Tensor): Latent variable cond_mods (Dic[str, torch.Tensor]): Conditioning data. Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the decoder. """ raise NotImplementedError()