BaseTrainerο
Base trainer for MultiVaeβs models.
- class multivae.trainers.BaseTrainerConfig(output_dir=None, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_epochs=100, train_dataloader_num_workers=0, eval_dataloader_num_workers=0, optimizer_cls='Adam', optimizer_params=None, scheduler_cls=None, scheduler_params=None, learning_rate=0.0001, steps_saving=None, steps_predict=None, keep_best_on_train=False, seed=8, no_cuda=False, world_size=-1, local_rank=-1, rank=-1, dist_backend='nccl', master_addr='localhost', master_port='12345', drop_last=False, gradient_clipping_max_norm=None)[source]ο
BaseTrainer config class stating the main training arguments.
- Parameters:
output_dir (str) β The directory where model checkpoints, configs and final model will be stored. Default: None.
per_device_train_batch_size (int) β The number of training samples per batch and per device. Default 64
per_device_eval_batch_size (int) β The number of evaluation samples per batch and per device. Default 64
num_epochs (int) β The maximal number of epochs for training. Default: 100
train_dataloader_num_workers (int) β Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0
eval_dataloader_num_workers (int) β Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0
optimizer_cls (str) β The name of the torch.optim.Optimizer used for training. Default:
Adam.optimizer_params (dict) β A dict containing the parameters to use for the torch.optim.Optimizer. If None, uses the default parameters. Default: None.
scheduler_cls (str) β The name of the torch.optim.lr_scheduler used for training. If None, no scheduler is used. Default None.
scheduler_params (dict) β A dict containing the parameters to use for the torch.optim.le_scheduler. If None, uses the default parameters. Default: None.
learning_rate (int) β The learning rate applied to the Optimizer. Default: 1e-4
steps_saving (int) β A model checkpoint will be saved every steps_saving epoch. Default: None
steps_predict (int) β A prediction using the best model will be run every steps_predict epoch. Default: None
keep_best_on_train (bool) β Whether to keep the best model on the train set. Default: False
seed (int) β The random seed for reproducibility
no_cuda (bool) β Disable cuda training. Default: False
world_size (int) β The total number of process to run. Default: -1
local_rank (int) β The rank of the node for distributed training. Default: -1
rank (int) β The rank of the process for distributed training. Default: -1
dist_backend (str) β The distributed backend to use. Default: βncclβ
master_addr (str) β The master address for distributed training. Default: βlocalhostβ
master_port (str) β The master port for distributed training. Default: β12345β
drop_last (bool) β if True, we drop the last batches in the dataloaders
gradient_clipping_max_norm (float) β clip the gradient norm. Default to None.
- class multivae.trainers.BaseTrainer(model, train_dataset, eval_dataset=None, training_config=None, callbacks=None, checkpoint=None)[source]ο
Base class to perform model training.
- Parameters:
model (BaseModel) β A instance of
BaseMultiVAEto train.train_dataset (MultimodalBaseDataset) β The training dataset of type
MultimodalBaseDataseteval_dataset (MultimodalBaseDataset) β The evaluation dataset of type
MultimodalBaseDatasettraining_config (BaseTrainerConfig) β The training arguments summarizing the main parameters used for training. If None, a basic training instance of
BaseTrainerConfigis used. Default: None.callbacks (List[TrainingCallback]) β A list of callbacks to use during training.
checkpoint (str) β The directory path of the checkpoint generated by the save_checkpoint function. Default to None. Used when training is resumed from a previous checkpoint.
- eval_step(epoch)[source]ο
Perform an evaluation step
- Parameters:
epoch (int) β The current epoch number
- Returns:
The evaluation loss
- Return type:
- predict(model, epoch, n_data=8)[source]ο
For BaseMultiVaE models, compute self and cross reconstructions during training.
- prepare_train_step(epoch, best_train_loss, best_eval_loss)[source]ο
Function to operate changes between train_steps such as resetting the optimizer and the best losses values.
- save_checkpoint(model, dir_path, epoch)[source]ο
Saves a checkpoint alowing to restart training from here
- save_model(model, dir_path)[source]ο
This method saves the final model along with the config files
- Parameters:
model (BaseMultiVAE) β The model to be saved
dir_path (str) β The folder where the model and config files should be saved