TrainEvalHandle

class hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle(model, data, device=None, batch_size=None, batch_size_val=None, batch_size_hessian=None, max_epochs=None, num_workers=None, loss_fn=None, nll_fn=None, metric_fns=None, early_stopping_handle=None, optimizer=None, model_output_transform=None, metric_input_transform=None, callbacks=None, callback_context=None, show_progress_bars=True)[source]

Bases: ABC

Handle for training and evaluation of pytorch models. The model base class should be torch.nn.Module.

The main functions are train() and evaluate(). Metrics and loss functions must be given on initialization. Training and evaluation results are returned as pandas.DataFrame resp. pandas.Series with columns the metric keys (prefixed according to the mode). Modes can be train, test, or validation (see instances of DatasetSplit enum). The non-prefixed loss key is saved in LOSS_KEY.

For a usage example see ConceptDetection2DTrainTestHandle.

Public Data Attributes:

LOSS_KEY

Key for the loss evaluation results.

NLL_KEY

Key for the proper scoring evaluation function used for second stage training.

kpi_fns

Metric and loss functions.

settings

The current training settings as dictionary.

Public Methods:

add_callbacks(callbacks)

Append the given callbacks.

remove_callback(callback)

Remove a single given callback.

loader([data, mode, batch_size, shuffle, ...])

Prepare and return a torch data loader from the dataset according to settings.

reset_optimizer([optimizer, device, model])

Move model to correct device, init optimizer to parameters of model.

reset_training_handles([optimizer, device, ...])

(Re)set all handles associated with training, and move to device.

reset_kpis([kpis])

Resets aggregating kpis

train([train_loader, val_loader, ...])

Train the model according to the specified training parameters.

cross_validate([num_splits, train_val_data, ...])

Record training results for num_splits distinct val splits.

train_val_one_epoch([pbar_desc, ...])

Train for one epoch, evaluate, and return history and test results.

train_one_epoch([train_loader, ...])

Train for one epoch and return history results as pandas.DataFrame.

evaluate([mode, val_loader, prefix, ...])

Evaluate the model wrt.

second_stage_train([callback_context])

Do a second stage training for calibration using Laplace approximation.

Special Methods:

__repr__()

Return repr(self).

__init__(model, data[, device, batch_size, ...])

Init.


Parameters
__init__(model, data, device=None, batch_size=None, batch_size_val=None, batch_size_hessian=None, max_epochs=None, num_workers=None, loss_fn=None, nll_fn=None, metric_fns=None, early_stopping_handle=None, optimizer=None, model_output_transform=None, metric_input_transform=None, callbacks=None, callback_context=None, show_progress_bars=True)[source]

Init.

Parameters
__repr__()[source]

Return repr(self).

add_callbacks(callbacks)[source]

Append the given callbacks.

Parameters

callbacks (Iterable[Mapping[CallbackEvents, Callable]]) –

cross_validate(num_splits=5, train_val_data=None, run_info_templ='Run {run}/{runs}', callback_context=None, pbar_desc_templ='epoch {epoch}/{epochs}')[source]

Record training results for num_splits distinct val splits. The original model state dict is restored after training runs. The model must feature a reset_parameters() method to reinitialize between the runs.

Parameters
  • run_info_templ (str) – template containing as substring placeholders {run} (the number of the current run) and runs (the total number of runs); the template is used as prefix for logging and progress bars

  • num_splits (int) – number of equal-sized, distinct validation splits to use

  • train_val_data – optional given dataset to split into train and validation dataset splits; defaults to the train_val split in data

  • callback_context (Optional[Dict[str, Any]]) – current callback context to use; defaults to copy of callback_context

  • pbar_desc_templ (str) – template for progress bar description (prefixed by run information); see train()

Returns

list of tuples of the form

(
final state_dict,
epoch- and batch-wise train history as pandas.DataFrame,
epoch-wise validation history as pandas.DataFrame
)

Return type

List[Tuple[Dict[str, Tensor], DataFrame, DataFrame]]

static detached_state_dict(model, device='cpu')[source]

Return a properly detached copy of the state dict of model on device. By default, the copy is created on cpu device to avoid overloading the GPU memory.

Parameters
  • model (Module) –

  • device (Union[str, device]) –

Return type

Dict[str, Tensor]

evaluate(mode=DatasetSplit.TEST, val_loader=None, prefix=None, callback_context=None, pbar_desc='Progress')[source]

Evaluate the model wrt. settings. This is a wrapper around evaluate() which uses the defaults given by settings. Override them by specifying them as custom_args. The device used for evaluation is device or the one of the model.

Parameters
Returns

Dictionary of all KPIs, i.e. of loss and each metric in metric_fns; format: {<KPI-name>: <KPI value as float>}

Return type

Series

loader(data=None, *, mode=None, batch_size=None, shuffle=False, device=None, model=None, num_workers=None, **_)[source]

Prepare and return a torch data loader from the dataset according to settings. For details see loader().

Parameters
Return type

torch.utils.data.DataLoader

classmethod prefix_by(mode, text)[source]

Prefix s with the given mode.

Parameters
Return type

str

remove_callback(callback)[source]

Remove a single given callback.

Parameters

callback (Callable) –

reset_kpis(kpis=None)[source]

Resets aggregating kpis

Parameters

kpis (Optional[Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]]) – All metric functions and classes

reset_optimizer(optimizer=None, device=None, model=None)[source]

Move model to correct device, init optimizer to parameters of model. By default apply to optimizer, device, model.

Parameters
reset_training_handles(optimizer=None, device=None, model=None)[source]

(Re)set all handles associated with training, and move to device. These are: optimizer, early_stopping_handle, and the data loaders. The argument values default to the corresponding attributes of this instance.

Parameters
Return type

