Concept Analysis Tooling
The tooling for concept analysis is collected in the module hybrid_learning.concepts
.
For details have a look at the API Reference.
Analysis Handles
A convenient analysis handle with functions to conduct concept analysis steps
and store results is ConceptAnalysis
.
Concept and Concept Embedding Modelling
A concept is defined by sample data points and is modeled as an instance of
Concept
.
A concept can be used to train a model for predicting/detecting it properly.
Here, concepts are used to train a linear concept model (e.g. an instance of
ConceptDetectionModel2D
),
which predicts a concept from the intermediate output of a DNN.
The parameters of a trained concept model describe a linear concept embedding
of a concept into a layer of a DNN. Embeddings and operations thereon are
modeled in ConceptEmbedding
.
For translation between concept model and concept embedding use
to_embedding()
for model to embedding, andfrom_embedding()
for embedding to model.
The relevant modelling classes are:
Representation of a concept with data and meta information. |
|
Concept with segmentation data. |
|
Representation of an embedding of a concept within a DNN. |
|
Pytorch model implementation of a concept embedding for 2D conv layers. |
|
A concept model that segments the concept in an image. |
|
A concept model that classifies whether an image-level concept is recognized in an activation map. |
The concept model classes are accompanied by custom handles for training and testing derived from
TrainEvalHandle()
.
The handles are:
Train test handle for concept localization models. |
|
Train and test handle for concept segmentation providing loss and metric defaults. |
|
Train and test handle for concept classification providing loss and metric defaults. |
Intermediate Output Retrieval and Model Extension
To retrieve the intermediate output of pytorch models, the pytorch hooks mechanism is used.
Wrappers for adding/retrieving intermediate output of DNNs are defined in the module
model_extension
:
Wrapper class to obtain intermediate outputs from models. |
|
Optionally apply a modification to the model stump output in the forward method. |
|
Wrapper that registers and unregisters hooks from model that save intermediate output. |
|
This class wraps a given model and extends its output. |
|
Obtain the intermediate output of a sub-module of a complete NN. |
Training and Validation
Training and Validation Interface
The following handle classes from hybrid_learning.concepts.train_eval
are used to
provide a generic training and validation interface for pytorch models:
Handle encapsulating early stopping checks. |
|
Wrapper around torch optimizers to enable reset and automatic learning rate handling. |
|
Handle for training and evaluation of pytorch models. |
Furthermore, a set of callback handles is pre-defined that can be added to a training or evaluation run:
A callback base class that eases implementing a custom callback handle. |
|
Extract the values stored in matplotlib figures and store them as CSV. |
|
Log batch and epoch KPI results. |
|
Update the progress bar postfix after each batch and epoch. |
|
Write batch and epoch KPI results to a tensorboard log directory. |
Training KPIs
Some loss and metric functions are available for different training setups of the concept models.
They all inherit from torch.nn.Module
.
Implementation of KPIs that cannot be computed on a batch level. |
|
Loss and metric functions and classes that can be calculated per batch. |
Train test handle for concept localization models. |
|
Train and test handle for concept segmentation providing loss and metric defaults. |
|
Train and test handle for concept classification providing loss and metric defaults. |