ModelExtender

class hybrid_learning.concepts.models.model_extension.ModelExtender(model, extensions, return_orig_out=True)[source]

Bases: ActivationMapGrabber

This class wraps a given model and extends its output. The extension are models taking intermediate output of the original model at given sub-modules. An extension is specified by the information in a dictionary {<sub-module ID> : {<name>: <model>}}.

where the sub-module must be one of the wrapped model, and the <model> is the torch.nn.Module to feed the sub-module output. The name must be unique amongst all registered models: It is checked when registering new extensions and used as key for the extension model outputs.

Extensions can be registered and unregistered using the corresponding methods register_extension() and unregister_extension().

The information about registered extensions can be accessed via the following properties:

  • extensions: extension models indexed by sub-module ID in the format described above

  • extension_models: Just a dict-like with registered models by name

  • name_registrations: Just a dict with registered extension names by sub-module

The output of a forward run then is a tuple of the main model output and a dict {<name>: <ext model output>}. If return_orig_out is False, only the dict is returned.

Public Data Attributes:

name_registrations

Dict mapping main model sub-modules to their registered extension model names.

extensions

Nested dict holding all extension modules indexed by ID and layer.

extension_names

List of the names of all registered extensions.

Inherited from : py: class:HooksHandle

registered_submodules

List of IDs of the registered sub-modules.

Inherited from : py: class:Module

dump_patches

This allows better BC support for load_state_dict().

T_destination

alias of TypeVar('T_destination', bound=Mapping[str, Tensor])

Public Methods:

register_extension(name, module_id, model)

Register a new extension model as name.

unregister_extension(name)

Unregister an existing extension by name.

register_extensions(new_extensions)

Register all specified new extensions.

forward(*inps)

Pytorch forward method.

Inherited from : py: class:ActivationMapGrabber

forward(*inps)

Pytorch forward method.

stump(module_id)

Provide a ModelStump (in eval mode) which yields act maps of given sub-module.

Inherited from : py: class:HooksHandle

register_submodule(module_id)

Register further submodule of to extract intermediate output from.

unregister_submodule(module_id)

Unregister a submodule for intermediate output retrieval.

get_module_by_id(m_id)

Get actual sub-module object within wrapped model by module ID.

forward(*inps)

Pytorch forward method.

Inherited from : py: class:Module

forward(*inps)

Pytorch forward method.

register_buffer(name, tensor[, persistent])

Adds a buffer to the module.

register_parameter(name, param)

Adds a parameter to the module.

add_module(name, module)

Adds a child module to the current module.

get_submodule(target)

Returns the submodule given by target if it exists, otherwise throws an error.

get_parameter(target)

Returns the parameter given by target if it exists, otherwise throws an error.

get_buffer(target)

Returns the buffer given by target if it exists, otherwise throws an error.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

cuda([device])

Moves all model parameters and buffers to the GPU.

xpu([device])

Moves all model parameters and buffers to the XPU.

cpu()

Moves all model parameters and buffers to the CPU.

type(dst_type)

Casts all parameters and buffers to dst_type.

float()

Casts all floating point parameters and buffers to float datatype.

double()

Casts all floating point parameters and buffers to double datatype.

half()

Casts all floating point parameters and buffers to half datatype.

bfloat16()

Casts all floating point parameters and buffers to bfloat16 datatype.

to_empty(*, device)

Moves the parameters and buffers to the specified device without copying storage.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

register_backward_hook(hook)

Registers a backward hook on the module.

register_full_backward_hook(hook)

Registers a backward hook on the module.

register_forward_pre_hook(hook)

Registers a forward pre-hook on the module.

register_forward_hook(hook)

Registers a forward hook on the module.

state_dict([destination, prefix, keep_vars])

Returns a dictionary containing a whole state of the module.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

parameters([recurse])

Returns an iterator over module parameters.

named_parameters([prefix, recurse])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

buffers([recurse])

Returns an iterator over module buffers.

named_buffers([prefix, recurse])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

children()

Returns an iterator over immediate children modules.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

modules()

Returns an iterator over all modules in the network.

named_modules([memo, prefix, remove_duplicate])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

train([mode])

Sets the module in training mode.

eval()

Sets the module in evaluation mode.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

zero_grad([set_to_none])

Sets gradients of all model parameters to zero.

share_memory()

See torch.Tensor.share_memory_()

extra_repr()

Set the extra representation of the module

Special Methods:

__init__(model, extensions[, return_orig_out])

Init.

Inherited from : py: class:ActivationMapGrabber

__init__(model, extensions[, return_orig_out])

Init.

Inherited from : py: class:HooksHandle

__init__(model, extensions[, return_orig_out])

Init.

__del__()

Unregister all hooks held by this handle on handle delete.

Inherited from : py: class:Module

__init__(model, extensions[, return_orig_out])

Init.

__call__(*input, **kwargs)

Call self as a function.

__setstate__(state)

__getattr__(name)

__setattr__(name, value)

Implement setattr(self, name, value).

__delattr__(name)

Implement delattr(self, name).

__repr__()

Return repr(self).

__dir__()

Default dir() implementation.


Parameters
  • model (Module) –

  • extensions (Dict[str, Dict[str, Module]]) –

  • return_orig_out (bool) –

__init__(model, extensions, return_orig_out=True)[source]

Init.

Parameters
forward(*inps)[source]

Pytorch forward method.

Returns

Tuple of the form (<main model out>, {<ext name>: <ext out>}).

Parameters

inps (Sequence[Tensor]) –

Return type

Union[Tuple[Any, Dict[str, Any]], Dict[str, Any]]

register_extension(name, module_id, model)[source]

Register a new extension model as name. Updates the hooks needed for acquiring extension output.

Raise

ValueError if there is a name for which already an extension is registered.

Parameters
  • name (str) –

  • module_id (str) –

  • model (Module) –

Return type

None

register_extensions(new_extensions)[source]

Register all specified new extensions.

Parameters

new_extensions (Dict[str, Dict[str, Module]]) – extensions in the format {module_id: {extension_name: extension_module}}

Raise

ValueError if there is a name for which already an extension is registered.

Return type

None

unregister_extension(name)[source]

Unregister an existing extension by name. Updates the hooks and the registration lists.

Parameters

name (str) –

Return type

None

extension_models: torch.nn.modules.container.ModuleDict

Dictionary of extension_models modules indexed by the layer they are applied to. Do only change via register_extension() and unregister_extension(), as the indices must be in synchronization with registered submodules.

property extension_names: List[str]

List of the names of all registered extensions.

property extensions: Dict[str, Dict[str, torch.nn.modules.module.Module]]

Nested dict holding all extension modules indexed by ID and layer. Merged information in name_registrations and extension_models.

Returns

Dict of the form {<sub-module ID>: {<ext name>: <registered ext model>}} The name is unique amongst all registered extension models over all sub-modules

hook_handles: Dict[str, torch.utils.hooks.RemovableHandle]

Dictionary of hooks; for each sub-module to grab output from, a hook is registered. On each forward, the hook for a sub-module of ID m writes the intermediate output of the sub-module into _intermediate_outs[m]. The dictionary saves for the sub-module ID the hook handle.

property name_registrations: Dict[str, List[str]]

Dict mapping main model sub-modules to their registered extension model names. The names of the extensions must match those used as keys in extension_models: {<sub-module ID>: [<extension name>, ...]}

return_orig_out: bool

Whether to return a tuple (original_output, extension_outputs) or only the dict extension_outputs.

training: bool
wrapped_model: torch.nn.modules.module.Module

Original model from which intermediate and final output are retrieved.