Source code for multivae.models.telbo.telbo_config

from typing import Union

from pydantic.dataclasses import dataclass

from ..joint_models import BaseJointModelConfig


[docs] @dataclass class TELBOConfig(BaseJointModelConfig): """Configuration class for the TELBO model from (arXiv:1705.10762 [cs, stat]) "Generative Models of Visually Grounded Imagination" (Vedantam et al,2018). Args: n_modalities (int): The number of modalities. Default: None. latent_dim (int): The dimension of the latent space. Default: None. input_dims (dict[str,tuple]) : The modalities'names (str) and input shapes (tuple). uses_likelihood_rescaling (bool): To mitigate modality collapse, it is possible to use likelihood rescaling. (see : https://proceedings.mlr.press/v162/javaloy22a.html). The inputs_dim must be provided to compute the likelihoods rescalings. It is used in a number of models which is why we include it here. Default to False. rescale_factors (dict[str, float]): The reconstruction rescaling factors per modality. If None is provided but uses_likelihood_rescaling is True, a default value proportional to the input modality size is computed. Default to None. decoders_dist (Dict[str, Union[function, str]]). The decoder distributions to use per modality. Per modality, you can provide a string in ['normal','bernoulli','laplace']. For Bernoulli distribution, the decoder is expected to output **logits**. If None is provided, a normal distribution is used for each modality. decoder_dist_params (Dict[str,dict]) : Parameters for the output decoder distributions, for computing the log-probability. For instance, with normal or laplace distribution, you can pass the scale in this dictionary with :code:`decoder_dist_params = {'mod1' : {"scale" : 0.75}}`. warmup (int): How many epochs to train the joint encoder and decoders before freezing them and learn the unimodal encoders. It is recommended to use half of the total training time for the first step. Default to 10. lambda_factors (dict[str,float]) : Ponderation factors for the reconstructions in the Joint Elbo. If None is provided but uses_likelihood_rescaling is True, we use the inverse product of dimensions as a rescaling factor for each modality. If None is provided and uses_likelihood_rescaling is False, each factor is set to one. Default to None. gamma_factors (dict[str,float]) : Ponderation factors for the reconstructions in the unimodal elbos. If None is provided but uses_likelihood_rescaling is True, we use the inverse product of dimensions as a rescaling factor for each modality. If None is provided and uses_likelihood_rescaling is False, each factor is set to one. Default to None. uses_likelihood_rescaling (bool) : Indicates how to set lambda or gamma factors when None are provided. Ignored when lambda_factors and gamma_factors are provided. Default to True. """ warmup: int = 10 lambda_factors: Union[dict, None] = None gamma_factors: Union[dict, None] = None uses_likelihood_rescaling: bool = True