ConceptDetection2DTrainTestHandle
- class hybrid_learning.concepts.models.concept_models.concept_detection.ConceptDetection2DTrainTestHandle(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
Bases:
TrainEvalHandleTrain test handle for concept localization models. Applies sensible defaults to
model_output_transform.Public Data Attributes:
Interpolation method used in default transforms for resizing masks to activation map size.
Inherited from : py: class:TrainEvalHandle
LOSS_KEYKey for the loss evaluation results.
NLL_KEYKey for the proper scoring evaluation function used for second stage training.
kpi_fnsMetric and loss functions.
settingsThe current training settings as dictionary.
Public Methods:
Inherited from : py: class:TrainEvalHandle
add_callbacks(callbacks)Append the given callbacks.
remove_callback(callback)Remove a single given callback.
loader([data, mode, batch_size, shuffle, ...])Prepare and return a torch data loader from the dataset according to settings.
reset_optimizer([optimizer, device, model])Move model to correct device, init optimizer to parameters of model.
reset_training_handles([optimizer, device, ...])(Re)set all handles associated with training, and move to
device.reset_kpis([kpis])Resets aggregating kpis
train([train_loader, val_loader, ...])Train the model according to the specified training parameters.
cross_validate([num_splits, train_val_data, ...])Record training results for
num_splitsdistinct val splits.train_val_one_epoch([pbar_desc, ...])Train for one epoch, evaluate, and return history and test results.
train_one_epoch([train_loader, ...])Train for one epoch and return history results as
pandas.DataFrame.evaluate([mode, val_loader, prefix, ...])Evaluate the model wrt.
second_stage_train([callback_context])Do a second stage training for calibration using Laplace approximation.
Special Methods:
__init__(concept_model, data, *[, ...])Init.
Inherited from : py: class:TrainEvalHandle
__repr__()Return repr(self).
__init__(concept_model, data, *[, ...])Init.
- Parameters
concept_model (ConceptDetectionModel2D) –
data (DataTriple) –
model_output_transform (TupleTransforms) –
metric_input_transform (TupleTransforms) –
- __init__(concept_model, data, *, model_output_transform=None, metric_input_transform=None, **kwargs)[source]
Init.
For further parameter descriptions see
__init__()ofTrainEvalHandle.- Parameters
concept_model (ConceptDetectionModel2D) – the concept localization model to work on with concept.
data (DataTriple) – data for the concept model, i.e. Sequence of tuples
(activation, mask)model_output_transform (Optional[TupleTransforms]) –
metric_input_transform (Optional[TupleTransforms]) –
- DEFAULT_MASK_INTERPOLATION: str = 'bilinear'
Interpolation method used in default transforms for resizing masks to activation map size. Argument may be one of the modes accepted by
torch.nn.functional.interpolate().