BaseTrainer

Base trainer for MultiVae’s models.

class multivae.trainers.BaseTrainerConfig(output_dir=None, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_epochs=100, train_dataloader_num_workers=0, eval_dataloader_num_workers=0, optimizer_cls='Adam', optimizer_params=None, scheduler_cls=None, scheduler_params=None, learning_rate=0.0001, steps_saving=None, steps_predict=None, keep_best_on_train=False, seed=8, no_cuda=False, world_size=-1, local_rank=-1, rank=-1, dist_backend='nccl', master_addr='localhost', master_port='12345', drop_last=False, gradient_clipping_max_norm=None)[source]

BaseTrainer config class stating the main training arguments.

Parameters:
  • output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.

  • per_device_train_batch_size (int) – The number of training samples per batch and per device. Default 64

  • per_device_eval_batch_size (int) – The number of evaluation samples per batch and per device. Default 64

  • num_epochs (int) – The maximal number of epochs for training. Default: 100

  • train_dataloader_num_workers (int) – Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0

  • eval_dataloader_num_workers (int) – Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0

  • optimizer_cls (str) – The name of the torch.optim.Optimizer used for training. Default: Adam.

  • optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer. If None, uses the default parameters. Default: None.

  • scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for training. If None, no scheduler is used. Default None.

  • scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler. If None, uses the default parameters. Default: None.

  • learning_rate (int) – The learning rate applied to the Optimizer. Default: 1e-4

  • steps_saving (int) – A model checkpoint will be saved every steps_saving epoch. Default: None

  • steps_predict (int) – A prediction using the best model will be run every steps_predict epoch. Default: None

  • keep_best_on_train (bool) – Whether to keep the best model on the train set. Default: False

  • seed (int) – The random seed for reproducibility

  • no_cuda (bool) – Disable cuda training. Default: False

  • world_size (int) – The total number of process to run. Default: -1

  • local_rank (int) – The rank of the node for distributed training. Default: -1

  • rank (int) – The rank of the process for distributed training. Default: -1

  • dist_backend (str) – The distributed backend to use. Default: β€˜nccl’

  • master_addr (str) – The master address for distributed training. Default: β€˜localhost’

  • master_port (str) – The master port for distributed training. Default: β€˜12345’

  • drop_last (bool) – if True, we drop the last batches in the dataloaders

  • gradient_clipping_max_norm (float) – clip the gradient norm. Default to None.

class multivae.trainers.BaseTrainer(model, train_dataset, eval_dataset=None, training_config=None, callbacks=None, checkpoint=None)[source]

Base class to perform model training.

Parameters:
eval_step(epoch)[source]

Perform an evaluation step

Parameters:

epoch (int) – The current epoch number

Returns:

The evaluation loss

Return type:

(torch.Tensor)

predict(model, epoch, n_data=8)[source]

For BaseMultiVaE models, compute self and cross reconstructions during training.

prepare_train_step(epoch, best_train_loss, best_eval_loss)[source]

Function to operate changes between train_steps such as resetting the optimizer and the best losses values.

prepare_training()[source]

Sets up the trainer for training

resume_training(checkpoint)[source]

Sets up the trainer for training

save_checkpoint(model, dir_path, epoch)[source]

Saves a checkpoint alowing to restart training from here

Parameters:
  • dir_path (str) – The folder where the checkpoint should be saved

  • epochs_signature (int) – The epoch number

save_model(model, dir_path)[source]

This method saves the final model along with the config files

Parameters:
  • model (BaseMultiVAE) – The model to be saved

  • dir_path (str) – The folder where the model and config files should be saved

train(log_output_dir=None)[source]

This function is the main training function

Parameters:
  • log_output_dir (str) – The path in which the log will be stored

  • start_epoch (int) – The first epoch to do. Is useful in case of restarting a training after saving a checkpoint. Default to 1.

train_step(epoch)[source]

The trainer performs training loop over the train_loader.

Parameters:

epoch (int) – The current epoch number

Returns:

The step training loss

Return type:

(torch.Tensor)