ConceptSegmentation2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_segmentation.ConceptSegmentation2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
Bases:
ConceptDetection2DTrainTestHandle
Train and test handle for concept segmentation providing loss and metric defaults. See
ConceptSegmentationModel2D
for details on the model.Public Data Attributes:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
DEFAULT_MASK_INTERPOLATION
Interpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEY
Key for the loss evaluation results.
NLL_KEY
Key for the proper scoring evaluation function used for second stage training.
kpi_fns
Metric and loss functions.
settings
The current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks
(callbacks)Append the given callbacks.
remove_callback
(callback)Remove a single given callback.
loader
([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer
([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles
([optimizer, device, ...])(Re)set all handles associated with training, and move to
device
.reset_kpis
([kpis])Resets aggregating kpis
train
([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate
([num_splits, train_val_data, ...])Record training results for
num_splits
distinct val splits.train_val_one_epoch
([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch
([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame
.evaluate
([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train
([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
__init__
(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__
()Return repr(self).
__init__
(concept_model, data, *[, ...])Init.
- Parameters
concept_model (ConceptDetectionModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- callback_context: Dict[str, Any]
The default callback context values to use. In any training run where context is used, the context can either be handed over or it defaults to a copy of this dict.
- callbacks: List[Mapping[CallbackEvents, Callable]]
A dictionary mapping events to a list of callables that are called every time the event occurs with the current state. Some default logging callbacks are defined.
For details on available events, see
hybrid_learning.concepts.train_eval.callbacks.CallbackEvents
. After the event, all callbacks for this event are called in order with keyword arguments from ahybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.callback_context
. The base context is dependent on the event and includes e.g. the model, and can be extended by specifying the callback context during function call or in the default callback context. Note that side effects on objects in the callback context (e.g. the model) will influence the training. Callback application examples:Logging
Storing of best n models
Results saving, e.g. to tensorboard or sacred log
Forced weight normalization
- data: DataTriple
Train, validation, and test data splits to use. Must be converted to data loaders before usage.
- device: torch.device
Device to run training and testing on (this is where the data loaders are put).
- early_stopping_handle: Optional[EarlyStoppingHandle]
Handle that is stepped during training and indicates need for early stopping. To disable early stopping, set
early_stopping_handle
toNone
resp. specifyearly_stopping_handle=False
in__init__
arguments.
- epochs: int
Default maximum number of epochs. May be reduced by
early_stopping_handle
.
- loss_fn: Callable[[Tensor, Tensor], Tensor]
Loss function callable. Defaults to a balanced binary cross-entropy assuming on average 1% positive px per img. Must be wrapped into a tuple to hide the parameters, since these are not to be updated.
- metric_fns: Dict[str, Union[AggregatingKpi, Callable[[Tensor, Tensor], Tensor]]]
Dictionary of metric functions to apply for evaluation and logging. Each function must have a signature of
(output, target) -> metric_value
. See alsokpi_fns
. Keys must not containLOSS_KEY
orNLL_KEY
.
- model: ConceptDetectionModel2D
The model to work on.
- nll_fn: Callable[[Tensor, Tensor], Tensor]
Proper scoring function callable used as loss in second stage training for Laplace approximation. Usually is chosen as negative log-likelihood, defaults to
loss_fn
.
- num_workers: int
The default number of workers to use for data loading. See
hybrid_learning.concepts.train_eval.base_handles.train_test_handle.TrainEvalHandle.loader()
.
- optimizer: ResettableOptimizer
Optimizer and learning rate scheduler handle.