Concept Analysis Tooling

The tooling for concept analysis is collected in the module hybrid_learning.concepts. For details have a look at the API Reference.

Analysis Handles

A convenient analysis handle with functions to conduct concept analysis steps and store results is ConceptAnalysis.

Concept and Concept Embedding Modelling

A concept is defined by sample data points and is modeled as an instance of Concept. A concept can be used to train a model for predicting/detecting it properly. Here, concepts are used to train a linear concept model (e.g. an instance of ConceptDetectionModel2D), which predicts a concept from the intermediate output of a DNN. The parameters of a trained concept model describe a linear concept embedding of a concept into a layer of a DNN. Embeddings and operations thereon are modeled in ConceptEmbedding. For translation between concept model and concept embedding use

The relevant modelling classes are:

Concept

Representation of a concept with data and meta information.

SegmentationConcept2D

Concept with segmentation data.

ConceptEmbedding

Representation of an embedding of a concept within a DNN.

ConceptDetectionModel2D

Pytorch model implementation of a concept embedding for 2D conv layers.

ConceptSegmentationModel2D

A concept model that segments the concept in an image.

ConceptClassificationModel2D

A concept model that classifies whether an image-level concept is recognized in an activation map.

The concept model classes are accompanied by custom handles for training and testing derived from TrainEvalHandle(). The handles are:

ConceptDetection2DTrainTestHandle

Train test handle for concept localization models.

ConceptSegmentation2DTrainTestHandle

Train and test handle for concept segmentation providing loss and metric defaults.

ConceptClassification2DTrainTestHandle

Train and test handle for concept classification providing loss and metric defaults.

Intermediate Output Retrieval and Model Extension

To retrieve the intermediate output of pytorch models, the pytorch hooks mechanism is used. Wrappers for adding/retrieving intermediate output of DNNs are defined in the module model_extension:

ActivationMapGrabber

Wrapper class to obtain intermediate outputs from models.

ExtendedModelStump

Optionally apply a modification to the model stump output in the forward method.

HooksHandle

Wrapper that registers and unregisters hooks from model that save intermediate output.

ModelExtender

This class wraps a given model and extends its output.

ModelStump

Obtain the intermediate output of a sub-module of a complete NN.

Training and Validation

Training and Validation Interface

The following handle classes from hybrid_learning.concepts.train_eval are used to provide a generic training and validation interface for pytorch models:

EarlyStoppingHandle

Handle encapsulating early stopping checks.

ResettableOptimizer

Wrapper around torch optimizers to enable reset and automatic learning rate handling.

TrainEvalHandle

Handle for training and evaluation of pytorch models.

Furthermore, a set of callback handles is pre-defined that can be added to a training or evaluation run:

Callback

A callback base class that eases implementing a custom callback handle.

CsvLoggingCallback

Extract the values stored in matplotlib figures and store them as CSV.

LoggingCallback

Log batch and epoch KPI results.

ProgressBarUpdater

Update the progress bar postfix after each batch and epoch.

TensorboardLogger

Write batch and epoch KPI results to a tensorboard log directory.

Training KPIs

Some loss and metric functions are available for different training setups of the concept models. They all inherit from torch.nn.Module.

aggregating_kpis

Implementation of KPIs that cannot be computed on a batch level.

batch_kpis

Loss and metric functions and classes that can be calculated per batch.

ConceptDetection2DTrainTestHandle

Train test handle for concept localization models.

ConceptSegmentation2DTrainTestHandle

Train and test handle for concept segmentation providing loss and metric defaults.

ConceptClassification2DTrainTestHandle

Train and test handle for concept classification providing loss and metric defaults.