diff --git a/inverter_util.py b/inverter_util.py index 0a58bb5..bce19c4 100644 --- a/inverter_util.py +++ b/inverter_util.py @@ -210,30 +210,6 @@ def silent_pass(self, m, in_tensor: torch.Tensor, # to store any specific data. Still useful for module tracking. pass - @staticmethod - def get_inv_max_pool_method(max_pool_instance): - """ - Get dimension-specific max_pooling layer. - The forward pass and inversion are made in a - 'dimensionality-agnostic' manner and are the same for - all nd instances of the layer, except for the functional - that needs to be used. - - Args: - max_pool_instance: instance of max_pool layer. - - Returns: - The correct functional used in the max_pooling layer. - - """ - - conv_func_mapper = { - torch.nn.MaxPool1d: F.max_unpool1d, - torch.nn.MaxPool2d: F.max_unpool2d, - torch.nn.MaxPool3d: F.max_unpool3d - } - return conv_func_mapper[type(max_pool_instance)] - def linear_inverse(self, m, relevance_in): if self.method == "e-rule": @@ -347,18 +323,26 @@ def linear_fwd_hook(self, m, in_tensor: torch.Tensor, return def max_pool_nd_inverse(self, layer_instance, relevance_in): - # In case the output had been reshaped for a linear layer, # make sure the relevance is put into the same shape as before. relevance_in = relevance_in.view(layer_instance.out_shape) - invert_pool = self.get_inv_max_pool_method(layer_instance) - inverted = invert_pool(relevance_in, layer_instance.indices, - layer_instance.kernel_size, layer_instance.stride, - layer_instance.padding, output_size=layer_instance.in_shape) - del layer_instance.indices + z = torch.zeros(layer_instance.in_shape) + # since scatter add will only work on a single dim, for MaxPool2d and MaxPool3d tensors need some flattening + is_2d_or_3d_maxpool = isinstance(layer_instance, torch.nn.MaxPool2d) or isinstance(layer_instance, torch.nn.MaxPool3d) + if is_2d_or_3d_maxpool: + z = z.flatten(2, len(layer_instance.in_shape)-1) + relevance_in = relevance_in.flatten(2, len(layer_instance.in_shape)-1) + indices = layer_instance.indices.reshape(relevance_in.shape) + + relevance_out = z.scatter_add(2, indices, relevance_in) + + # for MaxPool2d and MaxPool3d the correct shape has to be restored + if is_2d_or_3d_maxpool: + relevance_out.reshape(layer_instance.in_shape) + + return relevance_out - return inverted @module_tracker def max_pool_nd_fwd_hook(self, m, in_tensor: torch.Tensor,