Source code for multivae.metrics.likelihoods.likelihoods

from pythae.models.base.base_utils import ModelOutput

from ..base.evaluator_class import Evaluator
from .likelihoods_config import LikelihoodsEvaluatorConfig

try:
    from tqdm import tqdm
except:
    tqdm = lambda x: x

from multivae.data.utils import set_inputs_to_device


[docs] class LikelihoodsEvaluator(Evaluator): """Class for computing likelihood metrics. Args: model (BaseMultiVAE) : The model to evaluate. test_dataset (MultimodalBaseDataset) : The dataset to use for computing the metrics. output (str) : The folder path to save metrics. The metrics will be saved in a metrics.txt file. eval_config (EvaluatorConfig) : The configuration class to specify parameters for the evaluation. """ def __init__( self, model, test_dataset, output=None, eval_config=LikelihoodsEvaluatorConfig() ) -> None: super().__init__(model, test_dataset, output, eval_config) self.num_samples = eval_config.num_samples self.batch_size_k = eval_config.batch_size_k self.unified = eval_config.unified_implementation def eval(self): self.joint_nll() self.log_to_wandb() return ModelOutput(**self.metrics) def joint_nll(self): ll = 0 for batch in tqdm(self.test_loader): batch = set_inputs_to_device(batch, self.device) if self.unified or (not hasattr(self.model, "compute_joint_nll_paper")): ll += self.model.compute_joint_nll( batch, self.num_samples, self.batch_size_k ) else: self.logger.info("Using the paper version of the joint nll.") ll += self.model.compute_joint_nll_paper( batch, self.num_samples, self.batch_size_k ) joint_nll = ll / len(self.test_loader.dataset) self.logger.info(f"Mean Joint likelihood : {str(joint_nll)}") self.metrics["joint_likelihood"] = joint_nll return joint_nll
[docs] def joint_nll_from_subset(self, subset): """Only available for the MoPoE model for now. Use a subset posterior instead of the joint posterior as the importance sampling distribution. """ if hasattr(self.model, "_compute_joint_nll_from_subset_encoding"): ll = 0 nb_batch = 0 for batch in self.test_loader: batch = set_inputs_to_device(batch, self.device) ll += self.model._compute_joint_nll_from_subset_encoding( subset, batch, self.num_samples, self.batch_size_k ) nb_batch += 1 joint_nll = ll / self.n_data self.logger.info("Joint likelihood from subset %s", str(joint_nll)) self.metrics[f"Joint likelihood from subset {subset}"] = joint_nll return joint_nll else: return None