Source code for multivae.trainers.base.base_trainer

import datetime
import json
import logging
import os
from copy import deepcopy
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.utils import make_grid, save_image

from ...data import MultimodalBaseDataset
from ...data.datasets.utils import adapt_shape
from ...data.utils import set_inputs_to_device
from ...models import BaseModel, BaseMultiVAE
from .base_trainer_config import BaseTrainerConfig
from .callbacks import (
    CallbackHandler,
    MetricConsolePrinterCallback,
    ProgressBarCallback,
    TrainingCallback,
)
from .utils import set_seed, update_dict

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs] class BaseTrainer: """Base class to perform model training. Args: model (BaseModel): A instance of :class:`~multivae.models.BaseMultiVAE` to train. train_dataset (MultimodalBaseDataset): The training dataset of type :class:`~multivae.data.datasets.MultimodalBaseDataset` eval_dataset (MultimodalBaseDataset): The evaluation dataset of type :class:`~multivae.data.datasets.MultimodalBaseDataset` training_config (BaseTrainerConfig): The training arguments summarizing the main parameters used for training. If None, a basic training instance of :class:`BaseTrainerConfig` is used. Default: None. callbacks (List[~pythae.trainers.training_callbacks.TrainingCallback]): A list of callbacks to use during training. checkpoint (str) : The directory path of the checkpoint generated by the `save_checkpoint` function. Default to None. Used when training is resumed from a previous checkpoint. """ def __init__( self, model: BaseModel, train_dataset: MultimodalBaseDataset, eval_dataset: Optional[MultimodalBaseDataset] = None, training_config: Optional[BaseTrainerConfig] = None, callbacks: List[TrainingCallback] = None, checkpoint: str = None, ): if training_config is None: if checkpoint is None: training_config = BaseTrainerConfig() else: training_config = BaseTrainerConfig.from_json_file( os.path.join(checkpoint, "training_config.json") ) if training_config.output_dir is None: output_dir = "dummy_output_dir" training_config.output_dir = output_dir self.training_config = training_config self.model_config = model.model_config self.model_name = model.model_name # for distributed training self.world_size = self.training_config.world_size self.local_rank = self.training_config.local_rank self.rank = self.training_config.rank self.dist_backend = self.training_config.dist_backend if self.world_size > 1: self.distributed = True else: self.distributed = False if self.distributed: device = self._setup_devices() else: device = ( "cuda" if torch.cuda.is_available() and not self.training_config.no_cuda else "cpu" ) self.device = device if checkpoint is not None: model = model.load_from_folder(checkpoint) # place model on device model = model.to(device) model.device = device if self.distributed: model = DDP(model, device_ids=[self.local_rank]) self.train_dataset = train_dataset self.eval_dataset = eval_dataset # Define the loaders train_loader = self.get_train_dataloader(train_dataset) if eval_dataset is not None: eval_loader = self.get_eval_dataloader(eval_dataset) else: logger.info( "! No eval dataset provided ! -> keeping best model on train.\n" ) self.training_config.keep_best_on_train = True eval_loader = None self.train_loader = train_loader self.eval_loader = eval_loader self.callbacks = callbacks self.start_keep_best_epoch = getattr(model, "start_keep_best_epoch", 0) # run sanity check on the model self._run_model_sanity_check(model, train_loader) if self.is_main_process: logger.info("Model passed sanity check !\n" "Ready for training.\n") # Assert that the trainer is suited for the chosen model self.checktrainer(model) self.model = model if checkpoint is None: self.prepare_training() else: self.resume_training(checkpoint) def checktrainer(self, model): if hasattr(model, "reset_optimizer_epochs"): if len(model.reset_optimizer_epochs) != 0: raise AttributeError( f"The model {self.model_name} has a 'reset_optimizer_epochs' attribute ", "that is not empty. That means that it requires multistage training and therefore you", "should use the ~multivae.trainers.MultistageTrainer instead of the BaseTrainer.", ) @property def is_main_process(self): if self.rank == 0 or self.rank == -1: return True else: return False def _setup_devices(self): """Sets up the devices to perform distributed training.""" if dist.is_available() and dist.is_initialized() and self.local_rank == -1: logger.warning( "torch.distributed process group is initialized, but local_rank == -1. " ) if self.training_config.no_cuda: self._n_gpus = 0 device = "cpu" else: torch.cuda.set_device(self.local_rank) device = torch.device("cuda", self.local_rank) if not dist.is_initialized(): dist.init_process_group( backend=self.dist_backend, init_method="env://", world_size=self.world_size, rank=self.rank, ) return device def get_train_dataloader( self, train_dataset: MultimodalBaseDataset ) -> torch.utils.data.DataLoader: if self.distributed: train_sampler = DistributedSampler( train_dataset, num_replicas=self.world_size, rank=self.rank ) else: train_sampler = None return DataLoader( dataset=train_dataset, batch_size=self.training_config.per_device_train_batch_size, num_workers=self.training_config.train_dataloader_num_workers, shuffle=(train_sampler is None), sampler=train_sampler, drop_last=self.training_config.drop_last, ) def get_eval_dataloader( self, eval_dataset: MultimodalBaseDataset ) -> torch.utils.data.DataLoader: if self.distributed: eval_sampler = DistributedSampler( eval_dataset, num_replicas=self.world_size, rank=self.rank ) else: eval_sampler = None return DataLoader( dataset=eval_dataset, batch_size=self.training_config.per_device_eval_batch_size, num_workers=self.training_config.eval_dataloader_num_workers, shuffle=(eval_sampler is None), sampler=eval_sampler, ) def set_optimizer(self): optimizer_cls = getattr(optim, self.training_config.optimizer_cls) logger.info( f"Setting the optimizer with learning rate {self.training_config.learning_rate}" ) if self.training_config.optimizer_params is not None: optimizer = optimizer_cls( self.model.parameters(), lr=self.training_config.learning_rate, **self.training_config.optimizer_params, ) else: optimizer = optimizer_cls( self.model.parameters(), lr=self.training_config.learning_rate ) self.optimizer = optimizer def set_scheduler(self): if self.training_config.scheduler_cls is not None: scheduler_cls = getattr(lr_scheduler, self.training_config.scheduler_cls) if self.training_config.scheduler_params is not None: scheduler = scheduler_cls( self.optimizer, **self.training_config.scheduler_params ) else: scheduler = scheduler_cls(self.optimizer) else: scheduler = None self.scheduler = scheduler def _set_output_dir(self): # Create folder if not os.path.exists(self.training_config.output_dir) and self.is_main_process: os.makedirs(self.training_config.output_dir, exist_ok=True) logger.info( f"Created {self.training_config.output_dir} folder since did not exist.\n" ) self._training_signature = ( str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") ) training_dir = os.path.join( self.training_config.output_dir, f"{self.model_name}_training_{self._training_signature}", ) self.training_dir = training_dir if not os.path.exists(training_dir) and self.is_main_process: os.makedirs(training_dir, exist_ok=True) logger.info( f"Created {training_dir}. \n" "Training config, checkpoints and final model will be saved here.\n" ) def _get_file_logger(self, log_output_dir): log_dir = log_output_dir # if dir does not exist create it if not os.path.exists(log_dir) and self.is_main_process: os.makedirs(log_dir, exist_ok=True) logger.info(f"Created {log_dir} folder since did not exists.") logger.info("Training logs will be recodered here.\n") logger.info(" -> Training can be monitored here.\n") # create and set logger log_name = f"training_logs_{self._training_signature}" file_logger = logging.getLogger(log_name) file_logger.setLevel(logging.INFO) f_handler = logging.FileHandler( os.path.join(log_dir, f"training_logs_{self._training_signature}.log") ) f_handler.setLevel(logging.INFO) file_logger.addHandler(f_handler) # Do not output logs in the console file_logger.propagate = False return file_logger def _setup_callbacks(self): if self.callbacks is None: self.callbacks = [TrainingCallback()] self.callback_handler = CallbackHandler( callbacks=self.callbacks, model=self.model ) self.callback_handler.add_callback(ProgressBarCallback()) self.callback_handler.add_callback(MetricConsolePrinterCallback()) def _run_model_sanity_check(self, model, loader): try: inputs = next(iter(loader)) train_dataset = set_inputs_to_device( inputs, device=self.device, keys=["data"] ) model(train_dataset) model.zero_grad() torch.cuda.empty_cache() del inputs, train_dataset # check cuda memory except Exception as e: raise Exception( "Error when calling forward method from model. Potential issues: \n" " - Wrong model architecture -> check encoder, decoder and metric architecture if " "you provide yours \n" " - The data input dimension provided is wrong -> when no encoder, decoder or metric " "provided, a network is built automatically but requires the shape of the flatten " "input data.\n" f"Exception raised: {type(e)} with message: " + str(e) ) from e def _optimizers_step(self, model_output=None): loss = model_output.loss self.optimizer.zero_grad() if self.training_config.gradient_clipping_max_norm is not None: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.training_config.gradient_clipping_max_norm, ) loss.backward() torch.cuda.empty_cache() self.optimizer.step() def _schedulers_step(self, epoch, metrics=None): if self.scheduler is None: pass elif self.training_config.start_lr_scheduler_epoch > epoch: logger.info("Not taking a scheduler step yet.") pass elif isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau): self.scheduler.step(metrics) else: self.scheduler.step()
[docs] def prepare_training(self): """Sets up the trainer for training""" # set random seed set_seed(self.training_config.seed) # set optimizer self.set_optimizer() # set scheduler self.set_scheduler() # create folder for saving self._set_output_dir() # set callbacks self._setup_callbacks() # set already_trained epochs self.trained_epochs = 0 self.best_train_loss = torch.inf self.best_eval_loss = torch.inf self.metrics_best_model = {} # set up the best_model self._best_model = deepcopy(self.model)
[docs] def resume_training(self, checkpoint): """Sets up the trainer for training""" with open(os.path.join(checkpoint, "info_checkpoint.json"), "r") as fp: dict_checkpoint = json.load(fp) with open(os.path.join(checkpoint, "metrics_best_model.json"), "r") as fp: self.metrics_best_model = json.load(fp) # set random seed set_seed(self.training_config.seed) # set optimizer self.set_optimizer() self.optimizer.load_state_dict( torch.load( os.path.join(checkpoint, "optimizer.pt"), map_location=self.device ) ) # set scheduler self.set_scheduler() if self.scheduler is not None: self.scheduler.load_state_dict( torch.load( os.path.join(checkpoint, "scheduler.pt"), map_location=self.device ) ) # create folder for saving self.training_dir = dict_checkpoint["training_dir"] # set callbacks self._setup_callbacks() # set already_trained epochs self.trained_epochs = dict_checkpoint["trained_epochs"] self.best_train_loss = dict_checkpoint["best_train_loss"] self.best_eval_loss = dict_checkpoint["best_eval_loss"] # set up the best_model self._best_model = deepcopy(self.model)
[docs] def prepare_train_step(self, epoch, best_train_loss, best_eval_loss): """Function to operate changes between train_steps such as resetting the optimizer and the best losses values. """ return best_train_loss, best_eval_loss
[docs] def train(self, log_output_dir: str = None): """This function is the main training function Args: log_output_dir (str): The path in which the log will be stored start_epoch (int) : The first epoch to do. Is useful in case of restarting a training after saving a checkpoint. Default to 1. """ self.callback_handler.on_train_begin( training_config=self.training_config, model_config=self.model_config ) log_verbose = False msg = ( f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" " - per_device_train_batch_size: " f"{self.training_config.per_device_train_batch_size}\n" " - per_device_eval_batch_size: " f"{self.training_config.per_device_eval_batch_size}\n" f" - checkpoint saving every: {self.training_config.steps_saving}\n" f"Optimizer: {self.optimizer}\n" f"Scheduler: {self.scheduler}\n" ) if self.is_main_process: logger.info(msg) # set up log file if log_output_dir is not None and self.is_main_process: log_verbose = True file_logger = self._get_file_logger(log_output_dir=log_output_dir) file_logger.info(msg) if self.is_main_process: logger.info("Successfully launched training !\n") for epoch in range( self.trained_epochs + 1, self.training_config.num_epochs + 1 ): self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, train_loader=self.train_loader, eval_loader=self.eval_loader, ) self.best_train_loss, self.best_eval_loss = self.prepare_train_step( epoch, self.best_train_loss, self.best_eval_loss ) epoch_train_loss, epoch_metrics = self.train_step(epoch) metrics = { "train_" + k: ( epoch_metrics[k].item() if isinstance(epoch_metrics[k], torch.Tensor) else epoch_metrics[k] ) for k in epoch_metrics } metrics["train_epoch_loss"] = epoch_train_loss torch.cuda.empty_cache() if self.eval_dataset is not None: epoch_eval_loss, epoch_eval_metrics = self.eval_step(epoch) metrics["eval_epoch_loss"] = epoch_eval_loss update_dict( metrics, { "eval_" + k: ( epoch_metrics[k].item() if isinstance(epoch_metrics[k], torch.Tensor) else epoch_metrics[k] ) for k in epoch_eval_metrics }, ) self._schedulers_step(epoch, epoch_eval_loss) torch.cuda.empty_cache() else: epoch_eval_loss = self.best_eval_loss self._schedulers_step(epoch, epoch_train_loss) if epoch <= self.start_keep_best_epoch: # save the model, don't keep track of the best loss best_model = deepcopy(self.model) self._best_model = best_model self.metrics_best_model = metrics logger.info("New model saved!") elif ( epoch_eval_loss < self.best_eval_loss and not self.training_config.keep_best_on_train ): self.best_eval_loss = epoch_eval_loss best_model = deepcopy(self.model) self._best_model = best_model self.metrics_best_model = metrics logger.info("New best model on eval saved!") elif ( epoch_train_loss < self.best_train_loss and self.training_config.keep_best_on_train ): self.best_train_loss = epoch_train_loss best_model = deepcopy(self.model) self._best_model = best_model self.metrics_best_model = metrics logger.info("New best model on train saved!") # If steps_predict is not None, compute reconstruction images if ( self.training_config.steps_predict is not None and (epoch % self.training_config.steps_predict == 0 or epoch == 1) and self.is_main_process ): metrics_media = self.predict(self._best_model, epoch) self.callback_handler.on_prediction_step( self.training_config, metrics_media=metrics_media, global_step=epoch, ) # Save the reconstructions to folder images = metrics_media.pop("images", {}) for key, image in images.items(): save_image(image, os.path.join(self.training_dir, f"{key}.png")) torch.cuda.empty_cache() self.callback_handler.on_epoch_end(training_config=self.training_config) # save checkpoints if ( self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): if self.is_main_process: self.save_checkpoint( model=self._best_model, dir_path=self.training_dir, epoch=epoch ) logger.info(f"Saved checkpoint at epoch {epoch}\n") if log_verbose: file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( self.training_config, metrics, logger=logger, global_step=epoch, rank=self.rank, ) final_dir = os.path.join(self.training_dir, "final_model") if self.is_main_process: self.save_model(self._best_model, dir_path=final_dir) logger.info("Training ended!") logger.info(f"Saved final model in {final_dir}") if self.distributed: dist.destroy_process_group() self.callback_handler.on_train_end(self.training_config)
[docs] def eval_step(self, epoch: int): """Perform an evaluation step Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The evaluation loss """ self.callback_handler.on_eval_step_begin( training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, rank=self.rank, ) self.model.eval() epoch_loss = 0 epoch_metrics = {} for inputs in self.eval_loader: inputs = set_inputs_to_device(inputs, device=self.device, keys=["data"]) try: with torch.no_grad(): model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, use_mean_embedding=True, ) except RuntimeError: model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, use_mean_embedding=True, ) loss = ( model_output.loss_sum if hasattr(model_output, "loss_sum") else model_output.loss ) epoch_loss += loss.item() update_dict(epoch_metrics, model_output.metrics) if epoch_loss != epoch_loss: raise ArithmeticError("NaN detected in eval loss") self.callback_handler.on_eval_step_end(training_config=self.training_config) epoch_metrics = { k: epoch_metrics[k] / len(self.eval_loader) for k in epoch_metrics } epoch_loss = epoch_loss / len(self.eval_loader.dataset) return epoch_loss, epoch_metrics
[docs] def train_step(self, epoch: int): """The trainer performs training loop over the train_loader. Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The step training loss """ self.callback_handler.on_train_step_begin( training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, rank=self.rank, ) # set model in train model self.model.train() epoch_loss = 0 epoch_model_metrics = {} batch_idx = 0 for inputs in self.train_loader: inputs = set_inputs_to_device(inputs, device=self.device, keys=["data"]) if hasattr(self.training_config, "beta_schedule"): beta_epoch = self.training_config.beta_schedule[epoch - 1] else: beta_epoch = 1 model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset), uses_ddp=self.distributed, batch_ratio=(batch_idx) / len(self.train_loader), beta=beta_epoch, ) self._optimizers_step(model_output) loss = ( model_output.loss_sum if hasattr(model_output, "loss_sum") else model_output.loss ) epoch_loss += loss.item() update_dict(epoch_model_metrics, model_output.metrics) if epoch_loss != epoch_loss: raise ArithmeticError("NaN detected in train loss") self.callback_handler.on_train_step_end( training_config=self.training_config ) batch_idx += 1 # Allows model updates if needed if self.distributed: self.model.module.update() else: self.model.update() epoch_model_metrics = { k: epoch_model_metrics[k] / len(self.train_loader) for k in epoch_model_metrics } epoch_loss = epoch_loss / len(self.train_dataset) return epoch_loss, epoch_model_metrics
[docs] def save_model(self, model: BaseModel, dir_path: str): """This method saves the final model along with the config files Args: model (BaseMultiVAE): The model to be saved dir_path (str): The folder where the model and config files should be saved """ if not os.path.exists(dir_path): os.makedirs(dir_path) # save model if self.distributed: model.module.save(dir_path) else: model.save(dir_path) # save training config self.training_config.save_json(dir_path, "training_config") # save metrics with open(f"{dir_path}/metrics_best_model.json", mode="w+") as fp: json.dump(self.metrics_best_model, fp) self.callback_handler.on_save(self.training_config, dir_path=dir_path)
[docs] def save_checkpoint(self, model: BaseModel, dir_path, epoch: int): """Saves a checkpoint alowing to restart training from here Args: dir_path (str): The folder where the checkpoint should be saved epochs_signature (int): The epoch number """ checkpoint_dir = os.path.join(dir_path, f"checkpoint_epoch_{epoch}") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # save optimizer torch.save( deepcopy(self.optimizer.state_dict()), os.path.join(checkpoint_dir, "optimizer.pt"), ) # save scheduler if self.scheduler is not None: torch.save( deepcopy(self.scheduler.state_dict()), os.path.join(checkpoint_dir, "scheduler.pt"), ) # save model if self.distributed: model.module.save(checkpoint_dir) else: model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config") # save metrics with open(f"{checkpoint_dir}/metrics_best_model.json", mode="w+") as fp: json.dump(self.metrics_best_model, fp) # save info about checkpoint info = dict( training_dir=self.training_dir, trained_epochs=epoch, best_train_loss=self.best_train_loss, best_eval_loss=self.best_eval_loss, ) with open(os.path.join(checkpoint_dir, "info_checkpoint.json"), "w+") as fp: json.dump(info, fp, sort_keys=True, indent=4) self.callback_handler.on_save_checkpoint( self.training_config, checkpoint_dir=checkpoint_dir )
[docs] def predict(self, model: BaseModel, epoch: int, n_data=8): """For BaseMultiVaE models, compute self and cross reconstructions during training.""" model.eval() predict_dataset = ( self.eval_dataset if self.eval_dataset is not None else self.train_dataset ) # Take one sample with n_data datapoints inputs = next(iter(DataLoader(predict_dataset, batch_size=n_data))) inputs = set_inputs_to_device(inputs, self.device, keys=["data"]) all_recons = {"images": {}} # For multimodal VAEs we compute all 1-to-1 cross-modal reconstruction if isinstance(model, BaseMultiVAE): for mod in inputs.data: recon = model.predict( inputs, mod, "all", N=8, flatten=True, ignore_incomplete=True ) if hasattr(predict_dataset, "transform_for_plotting"): recon = { mod_name: predict_dataset.transform_for_plotting( recon[mod_name], modality=mod_name ) for mod_name in recon } recon["true_data"] = predict_dataset.transform_for_plotting( inputs.data[mod], modality=mod ) else: recon["true_data"] = inputs.data[mod] recon, _ = adapt_shape(recon) recon_image = [recon["true_data"]] + [ recon[m] for m in recon if m != "true_data" ] recon_image = torch.cat(recon_image) recon_image = make_grid(recon_image, nrow=n_data) all_recons["images"][f"recon_from_{mod}"] = recon_image # For multimodal VAE or CVAE model, we compute the joint reconstruction recon = model.predict( inputs=inputs, cond_mod="all", gen_mod="all", N=8, flatten=True, ignore_incomplete=True, ) reconstructed_modalities = list(recon.keys()) if hasattr(predict_dataset, "transform_for_plotting"): recon = { mod_name: predict_dataset.transform_for_plotting( recon[mod_name], modality=mod_name ) for mod_name in recon } recon.update( { f"true_data_{mod_name}": predict_dataset.transform_for_plotting( inputs.data[mod_name], modality=mod_name ) for mod_name in inputs.data } ) else: recon.update( { f"true_data_{mod_name}": inputs.data[mod_name] for mod_name in inputs.data } ) recon, _ = adapt_shape(recon) recon_image = [recon[f"true_data_{m}"] for m in inputs.data] + [ recon[m] for m in reconstructed_modalities ] recon_image = torch.cat(recon_image) # Transform to PIL format recon_image = make_grid(recon_image, nrow=n_data) # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer all_recons["images"]["recon_from_all"] = recon_image return all_recons