Source code for multivae.models.auto_model.auto_model

import json
import logging
import os

from torch import nn

from ..base.base_utils import hf_hub_is_available

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class AutoModel(nn.Module): """Utils class allowing to reload any :class:`multivae.models` automatically.""" def __init__(self) -> None: super().__init__()
[docs] @classmethod def load_from_folder(cls, dir_path: str): """Class method to be used to load the model from a specific folder. Args: dir_path (str): The path where the model should have been be saved. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoders.pkl`` (resp. ``decoders.pkl``) if a custom encoders (resp. decoders) were provided """ with open(os.path.join(dir_path, "model_config.json")) as f: model_name = json.load(f)["name"] if model_name == "JMVAEConfig": from ..jmvae import JMVAE model = JMVAE.load_from_folder(dir_path=dir_path) elif model_name == "JNFConfig": from ..jnf import JNF model = JNF.load_from_folder(dir_path=dir_path) elif model_name == "MMVAEConfig": from ..mmvae import MMVAE model = MMVAE.load_from_folder(dir_path=dir_path) elif model_name == "TELBOConfig": from ..telbo import TELBO model = TELBO.load_from_folder(dir_path) elif model_name == "MVAEConfig": from ..mvae import MVAE model = MVAE.load_from_folder(dir_path) elif model_name == "MoPoEConfig": from ..mopoe import MoPoE model = MoPoE.load_from_folder(dir_path) elif model_name == "MVTCAEConfig": from ..mvtcae import MVTCAE model = MVTCAE.load_from_folder(dir_path) elif model_name == "MMVAEPlusConfig": from ..mmvaePlus import MMVAEPlus model = MMVAEPlus.load_from_folder(dir_path) elif model_name == "NexusConfig": from ..nexus import Nexus model = Nexus.load_from_folder(dir_path) elif model_name == "CVAEConfig": from ..cvae import CVAE model = CVAE.load_from_folder(dir_path) elif model_name == "MHVAEConfig": from ..mhvae import MHVAE model = MHVAE.load_from_folder(dir_path) elif model_name == "DMVAEConfig": from ..dmvae import DMVAE model = DMVAE.load_from_folder(dir_path) elif model_name == "CMVAEConfig": from ..cmvae import CMVAE model = CMVAE.load_from_folder(dir_path) elif model_name == "CRMVAEConfig": from ..crmvae import CRMVAE model = CRMVAE.load_from_folder(dir_path) else: raise NameError( "Cannot reload automatically the model... " f"The model name in the `model_config.json may be corrupted. Got {model_name}" ) return model
[docs] @classmethod def load_from_hf_hub( cls, hf_hub_path: str, allow_pickle: bool = False ): # pragma: no cover """Class method to be used to load a automaticaly a pretrained model from the Hugging Face hub. Args: hf_hub_path (str): The path where the model should have been be saved on the hugginface hub. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. ``decoder.pkl``) if a custom encoder (resp. decoder) was provided """ if not hf_hub_is_available(): raise ModuleNotFoundError( "`huggingface_hub` package must be installed to load models from the HF hub. " "Run `python -m pip install huggingface_hub` and log in to your account with " "`huggingface-cli login`." ) else: from huggingface_hub import hf_hub_download logger.info("Downloading config file ...") config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") dir_path = os.path.dirname(config_path) with open(os.path.join(dir_path, "model_config.json")) as f: model_name = json.load(f)["name"] if model_name == "JMVAEConfig": from ..jmvae import JMVAE model = JMVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "JNFConfig": from ..jnf import JNF model = JNF.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MMVAEConfig": from ..mmvae import MMVAE model = MMVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "TELBOConfig": from ..telbo import TELBO model = TELBO.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MVAEConfig": from ..mvae import MVAE model = MVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MoPoEConfig": from ..mopoe import MoPoE model = MoPoE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MVTCAEConfig": from ..mvtcae import MVTCAE model = MVTCAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MMVAEPlusConfig": from ..mmvaePlus import MMVAEPlus model = MMVAEPlus.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "NexusConfig": from ..nexus import Nexus model = Nexus.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "CVAEConfig": from ..cvae import CVAE model = CVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "DMVAEConfig": from ..dmvae import DMVAE model = DMVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "CMVAEConfig": from ..cmvae import CMVAE model = CMVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "MHVAEConfig": from ..mhvae import MHVAE model = MHVAE.load_from_hf_hub(hf_hub_path, allow_pickle) elif model_name == "CRMVAEConfig": from ..crmvae import CRMVAE model = CRMVAE.load_from_hf_hub(hf_hub_path, allow_pickle) else: raise NameError( "Cannot reload automatically the model... " f"The model name in the `model_config.json may be corrupted. Got {model_name}" ) return model