ConceptDetection2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_detection.ConceptDetection2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
Bases:
TrainEvalHandle
Train test handle for concept localization models. Applies sensible defaults to
model_output_transform
.Public Data Attributes:
Interpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEY
Key for the loss evaluation results.
NLL_KEY
Key for the proper scoring evaluation function used for second stage training.
kpi_fns
Metric and loss functions.
settings
The current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks
(callbacks)Append the given callbacks.
remove_callback
(callback)Remove a single given callback.
loader
([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer
([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles
([optimizer, device, ...])(Re)set all handles associated with training, and move to
device
.reset_kpis
([kpis])Resets aggregating kpis
train
([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate
([num_splits, train_val_data, ...])Record training results for
num_splits
distinct val splits.train_val_one_epoch
([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch
([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame
.evaluate
([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train
([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
__init__
(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__
()Return repr(self).
__init__
(concept_model, data, *[, ...])Init.
- Parameters
concept_model (ConceptDetectionModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- __init__(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
Init.
For further parameter descriptions see
__init__()
ofTrainEvalHandle
.- Parameters
concept_model (ConceptDetectionModel2D) – the concept localization model to work on with concept.
data (DataTriple) – data for the concept model, i.e. Sequence of tuples
(activation, mask)
model_output_transform (Optional[TupleTransforms]) –
metric_input_transform (Optional[TupleTransforms]) –
- DEFAULT_MASK_INTERPOLATION: str = 'bilinear'
Interpolation method used in default transforms for resizing masks to activation map size. Argument may be one of the modes accepted by
torch.nn.functional.interpolate()
.