ConceptClassification2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_classification.ConceptClassification2DTrainTestHandle(concept_model, data, *, model_output_transform=False, metric_input_transform=False, **kwargs)[source]
Bases:
ConceptDetection2DTrainTestHandleTrain and test handle for concept classification providing loss and metric defaults. See
ConceptClassificationModel2Dfor details on the model.Public Data Attributes:
Inherited from : py: class:ConceptDetection2DTrainTestHandle
DEFAULT_MASK_INTERPOLATIONInterpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEYKey for the loss evaluation results.
NLL_KEYKey for the proper scoring evaluation function used for second stage training.
kpi_fnsMetric and loss functions.
settingsThe current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks(callbacks)Append the given callbacks.
remove_callback(callback)Remove a single given callback.
loader([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles([optimizer, device, ...])(Re)set all handles associated with training, and move to
device.reset_kpis([kpis])Resets aggregating kpis
train([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate([num_splits, train_val_data, ...])Record training results for
num_splitsdistinct val splits.train_val_one_epoch([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame.evaluate([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
__init__(concept_model, data, *[, ...])Init.
Inherited from : py: class:ConceptDetection2DTrainTestHandle
__init__(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__()Return repr(self).
__init__(concept_model, data, *[, ...])Init.
- Parameters
concept_model (ConceptClassificationModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- __init__(concept_model, data, *, model_output_transform=False, metric_input_transform=False, **kwargs)[source]
Init.
For details on the init parameters see the init of the super class
ConceptDetection2DTrainTestHandle.- Parameters
concept_model (ConceptClassificationModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –