Extending DNN model outputs
The module model_extension
provides
functionality to wrap models and extend their output by attaching
further modules to intermediate layers.
The base class for this is HooksHandle
:
It utilizes the
pytorch hooking mechanism
to extract intermediate layer outputs (no code from the cited tutorial is used).
Derived classes are:
Wrapper class to obtain intermediate outputs from models. |
|
This class wraps a given model and extends its output. |
|
Obtain the intermediate output of a sub-module of a complete NN. |
|
Optionally apply a modification to the model stump output in the forward method. |
Some exemplary applications are collected below.
Extend model output by activation maps
This can be achieved by the
Mini-example
Consider a simple example network composed of two linear operations realizing \(tnn(x) = 3\cdot(x+2)\):
>>> import torch
>>> class SampleNN(torch.nn.Module):
... def __init__(self):
... super(SampleNN, self).__init__()
... self.l1, self.l2 = torch.nn.Linear(1, 1), torch.nn.Linear(1, 1)
... def forward(self, x):
... return self.l2(self.l1(x))
...
>>> nn: SampleNN = SampleNN()
>>> nn.l1.weight.data, nn.l1.bias.data, nn.l2.weight.data, nn.l2.bias.data \
... = [torch.Tensor([i]) for i in [[1], 2, [3], 0]]
Now we can wrap the model and retrieve the output of l1
.
The output of the wrapper now consists of a tuple of
the output of the wrapped model and
a dict with the outputs of the registered sub-modules.
The output given 1 thus captures \(nn(1) = 3(1+2) = 9\) and \(l1(1) = 1+2 = 3\):
>>> from hybrid_learning.concepts.models.model_extension import ActivationMapGrabber
>>> wnn = ActivationMapGrabber(nn, ['l1'])
>>> print(wnn(torch.Tensor([[1]])))
(tensor([[9.]]...), {'l1': tensor([[3.]]...)})
>>> from hybrid_learning.concepts.models.model_extension import ActivationMapGrabber
>>> wnn = ActivationMapGrabber(nn, ['l1'])
>>> print(wnn(torch.Tensor([[1]])))
(tensor([[9.]]...), {'l1': tensor([[3.]]...)})
>>> from hybrid_learning.concepts.models.model_extension import ActivationMapGrabber
>>> wnn = ActivationMapGrabber(nn, ['l1'])
>>> print(wnn(torch.Tensor([[1]])))
(tensor([[9.]]...), {'l1': tensor([[3.]]...)})
More complex example
In the following a pre-trained Mask R-CNN model is wrapped to obtain
the outputs of the last two convolutional blocks in the backend.
These are handed over via a list of their keys in the
named_modules()
dict.
Note that they are registered in the order the IDs are given.
>>> from torchvision.models.detection import maskrcnn_resnet50_fpn
>>> nn = maskrcnn_resnet50_fpn(pretrained=True)
>>> wnn = ActivationMapGrabber(
... model=nn,
... module_ids=['backbone.body.layer3', 'backbone.body.layer4']
... )
>>> print(wnn.registered_submodules)
['backbone.body.layer3', 'backbone.body.layer4']
Manual un-/re-/registration of layers looks as follows (note that newly registered layers are simply appended and a module cannot be registered twice):
>>> wnn.unregister_submodule('backbone.body.layer3') # unregister
>>> print(wnn.registered_submodules)
['backbone.body.layer4']
>>> wnn.register_submodule('backbone.body.layer3') # (re-)register
>>> print(wnn.registered_submodules)
['backbone.body.layer4', 'backbone.body.layer3']
>>> wnn.register_submodule('backbone.body.layer4') # register existing
>>> print(wnn.registered_submodules)
['backbone.body.layer4', 'backbone.body.layer3']
The output of the wrapped model now looks as follows:
>>> img_t = torch.rand(size=(3, 400,400)) # example
>>> wnn_out = wnn.eval()(img_t.unsqueeze(0))
>>> len(wnn_out)
2
>>> type(wnn_out[0]) # the Mask R-CNN output
<class 'list'>
>>> type(wnn_out[1]) # the layer outputs
<class 'dict'>
>>> for k in sorted(wnn_out[1].keys()):
... print(k, ":", wnn_out[1][k].size())
backbone.body.layer3 : torch.Size([1, 1024, 50, 50])
backbone.body.layer4 : torch.Size([1, 2048, 25, 25])
Attaching Further Modules
An application of accessing DNN intermediate outputs is to extend the DNN by attaching further (output) modules. As an example, we will extend a Mask R-CNN by several fully connected output neurons at two layers. The main model is:
>>> from torchvision.models.detection import maskrcnn_resnet50_fpn
>>> main_model = maskrcnn_resnet50_fpn(pretrained=True)
Defining the extended (wrapped) mdel
To extend the model, hand over the model and a dictionary of extensions. This dictionary maps layer IDs to named modules, i.e. dicts of extension modules indexed by unique IDs. Consider as extensions some simple linear models (mind the different sizes of different layer outputs):
>>> import torch, numpy
>>> linear_attach = lambda in_size: torch.nn.Sequential(
... torch.nn.Flatten(),
... torch.nn.Linear(int(numpy.prod(in_size)), out_features=1)
... )
>>>
>>> from hybrid_learning.concepts.models.model_extension import ModelExtender
>>> extended = ModelExtender(
... main_model, extensions={
... 'backbone.body.layer3': {'3_1': linear_attach([1024, 50, 50]),
... '3_2': linear_attach([1024, 50, 50])},
... 'backbone.body.layer4': {'4_1': linear_attach([2048, 25, 25])},
... })
>>> print(extended)
ModelExtender(
(wrapped_model): MaskRCNN(...)
(extension_models): ModuleDict(
(3_1): Sequential(...)
(3_2): Sequential(...)
(4_1): Sequential(...)
)
)
Each extension module ID must be unique amongst all IDs. The mapping from layers and extensions attached to these layers is the list of registrations:
>>> extended.name_registrations
{'backbone.body.layer3': ['3_1', '3_2'], 'backbone.body.layer4': ['4_1']}
>>> extended.extensions
{'backbone.body.layer3': {'3_1': Sequential(...), '3_2': Sequential(...)},
'backbone.body.layer4': {'4_1': Sequential(...)}}
The extended output
When evaluated, the output is a tuple of the Mask R-CNN output and a dict with the output of each extension indexed by their ID. This can now normally be backpropagated:
>>> out = extended.eval()(torch.rand((1, 3, 400, 400)))
>>> print(out)
([{'boxes': ...}],
{'3_1': tensor(...),
'3_2': tensor(...),
'4_1': tensor(...)})
>>> _ = out[1]['3_1'].backward(torch.randn(1, 1))
Handling Registrations of Attachments
For un-/(re-)registration of modules use
register_extensions()
and unregister_extension()
:
>>> extended.register_extensions(
... {'backbone.body.layer4': {'4_2': linear_attach([2048, 25, 25])}})
>>> extended.name_registrations
{'backbone.body.layer3': ['3_1', '3_2'],
'backbone.body.layer4': ['4_1', '4_2']}
>>> extended.register_extensions(
... {'backbone.body.layer4': {'4_2': linear_attach([2048, 25, 25])}})
Traceback (most recent call last):
...
ValueError: Tried to overwrite module under existing name: 4_2
>>>
>>> extended.unregister_extension('4_1')
>>> extended.name_registrations
{'backbone.body.layer3': ['3_1', '3_2'], 'backbone.body.layer4': ['4_2']}
>>> extended.unregister_extension('4_1')
Traceback (most recent call last):
...
KeyError: 'Tried to unregister extension of unknown name 4_1'