train_one_epoch
- hybrid_learning.concepts.train_eval.train_eval_funs.train_one_epoch(model, loss_fn, metric_fns, train_loader, optimizer, callbacks=None, callback_context=None, ensemble_count=None, prefix='train', loss_key='loss')[source]
Train for one epoch, evaluate, and return history and test results. History and test results are stored in a
pandas.DataFrame
. The device used is the one the model lies on. Distributed models are not supported.Additional features:
ensemble_count
: if set and >0, it is assumed the model is an ensemble and returns a stack of result tensors (stacked in dim 0). The additional metric'disag@0.5'
(average disagreement of ensemble outputs if binarized at a threshold of 0.5) is calculated and added to the results.
- Parameters
model (Module) – model to train
loss_fn (Callable[[Tensor, Tensor], Tensor]) – function that calculates the optimization objective value
metric_fns (Dict[str, Callable[[Tensor, Tensor], Tensor]]) – further KPI functions to gather training stats
train_loader (
torch.utils.data.DataLoader
) – train data loaderoptimizer (Union[
torch.optim.Optimizer
, ResettableOptimizer]) – optimizer to use for weight update steps initialized with model’s weightscallbacks (List[Mapping[CallbackEvents, Callable]]) – callbacks to feed with callback context after each batch, and before and after training epoch
callback_context (Dict[str, Any]) – dict with any additional context to be handed over to the callbacks as keyword arguments
ensemble_count (int) – if set to a value >0 treat the output of the model as
ensemble_count
outputs stacked in dim 0prefix (str) – prefix to prepend to KPI names for the final
pandas.Series
namingloss_key (str) – key the loss should have in the output
- Returns
training history as
pandas.DataFrame
with- columns
loss
and the KPI names (keys from dictmetric_fns
),- index
the batch indices,
- items
the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)
- Return type
DataFrame