None

second_stage_train(callback_context=None)[source]

Do a second stage training for calibration using Laplace approximation. This is a wrapper around hybrid_learning.concepts.train_eval.train_eval_funs.second_stage_train() that uses defaults from settings. Before and after the second stage training process one epoch on the test and the validation set to enable logging of metrics for comparison.

Note

Evaluation runs on validation and test split are conducted before (epoch 0) and after (epoch 1) the second stage training.

Parameters

callback_context (Optional[Dict[str, Any]]) – see callback_context.

Returns

py:class:pandas.Series with the final evaluation results on validation and test splits of data

Return type

Series

classmethod test_(metric_name)[source]

Get name of metric for testing results.

Parameters

metric_name (str) –

Return type

str

train(train_loader=None, val_loader=None, pbar_desc_templ='Epoch {epoch}/{epochs}', callback_context=None)[source]

Train the model according to the specified training parameters. Defaults are taken from settings. To override specify custom_args (compare arguments to train_val_one_epoch).

Parameters
Returns

Two pandas.DataFrame with history information on

  • training: the epoch- and batch-wise loss and KPI results on the training data, index is a multi-index of (epoch, batch);

  • test: the epoch-wise evaluation results on the test set; index is the epoch index;

columns for both are loss and KPI names (keys of metric_fns)

Return type

Tuple[DataFrame, DataFrame]

classmethod train_(metric_name)[source]

Get name of metric for training results.

Parameters

metric_name (str) –

Return type

str

train_one_epoch(train_loader=None, callback_context=None, pbar_desc='Train progress')[source]

Train for one epoch and return history results as pandas.DataFrame. This is a wrapper around hybrid_learning.concepts.train_eval.train_eval_funs.train_one_epoch() that uses defaults from settings.

Parameters
Returns

tuple of training history and test results as pandas.DataFrame with:

columns

loss and the KPI names (keys from dict metric_fns),

index

the batch indices,

items

the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)

Return type

DataFrame

train_val_one_epoch(pbar_desc=None, callback_context=None, train_loader=None, val_loader=None)[source]

Train for one epoch, evaluate, and return history and test results. This is a wrapper around train_one_epoch() and evaluate() with nice progress bar printing and logging after the epoch. History and test results are stored in a pandas.DataFrame. The device used for training is that of the parameters of the used model (see device_of()).

Parameters
Returns

tuple of training history and test results; columns resp. index are loss and the KPI names (keys from dict metric_fns).

  • pandas.DataFrame: index are the batch indices, items are the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)

  • pandas.Series: the items are the final evaluations of the KPIs on the validation set

Return type

Tuple[DataFrame, Series]

classmethod val_(metric_name)[source]

Get name of metric for validation results.

Parameters

metric_name (str) –

Return type

str

LOSS_KEY = 'loss'

Key for the loss evaluation results.

NLL_KEY = 'NLL'

Key for the proper scoring evaluation function used for second stage training. Typically a negative log-likelihood.

batch_size: int

Default training batch size.

batch_size_hessian: int

Default batch size for calculating the hessian.

batch_size_val: int

Default validation batch size.

callback_context: Dict[str, Any]

The default callback context values to use. In any training run where context is used, the context can either be handed over or it defaults to a copy of this dict.

callbacks: List[Mapping[CallbackEvents, Callable]]

A dictionary mapping events to a list of callables that are called every time the event occurs with the current state. Some default logging callbacks are defined.

For details on available events, see hybrid_learning.concepts.train_eval.callbacks.CallbackEvents. After the event, all callbacks for this event are called in order with keyword arguments from a hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.callback_context. The base context is dependent on the event and includes e.g. the model, and can be extended by specifying the callback context during function call or in the default callback context. Note that side effects on objects in the callback context (e.g. the model) will influence the training. Callback application examples:

  • Logging

  • Storing of best n models

  • Results saving, e.g. to tensorboard or sacred log

  • Forced weight normalization

Callbacks can be added and removed.

data: DataTriple

Train, validation, and test data splits to use. Must be converted to data loaders before usage.

device: torch.device

Device to run training and testing on (this is where the data loaders are put).

early_stopping_handle: Optional[EarlyStoppingHandle]

Handle that is stepped during training and indicates need for early stopping. To disable early stopping, set early_stopping_handle to None resp. specify early_stopping_handle=False in __init__ arguments.

epochs: int

Default maximum number of epochs. May be reduced by early_stopping_handle.

property kpi_fns: Dict[str, Callable[[Tensor, Tensor], Tensor]]

Metric and loss functions. Nomenclature: metric_fns holds all metrics meant for evaluation, while kpi_fns also encompasses the losses.

loss_fn: Callable[[Tensor, Tensor], Tensor]

Loss function callable. Defaults to a balanced binary cross-entropy assuming on average 1% positive px per img. Must be wrapped into a tuple to hide the parameters, since these are not to be updated.

metric_fns: Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]

Dictionary of metric functions to apply for evaluation and logging. Each function must have a signature of (output, target) -> metric_value. See also kpi_fns. Keys must not contain LOSS_KEY or NLL_KEY.

model: torch.nn.modules.module.Module

The model to work on.

nll_fn: Callable[[Tensor, Tensor], Tensor]

Proper scoring function callable used as loss in second stage training for Laplace approximation. Usually is chosen as negative log-likelihood, defaults to loss_fn.

num_workers: int

The default number of workers to use for data loading. See hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.loader().

optimizer: ResettableOptimizer

Optimizer and learning rate scheduler handle.

property settings: Dict[str, Any]

The current training settings as dictionary.

show_progress_bars: str

Whether to show progress bars for batch-wise operations. Value must be a comma-separated concatenation of run types (the values of dataset splits) for which to show progress, or 'always'.