Source code for multivae.models.auto_model.auto_config

from pydantic.dataclasses import dataclass
from pythae.config import BaseConfig


[docs] @dataclass class AutoConfig(BaseConfig): """Class to reload any multivae.models configuration."""
[docs] @classmethod def from_json_file(cls, json_path): """Creates a :class:`~multivae.config.BaseMultiVAEConfig` instance from a JSON config file. It builds automatically the correct config for any `multivae.models`. Args: json_path (str): The path to the json file containing all the parameters Returns: :class:`BaseMultiVAEConfig`: The created instance """ config_dict = cls._dict_from_json(json_path) config_name = config_dict.pop("name") if config_name == "BaseMultiVAEConfig": from ..base import BaseMultiVAEConfig model_config = BaseMultiVAEConfig.from_json_file(json_path) elif config_name == "JMVAEConfig": from ..jmvae import JMVAEConfig model_config = JMVAEConfig.from_json_file(json_path) elif config_name == "JNFConfig": from ..jnf import JNFConfig model_config = JNFConfig.from_json_file(json_path) elif config_name == "MMVAEConfig": from ..mmvae import MMVAEConfig model_config = MMVAEConfig.from_json_file(json_path) elif config_name == "TELBOConfig": from ..telbo import TELBOConfig model_config = TELBOConfig.from_json_file(json_path) elif config_name == "MVAEConfig": from ..mvae import MVAEConfig model_config = MVAEConfig.from_json_file(json_path) elif config_name == "MoPoEConfig": from ..mopoe import MoPoEConfig model_config = MoPoEConfig.from_json_file(json_path) elif config_name == "MVTCAEConfig": from ..mvtcae import MVTCAEConfig model_config = MVTCAEConfig.from_json_file(json_path) elif config_name == "MMVAEPlusConfig": from ..mmvaePlus import MMVAEPlusConfig model_config = MMVAEPlusConfig.from_json_file(json_path) elif config_name == "NexusConfig": from ..nexus import NexusConfig model_config = NexusConfig.from_json_file(json_path) elif config_name == "CVAEConfig": from ..cvae import CVAEConfig model_config = CVAEConfig.from_json_file(json_path) elif config_name == "MHVAEConfig": from ..mhvae import MHVAEConfig model_config = MHVAEConfig.from_json_file(json_path) elif config_name == "DMVAEConfig": from ..dmvae import DMVAEConfig model_config = DMVAEConfig.from_json_file(json_path) elif config_name == "CMVAEConfig": from ..cmvae import CMVAEConfig model_config = CMVAEConfig.from_json_file(json_path) elif config_name == "CRMVAEConfig": from ..crmvae import CRMVAEConfig model_config = CRMVAEConfig.from_json_file(json_path) else: raise NameError( "Cannot reload automatically the model configuration... " f"The model name in the `model_config.json may be corrupted. Got `{config_name}`" ) return model_config