model_extension

Description

Wrapper classes to slice and extend torch.nn.modules. The main mechanism used are hooks to obtain layer intermediate output. The base class to use this mechanism is HooksHandle.

This is used to

Classes

ActivationMapGrabber

Wrapper class to obtain intermediate outputs from models.

ExtendedModelStump

Optionally apply a modification to the model stump output in the forward method.

HooksHandle

Wrapper that registers and unregisters hooks from model that save intermediate output.

ModelExtender

This class wraps a given model and extends its output.

ModelStump

Obtain the intermediate output of a sub-module of a complete NN.

Functions

dummy_output(model, input_size[, layer_ids])

Select dummy output of model's given or all layers for all-zero tensor of input_size.

output_size(model, input_size[, ...])

Feed dummy input of input_size to model to determine the output size of the layer with ID layer_id.

output_sizes(model, input_size[, layer_ids, ...])

Obtain the output sizes of the given or all layers for given input size.