TrainEvalHandle
- class hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle(model, data, device=None, batch_size=None, batch_size_val=None, batch_size_hessian=None, max_epochs=None, num_workers=None, loss_fn=None, nll_fn=None, metric_fns=None, early_stopping_handle=None, optimizer=None, model_output_transform=None, metric_input_transform=None, callbacks=None, callback_context=None, show_progress_bars=True)[source]
Bases:
ABC
Handle for training and evaluation of pytorch models. The model base class should be
torch.nn.Module
.The main functions are
train()
andevaluate()
. Metrics and loss functions must be given on initialization. Training and evaluation results are returned aspandas.DataFrame
resp.pandas.Series
with columns the metric keys (prefixed according to the mode). Modes can be train, test, or validation (see instances ofDatasetSplit
enum). The non-prefixed loss key is saved inLOSS_KEY
.For a usage example see
ConceptDetection2DTrainTestHandle
.Public Data Attributes:
Key for the loss evaluation results.
Key for the proper scoring evaluation function used for second stage training.
Metric and loss functions.
The current training settings as dictionary.
Public Methods:
add_callbacks
(callbacks)Append the given callbacks.
remove_callback
(callback)Remove a single given callback.
loader
([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer
([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles
([optimizer, device, ...])(Re)set all handles associated with training, and move to
device
.reset_kpis
([kpis])Resets aggregating kpis
train
([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate
([num_splits, train_val_data, ...])Record training results for
num_splits
distinct val splits.train_val_one_epoch
([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch
([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame
.evaluate
([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train
([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
__repr__
()Return repr(self).
__init__
(model, data[, device, batch_size, ...])Init.
- Parameters
model (Module) –
data (DataTriple) –
device (device) –
batch_size (int) –
batch_size_val (int) –
batch_size_hessian (int) –
max_epochs (int) –
num_workers (int) –
metric_fns (Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]) –
early_stopping_handle (EarlyStoppingHandle) –
optimizer (Callable[[...],
torch.optim.Optimizer
]) –model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
callbacks (List[Mapping[CallbackEvents, Callable]]) –
- __init__(model, data, device=None, batch_size=None, batch_size_val=None, batch_size_hessian=None, max_epochs=None, num_workers=None, loss_fn=None, nll_fn=None, metric_fns=None, early_stopping_handle=None, optimizer=None, model_output_transform=None, metric_input_transform=None, callbacks=None, callback_context=None, show_progress_bars=True)[source]
Init.
- Parameters
model (Module) – model to train/eval
device (Optional[device]) – device on which to load the data and the model parameters
num_workers (Optional[int]) – number of workers to use for data loading; see
loader()
; single process loading is used if unset or <2optimizer (Optional[Callable[[...],
torch.optim.Optimizer
]]) – callable that yields a fresh optimizer instance when called on the model’s trainable parametersearly_stopping_handle (Optional[EarlyStoppingHandle]) – handle for early stopping; defaults to default
EarlyStoppingHandle
ifNone
; set toFalse
to disable early stopping;loss_fn (Optional[Callable[[Tensor, Tensor], Tensor]]) – differentiable metric function to use as loss
nll_fn (Optional[Callable[[Tensor, Tensor], Tensor]]) – Negative log likelihood (or other proper scoring function) for use as Laplace approximation
metric_fns (Optional[Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]]) –
Dictionary of metric functions, each accepting
the batch model output tensor, and
the batch ground truth tensor
and yields the value of the specified metric.
model_output_transform (Optional[TupleTransforms]) – transformation applied to the tuples of
(model output, target)
before applying loss functions or metric functions; the functions are wrapped correspondingly;metric_input_transform (Optional[TupleTransforms]) – transformation applied to the tuples of
(model output, target)
before applying metric functions only (not the loss and scoring functions), after model_output_transform is applied; the functions are wrapped correspondingly; meant as convenient way to modify metrics simultaneouslycallbacks (Optional[List[Mapping[CallbackEvents, Callable]]]) – see
callbacks
show_progress_bars (Union[bool, str]) – see
show_progress_bars
data (DataTriple) –
- add_callbacks(callbacks)[source]
Append the given callbacks.
- Parameters
callbacks (Iterable[Mapping[CallbackEvents, Callable]]) –
- cross_validate(num_splits=5, train_val_data=None, run_info_templ='Run {run}/{runs}', callback_context=None, pbar_desc_templ='epoch {epoch}/{epochs}')[source]
Record training results for
num_splits
distinct val splits. The original model state dict is restored after training runs. The model must feature areset_parameters()
method to reinitialize between the runs.- Parameters
run_info_templ (str) – template containing as substring placeholders
{run}
(the number of the current run) andruns
(the total number of runs); the template is used as prefix for logging and progress barsnum_splits (int) – number of equal-sized, distinct validation splits to use
train_val_data – optional given dataset to split into train and validation dataset splits; defaults to the
train_val
split indata
callback_context (Optional[Dict[str, Any]]) – current callback context to use; defaults to copy of
callback_context
pbar_desc_templ (str) – template for progress bar description (prefixed by run information); see
train()
- Returns
list of tuples of the form
(finalstate_dict
,epoch- and batch-wise train history aspandas.DataFrame
,epoch-wise validation history aspandas.DataFrame
)- Return type
- static detached_state_dict(model, device='cpu')[source]
Return a properly detached copy of the state dict of
model
ondevice
. By default, the copy is created oncpu
device to avoid overloading the GPU memory.
- evaluate(mode=DatasetSplit.TEST, val_loader=None, prefix=None, callback_context=None, pbar_desc='Progress')[source]
Evaluate the model wrt.
settings
. This is a wrapper aroundevaluate()
which uses the defaults given bysettings
. Override them by specifying them ascustom_args
. The device used for evaluation isdevice
or the one of the model.- Parameters
mode (Union[str, DatasetSplit]) – see
loader
prefix (Optional[str]) – see
evaluate()
val_loader (Optional[
torch.utils.data.DataLoader
]) – the evaluation dataset loader; defaults to one withdata
of respectivemode
callback_context (Optional[Dict[str, Any]]) – see
callback_context
pbar_desc (str) – leading static description text for the progress bar if newly created
- Returns
Dictionary of all KPIs, i.e. of
loss
and each metric inmetric_fns
; format:{<KPI-name>: <KPI value as float>}
- Return type
Series
- loader(data=None, *, mode=None, batch_size=None, shuffle=False, device=None, model=None, num_workers=None, **_)[source]
Prepare and return a torch data loader from the dataset according to settings. For details see
loader()
.- Parameters
data (Optional[Union[
torch.utils.data.Dataset
, BaseDataset]]) – data to obtain loader for; defaults todata
of respectivemode
mode (Optional[Union[str, DatasetSplit]]) – which
data
split to use by default; specify as instance ofDatasetSplit
or the name of one;batch_size (Optional[int]) – defaults to
batch_size
num_workers (Optional[int]) – defaults to
num_workers
shuffle (bool) –
model (Optional[Module]) –
- Return type
- classmethod prefix_by(mode, text)[source]
Prefix
s
with the given mode.- Parameters
mode (DatasetSplit) –
text (str) –
- Return type
- reset_optimizer(optimizer=None, device=None, model=None)[source]
Move model to correct device, init optimizer to parameters of model. By default apply to
optimizer
,device
,model
.- Parameters
optimizer (Optional[ResettableOptimizer]) –
device (Optional[device]) –
model (Optional[Module]) –
- reset_training_handles(optimizer=None, device=None, model=None)[source]
(Re)set all handles associated with training, and move to
device
. These are:optimizer
,early_stopping_handle
, and the data loaders. The argument values default to the corresponding attributes of this instance.- Parameters
optimizer (Optional[ResettableOptimizer]) –
device (Optional[device]) –
model (Optional[Module]) –
- Return type
None
- second_stage_train(callback_context=None)[source]
Do a second stage training for calibration using Laplace approximation. This is a wrapper around
hybrid_learning.concepts.train_eval.train_eval_funs.second_stage_train()
that uses defaults fromsettings
. Before and after the second stage training process one epoch on the test and the validation set to enable logging of metrics for comparison.Note
Evaluation runs on validation and test split are conducted before (epoch 0) and after (epoch 1) the second stage training.
- train(train_loader=None, val_loader=None, pbar_desc_templ='Epoch {epoch}/{epochs}', callback_context=None)[source]
Train the model according to the specified training parameters. Defaults are taken from
settings
. To override specifycustom_args
(compare arguments totrain_val_one_epoch
).- Parameters
callback_context (Optional[Dict[str, Any]]) – see
callback_context
pbar_desc_templ (str) –
template for the progress bar description; must accept as substring
{epoch}
: the current epoch number{tot_epoch}
: the total number of epochs
train_loader (Optional[
torch.utils.data.DataLoader
]) – seetrain_one_epoch()
val_loader (Optional[
torch.utils.data.DataLoader
]) – seeevaluate()
- Returns
Two pandas.DataFrame with history information on
training: the epoch- and batch-wise loss and KPI results on the training data, index is a multi-index of
(epoch, batch)
;test: the epoch-wise evaluation results on the test set; index is the epoch index;
columns for both are
loss
and KPI names (keys ofmetric_fns
)- Return type
Tuple[DataFrame, DataFrame]
- train_one_epoch(train_loader=None, callback_context=None, pbar_desc='Train progress')[source]
Train for one epoch and return history results as
pandas.DataFrame
. This is a wrapper aroundhybrid_learning.concepts.train_eval.train_eval_funs.train_one_epoch()
that uses defaults fromsettings
.- Parameters
train_loader (Optional[
torch.utils.data.DataLoader
]) – the training loader to use; defaults to a shuffled one with trainingdata
ofself
callback_context (Optional[Dict[str, Any]]) – see
callback_context
pbar_desc (str) – leading static description text for the progress bar if newly created
- Returns
tuple of training history and test results as
pandas.DataFrame
with:- columns
loss
and the KPI names (keys from dictmetric_fns
),- index
the batch indices,
- items
the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)
- Return type
DataFrame
- train_val_one_epoch(pbar_desc=None, callback_context=None, train_loader=None, val_loader=None)[source]
Train for one epoch, evaluate, and return history and test results. This is a wrapper around
train_one_epoch()
andevaluate()
with nice progress bar printing and logging after the epoch. History and test results are stored in apandas.DataFrame
. The device used for training is that of the parameters of the used model (seedevice_of()
).- Parameters
pbar_desc (Optional[str]) – leading static description text for the progress bar
callback_context (Optional[Dict[str, Any]]) – see
callback_context
train_loader (Optional[
torch.utils.data.DataLoader
]) – seetrain_one_epoch()
val_loader (Optional[
torch.utils.data.DataLoader
]) – seeevaluate()
- Returns
tuple of training history and test results; columns resp. index are
loss
and the KPI names (keys from dictmetric_fns
).pandas.DataFrame
: index are the batch indices, items are the results of KPI evaluations of the output on the training batch (i.e. before back-propagation step)pandas.Series
: the items are the final evaluations of the KPIs on the validation set
- Return type
Tuple[DataFrame, Series]
- LOSS_KEY = 'loss'
Key for the loss evaluation results.
- NLL_KEY = 'NLL'
Key for the proper scoring evaluation function used for second stage training. Typically a negative log-likelihood.
- callback_context: Dict[str, Any]
The default callback context values to use. In any training run where context is used, the context can either be handed over or it defaults to a copy of this dict.
- callbacks: List[Mapping[CallbackEvents, Callable]]
A dictionary mapping events to a list of callables that are called every time the event occurs with the current state. Some default logging callbacks are defined.
For details on available events, see
hybrid_learning.concepts.train_eval.callbacks.CallbackEvents
. After the event, all callbacks for this event are called in order with keyword arguments from ahybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.callback_context
. The base context is dependent on the event and includes e.g. the model, and can be extended by specifying the callback context during function call or in the default callback context. Note that side effects on objects in the callback context (e.g. the model) will influence the training. Callback application examples:Logging
Storing of best n models
Results saving, e.g. to tensorboard or sacred log
Forced weight normalization
- data: DataTriple
Train, validation, and test data splits to use. Must be converted to data loaders before usage.
- device: torch.device
Device to run training and testing on (this is where the data loaders are put).
- early_stopping_handle: Optional[EarlyStoppingHandle]
Handle that is stepped during training and indicates need for early stopping. To disable early stopping, set
early_stopping_handle
toNone
resp. specifyearly_stopping_handle=False
in__init__
arguments.
- epochs: int
Default maximum number of epochs. May be reduced by
early_stopping_handle
.
- property kpi_fns: Dict[str, Callable[[Tensor, Tensor], Tensor]]
Metric and loss functions. Nomenclature:
metric_fns
holds all metrics meant for evaluation, whilekpi_fns
also encompasses the losses.
- loss_fn: Callable[[Tensor, Tensor], Tensor]
Loss function callable. Defaults to a balanced binary cross-entropy assuming on average 1% positive px per img. Must be wrapped into a tuple to hide the parameters, since these are not to be updated.
- metric_fns: Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]
Dictionary of metric functions to apply for evaluation and logging. Each function must have a signature of
(output, target) -> metric_value
. See alsokpi_fns
. Keys must not containLOSS_KEY
orNLL_KEY
.
- model: torch.nn.modules.module.Module
The model to work on.
- nll_fn: Callable[[Tensor, Tensor], Tensor]
Proper scoring function callable used as loss in second stage training for Laplace approximation. Usually is chosen as negative log-likelihood, defaults to
loss_fn
.
- num_workers: int
The default number of workers to use for data loading. See
hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.loader()
.
- optimizer: ResettableOptimizer
Optimizer and learning rate scheduler handle.