ConceptClassification2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_classification.ConceptClassification2DTrainTestHandle(concept_model, data, *, model_output_transform=False, metric_input_transform=False, **kwargs)[source]
Bases:
ConceptDetection2DTrainTestHandle
Train and test handle for concept classification providing loss and metric defaults. See
ConceptClassificationModel2D
for details on the model.Public Data Attributes:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
DEFAULT_MASK_INTERPOLATION
Interpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEY
Key for the loss evaluation results.
NLL_KEY
Key for the proper scoring evaluation function used for second stage training.
kpi_fns
Metric and loss functions.
settings
The current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks
(callbacks)Append the given callbacks.
remove_callback
(callback)Remove a single given callback.
loader
([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer
([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles
([optimizer, device, ...])(Re)set all handles associated with training, and move to
device
.reset_kpis
([kpis])Resets aggregating kpis
train
([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate
([num_splits, train_val_data, ...])Record training results for
num_splits
distinct val splits.train_val_one_epoch
([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch
([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame
.evaluate
([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train
([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
__init__
(concept_model, data, *[, ...])Init.
Inherited from : py: class:ConceptDetection2DTrainTestHandle
__init__
(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__
()Return repr(self).
__init__
(concept_model, data, *[, ...])Init.
- Parameters
concept_model (ConceptClassificationModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- __init__(concept_model, data, *, model_output_transform=False, metric_input_transform=False, **kwargs)[source]
Init.
For details on the init parameters see the init of the super class
ConceptDetection2DTrainTestHandle
.- Parameters
concept_model (ConceptClassificationModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –