ConceptSegmentation2DTrainTestHandle

class hybrid_learning.concepts.models.concept_models.concept_segmentation.ConceptSegmentation2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]

Bases: ConceptDetection2DTrainTestHandle

Train and test handle for concept segmentation providing loss and metric defaults. See ConceptSegmentationModel2D for details on the model.

Public Data Attributes:

Inherited from : py: class:ConceptDetection2DTrainTestHandle

DEFAULT_MASK_INTERPOLATION

Interpolation method used in default transforms for resizing masks to activation map size.

Inherited from : py: class:TrainEvalHandle

LOSS_KEY

Key for the loss evaluation results.

NLL_KEY

Key for the proper scoring evaluation function used for second stage training.

kpi_fns

Metric and loss functions.

settings

The current training settings as dictionary.

Public Methods:

Inherited from : py: class:TrainEvalHandle

add_callbacks(callbacks)

Append the given callbacks.

remove_callback(callback)

Remove a single given callback.

loader([data, mode, batch_size, shuffle, ...])

Prepare and return a torch data loader from the dataset according to settings.

reset_optimizer([optimizer, device, model])

Move model to correct device, init optimizer to parameters of model.

reset_training_handles([optimizer, device, ...])

(Re)set all handles associated with training, and move to device.

reset_kpis([kpis])

Resets aggregating kpis

train([train_loader, val_loader, ...])

Train the model according to the specified training parameters.

cross_validate([num_splits, train_val_data, ...])

Record training results for num_splits distinct val splits.

train_val_one_epoch([pbar_desc, ...])

Train for one epoch, evaluate, and return history and test results.

train_one_epoch([train_loader, ...])

Train for one epoch and return history results as pandas.DataFrame.

evaluate([mode, val_loader, prefix, ...])

Evaluate the model wrt.

second_stage_train([callback_context])

Do a second stage training for calibration using Laplace approximation.

Special Methods:

Inherited from : py: class:ConceptDetection2DTrainTestHandle

__init__(concept_model, data, *[, ...])

Init.

Inherited from : py: class:TrainEvalHandle

__repr__()

Return repr(self).

__init__(concept_model, data, *[, ...])

Init.


Parameters
batch_size: int

Default training batch size.

batch_size_hessian: int

Default batch size for calculating the hessian.

batch_size_val: int

Default validation batch size.

callback_context: Dict[str, Any]

The default callback context values to use. In any training run where context is used, the context can either be handed over or it defaults to a copy of this dict.

callbacks: List[Mapping[CallbackEvents, Callable]]

A dictionary mapping events to a list of callables that are called every time the event occurs with the current state. Some default logging callbacks are defined.

For details on available events, see hybrid_learning.concepts.train_eval.callbacks.CallbackEvents. After the event, all callbacks for this event are called in order with keyword arguments from a hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.callback_context. The base context is dependent on the event and includes e.g. the model, and can be extended by specifying the callback context during function call or in the default callback context. Note that side effects on objects in the callback context (e.g. the model) will influence the training. Callback application examples:

  • Logging

  • Storing of best n models

  • Results saving, e.g. to tensorboard or sacred log

  • Forced weight normalization

Callbacks can be added and removed.

data: DataTriple

Train, validation, and test data splits to use. Must be converted to data loaders before usage.

device: torch.device

Device to run training and testing on (this is where the data loaders are put).

early_stopping_handle: Optional[EarlyStoppingHandle]

Handle that is stepped during training and indicates need for early stopping. To disable early stopping, set early_stopping_handle to None resp. specify early_stopping_handle=False in __init__ arguments.

epochs: int

Default maximum number of epochs. May be reduced by early_stopping_handle.

loss_fn: Callable[[Tensor, Tensor], Tensor]

Loss function callable. Defaults to a balanced binary cross-entropy assuming on average 1% positive px per img. Must be wrapped into a tuple to hide the parameters, since these are not to be updated.

metric_fns: Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]

Dictionary of metric functions to apply for evaluation and logging. Each function must have a signature of (output, target) -> metric_value. See also kpi_fns. Keys must not contain LOSS_KEY or NLL_KEY.

model: ConceptDetectionModel2D

The model to work on.

nll_fn: Callable[[Tensor, Tensor], Tensor]

Proper scoring function callable used as loss in second stage training for Laplace approximation. Usually is chosen as negative log-likelihood, defaults to loss_fn.

num_workers: int

The default number of workers to use for data loading. See hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.loader().

optimizer: ResettableOptimizer

Optimizer and learning rate scheduler handle.

show_progress_bars: str

Whether to show progress bars for batch-wise operations. Value must be a comma-separated concatenation of run types (the values of dataset splits) for which to show progress, or 'always'.