Source code for multivae.models.base.base_config

from dataclasses import field
from typing import Dict, Literal, Optional, Tuple, Union

from pydantic.dataclasses import dataclass
from pythae.config import BaseConfig


[docs] @dataclass class BaseMultiVAEConfig(BaseConfig): """This is the base config for a Multi-Modal VAE model. Args: n_modalities (int): The number of modalities. Default: None. latent_dim (int): The dimension of the latent space. Default: None. input_dims (dict[str,tuple]) : The modalities'names (str) and input shapes (tuple). uses_likelihood_rescaling (bool): To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False. rescale_factors (dict[str, float]): The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None. decoders_dist (Dict[str, Union[function, str]]). The decoder distributions to use per modality. Per modality, you can provide a string in ['normal','bernoulli','laplace']. For Bernoulli distribution, the decoder is expected to output **logits**. If None is provided, a normal distribution is used for each modality. decoder_dist_params (Dict[str,dict]) : Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with :code:`decoder_dist_params = {'mod1' : {"scale" : 0.75}}`. """ n_modalities: int latent_dim: int = 10 input_dims: Optional[dict] = None uses_likelihood_rescaling: bool = False rescale_factors: Optional[dict] = None decoders_dist: Union[ Dict[str, Literal["normal", "bernoulli", "laplace", "categorical"]], None ] = None decoder_dist_params: Union[dict, None] = None custom_architectures: list = field(default_factory=lambda: []) def __post_init__(self): super().__post_init__() if self.input_dims is not None: self.input_dims = {k: tuple(self.input_dims[k]) for k in self.input_dims}
@dataclass class EnvironmentConfig(BaseConfig): """Base environment config to save python version.""" python_version: str = "3.8" @dataclass class BaseAEConfig(BaseConfig): """This is the base configuration instance of encoders/decoders models deriving from :class:`~pythae.config.BaseConfig`. Args: input_dim (tuple): The input_data dimension (channels X x_dim X y_dim) latent_dim (int): The latent space dimension. Default: None. style_dim (int) : For models with private latent spaces for each modality. Default: 0. """ input_dim: Union[Tuple[int, ...], None] = None latent_dim: int = 10 style_dim: int = 0