import datetime
import json
import logging
import os
from copy import deepcopy
from typing import List, Optional
import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.utils import make_grid, save_image
from ...data import MultimodalBaseDataset
from ...data.datasets.utils import adapt_shape
from ...data.utils import set_inputs_to_device
from ...models import BaseModel, BaseMultiVAE
from .base_trainer_config import BaseTrainerConfig
from .callbacks import (
CallbackHandler,
MetricConsolePrinterCallback,
ProgressBarCallback,
TrainingCallback,
)
from .utils import set_seed, update_dict
logger = logging.getLogger(__name__)
# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]
class BaseTrainer:
"""Base class to perform model training.
Args:
model (BaseModel): A instance of :class:`~multivae.models.BaseMultiVAE` to train.
train_dataset (MultimodalBaseDataset): The training dataset of type
:class:`~multivae.data.datasets.MultimodalBaseDataset`
eval_dataset (MultimodalBaseDataset): The evaluation dataset of type
:class:`~multivae.data.datasets.MultimodalBaseDataset`
training_config (BaseTrainerConfig): The training arguments summarizing the main
parameters used for training. If None, a basic training instance of
:class:`BaseTrainerConfig` is used. Default: None.
callbacks (List[~pythae.trainers.training_callbacks.TrainingCallback]):
A list of callbacks to use during training.
checkpoint (str) : The directory path of the checkpoint generated by the `save_checkpoint`
function. Default to None. Used when training is resumed from a previous checkpoint.
"""
def __init__(
self,
model: BaseModel,
train_dataset: MultimodalBaseDataset,
eval_dataset: Optional[MultimodalBaseDataset] = None,
training_config: Optional[BaseTrainerConfig] = None,
callbacks: List[TrainingCallback] = None,
checkpoint: str = None,
):
if training_config is None:
if checkpoint is None:
training_config = BaseTrainerConfig()
else:
training_config = BaseTrainerConfig.from_json_file(
os.path.join(checkpoint, "training_config.json")
)
if training_config.output_dir is None:
output_dir = "dummy_output_dir"
training_config.output_dir = output_dir
self.training_config = training_config
self.model_config = model.model_config
self.model_name = model.model_name
# for distributed training
self.world_size = self.training_config.world_size
self.local_rank = self.training_config.local_rank
self.rank = self.training_config.rank
self.dist_backend = self.training_config.dist_backend
if self.world_size > 1:
self.distributed = True
else:
self.distributed = False
if self.distributed:
device = self._setup_devices()
else:
device = (
"cuda"
if torch.cuda.is_available() and not self.training_config.no_cuda
else "cpu"
)
self.device = device
if checkpoint is not None:
model = model.load_from_folder(checkpoint)
# place model on device
model = model.to(device)
model.device = device
if self.distributed:
model = DDP(model, device_ids=[self.local_rank])
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
# Define the loaders
train_loader = self.get_train_dataloader(train_dataset)
if eval_dataset is not None:
eval_loader = self.get_eval_dataloader(eval_dataset)
else:
logger.info(
"! No eval dataset provided ! -> keeping best model on train.\n"
)
self.training_config.keep_best_on_train = True
eval_loader = None
self.train_loader = train_loader
self.eval_loader = eval_loader
self.callbacks = callbacks
self.start_keep_best_epoch = getattr(model, "start_keep_best_epoch", 0)
# run sanity check on the model
self._run_model_sanity_check(model, train_loader)
if self.is_main_process:
logger.info("Model passed sanity check !\n" "Ready for training.\n")
# Assert that the trainer is suited for the chosen model
self.checktrainer(model)
self.model = model
if checkpoint is None:
self.prepare_training()
else:
self.resume_training(checkpoint)
def checktrainer(self, model):
if hasattr(model, "reset_optimizer_epochs"):
if len(model.reset_optimizer_epochs) != 0:
raise AttributeError(
f"The model {self.model_name} has a 'reset_optimizer_epochs' attribute ",
"that is not empty. That means that it requires multistage training and therefore you",
"should use the ~multivae.trainers.MultistageTrainer instead of the BaseTrainer.",
)
@property
def is_main_process(self):
if self.rank == 0 or self.rank == -1:
return True
else:
return False
def _setup_devices(self):
"""Sets up the devices to perform distributed training."""
if dist.is_available() and dist.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
)
if self.training_config.no_cuda:
self._n_gpus = 0
device = "cpu"
else:
torch.cuda.set_device(self.local_rank)
device = torch.device("cuda", self.local_rank)
if not dist.is_initialized():
dist.init_process_group(
backend=self.dist_backend,
init_method="env://",
world_size=self.world_size,
rank=self.rank,
)
return device
def get_train_dataloader(
self, train_dataset: MultimodalBaseDataset
) -> torch.utils.data.DataLoader:
if self.distributed:
train_sampler = DistributedSampler(
train_dataset, num_replicas=self.world_size, rank=self.rank
)
else:
train_sampler = None
return DataLoader(
dataset=train_dataset,
batch_size=self.training_config.per_device_train_batch_size,
num_workers=self.training_config.train_dataloader_num_workers,
shuffle=(train_sampler is None),
sampler=train_sampler,
drop_last=self.training_config.drop_last,
)
def get_eval_dataloader(
self, eval_dataset: MultimodalBaseDataset
) -> torch.utils.data.DataLoader:
if self.distributed:
eval_sampler = DistributedSampler(
eval_dataset, num_replicas=self.world_size, rank=self.rank
)
else:
eval_sampler = None
return DataLoader(
dataset=eval_dataset,
batch_size=self.training_config.per_device_eval_batch_size,
num_workers=self.training_config.eval_dataloader_num_workers,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
)
def set_optimizer(self):
optimizer_cls = getattr(optim, self.training_config.optimizer_cls)
logger.info(
f"Setting the optimizer with learning rate {self.training_config.learning_rate}"
)
if self.training_config.optimizer_params is not None:
optimizer = optimizer_cls(
self.model.parameters(),
lr=self.training_config.learning_rate,
**self.training_config.optimizer_params,
)
else:
optimizer = optimizer_cls(
self.model.parameters(), lr=self.training_config.learning_rate
)
self.optimizer = optimizer
def set_scheduler(self):
if self.training_config.scheduler_cls is not None:
scheduler_cls = getattr(lr_scheduler, self.training_config.scheduler_cls)
if self.training_config.scheduler_params is not None:
scheduler = scheduler_cls(
self.optimizer, **self.training_config.scheduler_params
)
else:
scheduler = scheduler_cls(self.optimizer)
else:
scheduler = None
self.scheduler = scheduler
def _set_output_dir(self):
# Create folder
if not os.path.exists(self.training_config.output_dir) and self.is_main_process:
os.makedirs(self.training_config.output_dir, exist_ok=True)
logger.info(
f"Created {self.training_config.output_dir} folder since did not exist.\n"
)
self._training_signature = (
str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-")
)
training_dir = os.path.join(
self.training_config.output_dir,
f"{self.model_name}_training_{self._training_signature}",
)
self.training_dir = training_dir
if not os.path.exists(training_dir) and self.is_main_process:
os.makedirs(training_dir, exist_ok=True)
logger.info(
f"Created {training_dir}. \n"
"Training config, checkpoints and final model will be saved here.\n"
)
def _get_file_logger(self, log_output_dir):
log_dir = log_output_dir
# if dir does not exist create it
if not os.path.exists(log_dir) and self.is_main_process:
os.makedirs(log_dir, exist_ok=True)
logger.info(f"Created {log_dir} folder since did not exists.")
logger.info("Training logs will be recodered here.\n")
logger.info(" -> Training can be monitored here.\n")
# create and set logger
log_name = f"training_logs_{self._training_signature}"
file_logger = logging.getLogger(log_name)
file_logger.setLevel(logging.INFO)
f_handler = logging.FileHandler(
os.path.join(log_dir, f"training_logs_{self._training_signature}.log")
)
f_handler.setLevel(logging.INFO)
file_logger.addHandler(f_handler)
# Do not output logs in the console
file_logger.propagate = False
return file_logger
def _setup_callbacks(self):
if self.callbacks is None:
self.callbacks = [TrainingCallback()]
self.callback_handler = CallbackHandler(
callbacks=self.callbacks, model=self.model
)
self.callback_handler.add_callback(ProgressBarCallback())
self.callback_handler.add_callback(MetricConsolePrinterCallback())
def _run_model_sanity_check(self, model, loader):
try:
inputs = next(iter(loader))
train_dataset = set_inputs_to_device(
inputs, device=self.device, keys=["data"]
)
model(train_dataset)
model.zero_grad()
torch.cuda.empty_cache()
del inputs, train_dataset
# check cuda memory
except Exception as e:
raise Exception(
"Error when calling forward method from model. Potential issues: \n"
" - Wrong model architecture -> check encoder, decoder and metric architecture if "
"you provide yours \n"
" - The data input dimension provided is wrong -> when no encoder, decoder or metric "
"provided, a network is built automatically but requires the shape of the flatten "
"input data.\n"
f"Exception raised: {type(e)} with message: " + str(e)
) from e
def _optimizers_step(self, model_output=None):
loss = model_output.loss
self.optimizer.zero_grad()
if self.training_config.gradient_clipping_max_norm is not None:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.training_config.gradient_clipping_max_norm,
)
loss.backward()
torch.cuda.empty_cache()
self.optimizer.step()
def _schedulers_step(self, epoch, metrics=None):
if self.scheduler is None:
pass
elif self.training_config.start_lr_scheduler_epoch > epoch:
logger.info("Not taking a scheduler step yet.")
pass
elif isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(metrics)
else:
self.scheduler.step()
[docs]
def prepare_training(self):
"""Sets up the trainer for training"""
# set random seed
set_seed(self.training_config.seed)
# set optimizer
self.set_optimizer()
# set scheduler
self.set_scheduler()
# create folder for saving
self._set_output_dir()
# set callbacks
self._setup_callbacks()
# set already_trained epochs
self.trained_epochs = 0
self.best_train_loss = torch.inf
self.best_eval_loss = torch.inf
self.metrics_best_model = {}
# set up the best_model
self._best_model = deepcopy(self.model)
[docs]
def resume_training(self, checkpoint):
"""Sets up the trainer for training"""
with open(os.path.join(checkpoint, "info_checkpoint.json"), "r") as fp:
dict_checkpoint = json.load(fp)
with open(os.path.join(checkpoint, "metrics_best_model.json"), "r") as fp:
self.metrics_best_model = json.load(fp)
# set random seed
set_seed(self.training_config.seed)
# set optimizer
self.set_optimizer()
self.optimizer.load_state_dict(
torch.load(
os.path.join(checkpoint, "optimizer.pt"), map_location=self.device
)
)
# set scheduler
self.set_scheduler()
if self.scheduler is not None:
self.scheduler.load_state_dict(
torch.load(
os.path.join(checkpoint, "scheduler.pt"), map_location=self.device
)
)
# create folder for saving
self.training_dir = dict_checkpoint["training_dir"]
# set callbacks
self._setup_callbacks()
# set already_trained epochs
self.trained_epochs = dict_checkpoint["trained_epochs"]
self.best_train_loss = dict_checkpoint["best_train_loss"]
self.best_eval_loss = dict_checkpoint["best_eval_loss"]
# set up the best_model
self._best_model = deepcopy(self.model)
[docs]
def prepare_train_step(self, epoch, best_train_loss, best_eval_loss):
"""Function to operate changes between train_steps such as resetting the optimizer and
the best losses values.
"""
return best_train_loss, best_eval_loss
[docs]
def train(self, log_output_dir: str = None):
"""This function is the main training function
Args:
log_output_dir (str): The path in which the log will be stored
start_epoch (int) : The first epoch to do. Is useful in case of restarting a
training after saving a checkpoint. Default to 1.
"""
self.callback_handler.on_train_begin(
training_config=self.training_config, model_config=self.model_config
)
log_verbose = False
msg = (
f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n"
" - per_device_train_batch_size: "
f"{self.training_config.per_device_train_batch_size}\n"
" - per_device_eval_batch_size: "
f"{self.training_config.per_device_eval_batch_size}\n"
f" - checkpoint saving every: {self.training_config.steps_saving}\n"
f"Optimizer: {self.optimizer}\n"
f"Scheduler: {self.scheduler}\n"
)
if self.is_main_process:
logger.info(msg)
# set up log file
if log_output_dir is not None and self.is_main_process:
log_verbose = True
file_logger = self._get_file_logger(log_output_dir=log_output_dir)
file_logger.info(msg)
if self.is_main_process:
logger.info("Successfully launched training !\n")
for epoch in range(
self.trained_epochs + 1, self.training_config.num_epochs + 1
):
self.callback_handler.on_epoch_begin(
training_config=self.training_config,
epoch=epoch,
train_loader=self.train_loader,
eval_loader=self.eval_loader,
)
self.best_train_loss, self.best_eval_loss = self.prepare_train_step(
epoch, self.best_train_loss, self.best_eval_loss
)
epoch_train_loss, epoch_metrics = self.train_step(epoch)
metrics = {
"train_"
+ k: (
epoch_metrics[k].item()
if isinstance(epoch_metrics[k], torch.Tensor)
else epoch_metrics[k]
)
for k in epoch_metrics
}
metrics["train_epoch_loss"] = epoch_train_loss
torch.cuda.empty_cache()
if self.eval_dataset is not None:
epoch_eval_loss, epoch_eval_metrics = self.eval_step(epoch)
metrics["eval_epoch_loss"] = epoch_eval_loss
update_dict(
metrics,
{
"eval_"
+ k: (
epoch_metrics[k].item()
if isinstance(epoch_metrics[k], torch.Tensor)
else epoch_metrics[k]
)
for k in epoch_eval_metrics
},
)
self._schedulers_step(epoch, epoch_eval_loss)
torch.cuda.empty_cache()
else:
epoch_eval_loss = self.best_eval_loss
self._schedulers_step(epoch, epoch_train_loss)
if epoch <= self.start_keep_best_epoch:
# save the model, don't keep track of the best loss
best_model = deepcopy(self.model)
self._best_model = best_model
self.metrics_best_model = metrics
logger.info("New model saved!")
elif (
epoch_eval_loss < self.best_eval_loss
and not self.training_config.keep_best_on_train
):
self.best_eval_loss = epoch_eval_loss
best_model = deepcopy(self.model)
self._best_model = best_model
self.metrics_best_model = metrics
logger.info("New best model on eval saved!")
elif (
epoch_train_loss < self.best_train_loss
and self.training_config.keep_best_on_train
):
self.best_train_loss = epoch_train_loss
best_model = deepcopy(self.model)
self._best_model = best_model
self.metrics_best_model = metrics
logger.info("New best model on train saved!")
# If steps_predict is not None, compute reconstruction images
if (
self.training_config.steps_predict is not None
and (epoch % self.training_config.steps_predict == 0 or epoch == 1)
and self.is_main_process
):
metrics_media = self.predict(self._best_model, epoch)
self.callback_handler.on_prediction_step(
self.training_config,
metrics_media=metrics_media,
global_step=epoch,
)
# Save the reconstructions to folder
images = metrics_media.pop("images", {})
for key, image in images.items():
save_image(image, os.path.join(self.training_dir, f"{key}.png"))
torch.cuda.empty_cache()
self.callback_handler.on_epoch_end(training_config=self.training_config)
# save checkpoints
if (
self.training_config.steps_saving is not None
and epoch % self.training_config.steps_saving == 0
):
if self.is_main_process:
self.save_checkpoint(
model=self._best_model, dir_path=self.training_dir, epoch=epoch
)
logger.info(f"Saved checkpoint at epoch {epoch}\n")
if log_verbose:
file_logger.info(f"Saved checkpoint at epoch {epoch}\n")
self.callback_handler.on_log(
self.training_config,
metrics,
logger=logger,
global_step=epoch,
rank=self.rank,
)
final_dir = os.path.join(self.training_dir, "final_model")
if self.is_main_process:
self.save_model(self._best_model, dir_path=final_dir)
logger.info("Training ended!")
logger.info(f"Saved final model in {final_dir}")
if self.distributed:
dist.destroy_process_group()
self.callback_handler.on_train_end(self.training_config)
[docs]
def eval_step(self, epoch: int):
"""Perform an evaluation step
Parameters:
epoch (int): The current epoch number
Returns:
(torch.Tensor): The evaluation loss
"""
self.callback_handler.on_eval_step_begin(
training_config=self.training_config,
eval_loader=self.eval_loader,
epoch=epoch,
rank=self.rank,
)
self.model.eval()
epoch_loss = 0
epoch_metrics = {}
for inputs in self.eval_loader:
inputs = set_inputs_to_device(inputs, device=self.device, keys=["data"])
try:
with torch.no_grad():
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.eval_loader.dataset),
uses_ddp=self.distributed,
use_mean_embedding=True,
)
except RuntimeError:
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.eval_loader.dataset),
uses_ddp=self.distributed,
use_mean_embedding=True,
)
loss = (
model_output.loss_sum
if hasattr(model_output, "loss_sum")
else model_output.loss
)
epoch_loss += loss.item()
update_dict(epoch_metrics, model_output.metrics)
if epoch_loss != epoch_loss:
raise ArithmeticError("NaN detected in eval loss")
self.callback_handler.on_eval_step_end(training_config=self.training_config)
epoch_metrics = {
k: epoch_metrics[k] / len(self.eval_loader) for k in epoch_metrics
}
epoch_loss = epoch_loss / len(self.eval_loader.dataset)
return epoch_loss, epoch_metrics
[docs]
def train_step(self, epoch: int):
"""The trainer performs training loop over the train_loader.
Parameters:
epoch (int): The current epoch number
Returns:
(torch.Tensor): The step training loss
"""
self.callback_handler.on_train_step_begin(
training_config=self.training_config,
train_loader=self.train_loader,
epoch=epoch,
rank=self.rank,
)
# set model in train model
self.model.train()
epoch_loss = 0
epoch_model_metrics = {}
batch_idx = 0
for inputs in self.train_loader:
inputs = set_inputs_to_device(inputs, device=self.device, keys=["data"])
if hasattr(self.training_config, "beta_schedule"):
beta_epoch = self.training_config.beta_schedule[epoch - 1]
else:
beta_epoch = 1
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.train_loader.dataset),
uses_ddp=self.distributed,
batch_ratio=(batch_idx) / len(self.train_loader),
beta=beta_epoch,
)
self._optimizers_step(model_output)
loss = (
model_output.loss_sum
if hasattr(model_output, "loss_sum")
else model_output.loss
)
epoch_loss += loss.item()
update_dict(epoch_model_metrics, model_output.metrics)
if epoch_loss != epoch_loss:
raise ArithmeticError("NaN detected in train loss")
self.callback_handler.on_train_step_end(
training_config=self.training_config
)
batch_idx += 1
# Allows model updates if needed
if self.distributed:
self.model.module.update()
else:
self.model.update()
epoch_model_metrics = {
k: epoch_model_metrics[k] / len(self.train_loader)
for k in epoch_model_metrics
}
epoch_loss = epoch_loss / len(self.train_dataset)
return epoch_loss, epoch_model_metrics
[docs]
def save_model(self, model: BaseModel, dir_path: str):
"""This method saves the final model along with the config files
Args:
model (BaseMultiVAE): The model to be saved
dir_path (str): The folder where the model and config files should be saved
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# save model
if self.distributed:
model.module.save(dir_path)
else:
model.save(dir_path)
# save training config
self.training_config.save_json(dir_path, "training_config")
# save metrics
with open(f"{dir_path}/metrics_best_model.json", mode="w+") as fp:
json.dump(self.metrics_best_model, fp)
self.callback_handler.on_save(self.training_config, dir_path=dir_path)
[docs]
def save_checkpoint(self, model: BaseModel, dir_path, epoch: int):
"""Saves a checkpoint alowing to restart training from here
Args:
dir_path (str): The folder where the checkpoint should be saved
epochs_signature (int): The epoch number
"""
checkpoint_dir = os.path.join(dir_path, f"checkpoint_epoch_{epoch}")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# save optimizer
torch.save(
deepcopy(self.optimizer.state_dict()),
os.path.join(checkpoint_dir, "optimizer.pt"),
)
# save scheduler
if self.scheduler is not None:
torch.save(
deepcopy(self.scheduler.state_dict()),
os.path.join(checkpoint_dir, "scheduler.pt"),
)
# save model
if self.distributed:
model.module.save(checkpoint_dir)
else:
model.save(checkpoint_dir)
# save training config
self.training_config.save_json(checkpoint_dir, "training_config")
# save metrics
with open(f"{checkpoint_dir}/metrics_best_model.json", mode="w+") as fp:
json.dump(self.metrics_best_model, fp)
# save info about checkpoint
info = dict(
training_dir=self.training_dir,
trained_epochs=epoch,
best_train_loss=self.best_train_loss,
best_eval_loss=self.best_eval_loss,
)
with open(os.path.join(checkpoint_dir, "info_checkpoint.json"), "w+") as fp:
json.dump(info, fp, sort_keys=True, indent=4)
self.callback_handler.on_save_checkpoint(
self.training_config, checkpoint_dir=checkpoint_dir
)
[docs]
def predict(self, model: BaseModel, epoch: int, n_data=8):
"""For BaseMultiVaE models, compute self and cross reconstructions during training."""
model.eval()
predict_dataset = (
self.eval_dataset if self.eval_dataset is not None else self.train_dataset
)
# Take one sample with n_data datapoints
inputs = next(iter(DataLoader(predict_dataset, batch_size=n_data)))
inputs = set_inputs_to_device(inputs, self.device, keys=["data"])
all_recons = {"images": {}}
# For multimodal VAEs we compute all 1-to-1 cross-modal reconstruction
if isinstance(model, BaseMultiVAE):
for mod in inputs.data:
recon = model.predict(
inputs, mod, "all", N=8, flatten=True, ignore_incomplete=True
)
if hasattr(predict_dataset, "transform_for_plotting"):
recon = {
mod_name: predict_dataset.transform_for_plotting(
recon[mod_name], modality=mod_name
)
for mod_name in recon
}
recon["true_data"] = predict_dataset.transform_for_plotting(
inputs.data[mod], modality=mod
)
else:
recon["true_data"] = inputs.data[mod]
recon, _ = adapt_shape(recon)
recon_image = [recon["true_data"]] + [
recon[m] for m in recon if m != "true_data"
]
recon_image = torch.cat(recon_image)
recon_image = make_grid(recon_image, nrow=n_data)
all_recons["images"][f"recon_from_{mod}"] = recon_image
# For multimodal VAE or CVAE model, we compute the joint reconstruction
recon = model.predict(
inputs=inputs,
cond_mod="all",
gen_mod="all",
N=8,
flatten=True,
ignore_incomplete=True,
)
reconstructed_modalities = list(recon.keys())
if hasattr(predict_dataset, "transform_for_plotting"):
recon = {
mod_name: predict_dataset.transform_for_plotting(
recon[mod_name], modality=mod_name
)
for mod_name in recon
}
recon.update(
{
f"true_data_{mod_name}": predict_dataset.transform_for_plotting(
inputs.data[mod_name], modality=mod_name
)
for mod_name in inputs.data
}
)
else:
recon.update(
{
f"true_data_{mod_name}": inputs.data[mod_name]
for mod_name in inputs.data
}
)
recon, _ = adapt_shape(recon)
recon_image = [recon[f"true_data_{m}"] for m in inputs.data] + [
recon[m] for m in reconstructed_modalities
]
recon_image = torch.cat(recon_image)
# Transform to PIL format
recon_image = make_grid(recon_image, nrow=n_data)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
all_recons["images"]["recon_from_all"] = recon_image
return all_recons