Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing relevance loss in max_pool_nd_inverse #19

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 15 additions & 31 deletions inverter_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,6 @@ def silent_pass(self, m, in_tensor: torch.Tensor,
# to store any specific data. Still useful for module tracking.
pass

@staticmethod
def get_inv_max_pool_method(max_pool_instance):
"""
Get dimension-specific max_pooling layer.
The forward pass and inversion are made in a
'dimensionality-agnostic' manner and are the same for
all nd instances of the layer, except for the functional
that needs to be used.

Args:
max_pool_instance: instance of max_pool layer.

Returns:
The correct functional used in the max_pooling layer.

"""

conv_func_mapper = {
torch.nn.MaxPool1d: F.max_unpool1d,
torch.nn.MaxPool2d: F.max_unpool2d,
torch.nn.MaxPool3d: F.max_unpool3d
}
return conv_func_mapper[type(max_pool_instance)]

def linear_inverse(self, m, relevance_in):

if self.method == "e-rule":
Expand Down Expand Up @@ -347,18 +323,26 @@ def linear_fwd_hook(self, m, in_tensor: torch.Tensor,
return

def max_pool_nd_inverse(self, layer_instance, relevance_in):

# In case the output had been reshaped for a linear layer,
# make sure the relevance is put into the same shape as before.
relevance_in = relevance_in.view(layer_instance.out_shape)

invert_pool = self.get_inv_max_pool_method(layer_instance)
inverted = invert_pool(relevance_in, layer_instance.indices,
layer_instance.kernel_size, layer_instance.stride,
layer_instance.padding, output_size=layer_instance.in_shape)
del layer_instance.indices
z = torch.zeros(layer_instance.in_shape)
# since scatter add will only work on a single dim, for MaxPool2d and MaxPool3d tensors need some flattening
is_2d_or_3d_maxpool = isinstance(layer_instance, torch.nn.MaxPool2d) or isinstance(layer_instance, torch.nn.MaxPool3d)
if is_2d_or_3d_maxpool:
z = z.flatten(2, len(layer_instance.in_shape)-1)
relevance_in = relevance_in.flatten(2, len(layer_instance.in_shape)-1)
indices = layer_instance.indices.reshape(relevance_in.shape)

relevance_out = z.scatter_add(2, indices, relevance_in)

# for MaxPool2d and MaxPool3d the correct shape has to be restored
if is_2d_or_3d_maxpool:
relevance_out.reshape(layer_instance.in_shape)

return relevance_out

return inverted

@module_tracker
def max_pool_nd_fwd_hook(self, m, in_tensor: torch.Tensor,
Expand Down