ConceptSegmentation2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_segmentation.ConceptSegmentation2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
 Bases:
ConceptDetection2DTrainTestHandleTrain and test handle for concept segmentation providing loss and metric defaults. See
ConceptSegmentationModel2Dfor details on the model.Public Data Attributes:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
DEFAULT_MASK_INTERPOLATIONInterpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEYKey for the loss evaluation results.
NLL_KEYKey for the proper scoring evaluation function used for second stage training.
kpi_fnsMetric and loss functions.
settingsThe current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks(callbacks)Append the given callbacks.
remove_callback(callback)Remove a single given callback.
loader([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles([optimizer, device, ...])(Re)set all handles associated with training, and move to
device.reset_kpis([kpis])Resets aggregating kpis
train([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate([num_splits, train_val_data, ...])Record training results for
num_splitsdistinct val splits.train_val_one_epoch([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame.evaluate([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
__init__(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__()Return repr(self).
__init__(concept_model, data, *[, ...])Init.
- Parameters
 concept_model (ConceptDetectionModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- callback_context: Dict[str, Any]
 The default callback context values to use. In any training run where context is used, the context can either be handed over or it defaults to a copy of this dict.
- callbacks: List[Mapping[CallbackEvents, Callable]]
 A dictionary mapping events to a list of callables that are called every time the event occurs with the current state. Some default logging callbacks are defined.
For details on available events, see
hybrid_learning.concepts.train_eval.callbacks.CallbackEvents. After the event, all callbacks for this event are called in order with keyword arguments from ahybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.callback_context. The base context is dependent on the event and includes e.g. the model, and can be extended by specifying the callback context during function call or in the default callback context. Note that side effects on objects in the callback context (e.g. the model) will influence the training. Callback application examples:Logging
Storing of best n models
Results saving, e.g. to tensorboard or sacred log
Forced weight normalization
- data: DataTriple
 Train, validation, and test data splits to use. Must be converted to data loaders before usage.
- device: torch.device
 Device to run training and testing on (this is where the data loaders are put).
- early_stopping_handle: Optional[EarlyStoppingHandle]
 Handle that is stepped during training and indicates need for early stopping. To disable early stopping, set
early_stopping_handletoNoneresp. specifyearly_stopping_handle=Falsein__init__arguments.
- epochs: int
 Default maximum number of epochs. May be reduced by
early_stopping_handle.
- loss_fn: Callable[[Tensor, Tensor], Tensor]
 Loss function callable. Defaults to a balanced binary cross-entropy assuming on average 1% positive px per img. Must be wrapped into a tuple to hide the parameters, since these are not to be updated.
- metric_fns: Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]
 Dictionary of metric functions to apply for evaluation and logging. Each function must have a signature of
(output, target) -> metric_value. See alsokpi_fns. Keys must not containLOSS_KEYorNLL_KEY.
- model: ConceptDetectionModel2D
 The model to work on.
- nll_fn: Callable[[Tensor, Tensor], Tensor]
 Proper scoring function callable used as loss in second stage training for Laplace approximation. Usually is chosen as negative log-likelihood, defaults to
loss_fn.
- num_workers: int
 The default number of workers to use for data loading. See
hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.loader().
- optimizer: ResettableOptimizer
 Optimizer and learning rate scheduler handle.