ConceptDetection2DTrainTestHandle

class hybrid_learning.concepts.models.concept_models.concept_detection.ConceptDetection2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]

Bases: TrainEvalHandle

Train test handle for concept localization models. Applies sensible defaults to model_output_transform.

Public Data Attributes:

DEFAULT_MASK_INTERPOLATION

Interpolation method used in default transforms for resizing masks to activation map size.

Inherited from : py: class:TrainEvalHandle

LOSS_KEY

Key for the loss evaluation results.

NLL_KEY

Key for the proper scoring evaluation function used for second stage training.

kpi_fns

Metric and loss functions.

settings

The current training settings as dictionary.

Public Methods:

Inherited from : py: class:TrainEvalHandle

add_callbacks(callbacks)

Append the given callbacks.

remove_callback(callback)

Remove a single given callback.

loader([data, mode, batch_size, shuffle, ...])

Prepare and return a torch data loader from the dataset according to settings.

reset_optimizer([optimizer, device, model])

Move model to correct device, init optimizer to parameters of model.

reset_training_handles([optimizer, device, ...])

(Re)set all handles associated with training, and move to device.

reset_kpis([kpis])

Resets aggregating kpis

train([train_loader, val_loader, ...])

Train the model according to the specified training parameters.

cross_validate([num_splits, train_val_data, ...])

Record training results for num_splits distinct val splits.

train_val_one_epoch([pbar_desc, ...])

Train for one epoch, evaluate, and return history and test results.

train_one_epoch([train_loader, ...])

Train for one epoch and return history results as pandas.DataFrame.

evaluate([mode, val_loader, prefix, ...])

Evaluate the model wrt.

second_stage_train([callback_context])

Do a second stage training for calibration using Laplace approximation.

Special Methods:

__init__(concept_model, data, *[, ...])

Init.

Inherited from : py: class:TrainEvalHandle

__repr__()

Return repr(self).

__init__(concept_model, data, *[, ...])

Init.


Parameters
__init__(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]

Init.

For further parameter descriptions see __init__() of TrainEvalHandle.

Parameters
DEFAULT_MASK_INTERPOLATION: str = 'bilinear'

Interpolation method used in default transforms for resizing masks to activation map size. Argument may be one of the modes accepted by torch.nn.functional.interpolate().