train_one_epoch
- hybrid_learning.concepts.train_eval.train_eval_funs.train_one_epoch(model, loss_fn, metric_fns, train_loader, optimizer, callbacks=None, callback_context=None, ensemble_count=None, prefix='train', loss_key='loss')[source]
- Train for one epoch, evaluate, and return history and test results. History and test results are stored in a - pandas.DataFrame. The device used is the one the model lies on. Distributed models are not supported.- Additional features: - ensemble_count: if set and >0, it is assumed the model is an ensemble and returns a stack of result tensors (stacked in dim 0). The additional metric- 'disag@0.5'(average disagreement of ensemble outputs if binarized at a threshold of 0.5) is calculated and added to the results.
 - Parameters
- model (Module) – model to train 
- loss_fn (Callable[[Tensor, Tensor], Tensor]) – function that calculates the optimization objective value 
- metric_fns (Dict[str, Callable[[Tensor, Tensor], Tensor]]) – further KPI functions to gather training stats 
- train_loader ( - torch.utils.data.DataLoader) – train data loader
- optimizer (Union[ - torch.optim.Optimizer, ResettableOptimizer]) – optimizer to use for weight update steps initialized with model’s weights
- callbacks (List[Mapping[CallbackEvents, Callable]]) – callbacks to feed with callback context after each batch, and before and after training epoch 
- callback_context (Dict[str, Any]) – dict with any additional context to be handed over to the callbacks as keyword arguments 
- ensemble_count (int) – if set to a value >0 treat the output of the model as - ensemble_countoutputs stacked in dim 0
- prefix (str) – prefix to prepend to KPI names for the final - pandas.Seriesnaming
- loss_key (str) – key the loss should have in the output 
 
- Returns
- training history as - pandas.DataFramewith- columns
- lossand the KPI names (keys from dict- metric_fns),
- index
- the batch indices, 
- items
- the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step) 
 
- Return type
- DataFrame