train_one_epoch

hybrid_learning.concepts.train_eval.train_eval_funs.train_one_epoch(model, loss_fn, metric_fns, train_loader, optimizer, callbacks=None, callback_context=None, ensemble_count=None, prefix='train', loss_key='loss')[source]

Train for one epoch, evaluate, and return history and test results. History and test results are stored in a pandas.DataFrame. The device used is the one the model lies on. Distributed models are not supported.

Additional features:

  • ensemble_count: if set and >0, it is assumed the model is an ensemble and returns a stack of result tensors (stacked in dim 0). The additional metric 'disag@0.5' (average disagreement of ensemble outputs if binarized at a threshold of 0.5) is calculated and added to the results.

Parameters
  • model (Module) – model to train

  • loss_fn (Callable[[Tensor, Tensor], Tensor]) – function that calculates the optimization objective value

  • metric_fns (Dict[str, Callable[[Tensor, Tensor], Tensor]]) – further KPI functions to gather training stats

  • train_loader (torch.utils.data.DataLoader) – train data loader

  • optimizer (Union[torch.optim.Optimizer, ResettableOptimizer]) – optimizer to use for weight update steps initialized with model’s weights

  • callbacks (List[Mapping[CallbackEvents, Callable]]) – callbacks to feed with callback context after each batch, and before and after training epoch

  • callback_context (Dict[str, Any]) – dict with any additional context to be handed over to the callbacks as keyword arguments

  • ensemble_count (int) – if set to a value >0 treat the output of the model as ensemble_count outputs stacked in dim 0

  • prefix (str) – prefix to prepend to KPI names for the final pandas.Series naming

  • loss_key (str) – key the loss should have in the output

Returns

training history as pandas.DataFrame with

columns

loss and the KPI names (keys from dict metric_fns),

index

the batch indices,

items

the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)

Return type

DataFrame