-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Densenet canonizations #171
base: master
Are you sure you want to change the base?
Conversation
- The returned handles are reversed. - This way when two canonizers change a parameter, removing handles in the returned order will restore the original model
- The epsilon parameter is set to 0 during canonization
- Parameter dimensions are checked before merging, to prevent from attempting merging incompatible layers as in DenseNets.
- Minor change in MergeBatchNorm: set batch_norm.eps=0 in the register method instead of merge_batch_norm - Add MergeBatchNormtoRight canonizer - Merges BathNorm to a linear layer to the right. - If the convolution has padding, one needs to compute a feature map and add it to the output of the convolution to account for the batch norm bias
- Canonizer didn't work correctly when the convolution has bias. This has been handled - The hook function was made lighter by discarding unneeded overhead computation
- Minor change in MergeBatchNormtoRight: remove unused variable - Add DenseNetAdaptiveAvgPoolCanonizer: makes the last adaptive average pooling of torchvision densenets an explicit nn.module object - Add ThreshReLUMergeBatchNorm: Canonizer to canonize BatchNorm -> ReLU -> Linear chains. Adds backwards and forward hooks to ReLU in order to turn it into ThreshReLU as defined in https://github.com/AlexBinder/LRP_Pytorch_Resnets_Densenet/blob/master/canonization_doc.pdf - Add SequentialThreshCanonizer: a composite canonizer that applies DenseNetAvgPoolCanonizer, SequentialMergeBatchNorm, ThreshReLUMergeBatchNorm - Add ThreshSequentialCanonizer: a composite canonizer that applies DenseNetAvgPoolCanonizer, ThreshReLUMergeBatchNorm, SequentialMergeBatchNorm The last two canonizers are the recommended canonizers for torchvision DenseNet implementations. We need to apply the standard SequentialMergeBatchNorm to do away with the initial BN->Conv in the architecture. The two canonizers result in different implementations of the same function because in practice dense blocks have BN->ReLU->Conv->BN->ReLU->Conv which leaves the possibility of using the SequentialMergeBatchNorm inside the DenseBlocks if it is applied before. In practice, both canonizations get rid of the artifacts in the attribution maps. SequentialThreshCanonizer seems to be better quantitatively.
…orrectCompositeCanonizer in the ThreshSequentialCanonizer and SequentialThreshCanonizer classes
Docs: Fix docstrings in MergeBatchNormtoRight and ThreshReLUMergeBatchNorm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey Galip,
sorry for the very long hold-up. Let's try to finalize this. Ultimately, we need to rebase this. Maybe you can first introduce the changes and then rebase.
module.canonization_params = {} | ||
module.canonization_params["bias_kernel"] = bias_kernel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's store these in the canonizer itself, similar to MergeBatchNorm.linear_params
and .batch_norm_params
module.bias.data = (original_weight * shift).sum(dim=1) + original_bias | ||
|
||
# change batch_norm parameters to produce identity | ||
batch_norm.running_mean.data = torch.zeros_like(batch_norm.running_mean.data) | ||
batch_norm.running_var.data = torch.ones_like(batch_norm.running_var.data) | ||
batch_norm.bias.data = torch.zeros_like(batch_norm.bias.data) | ||
batch_norm.weight.data = torch.ones_like(batch_norm.weight.data) | ||
batch_norm.eps = 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these need to be adapted to the new approach (see current version of MergeBatchNorm
)
|
||
module.canonization_params = {} | ||
module.canonization_params["bias_kernel"] = bias_kernel | ||
return_handles.append(module.register_forward_hook(MergeBatchNormtoRight.convhook)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the sake of not using Hooks, maybe we can wrap and overwrite the forward function (similar to the ResNet Canonizer)?
temp_module = torch.nn.Conv2d(in_channels=module.in_channels, out_channels=module.out_channels, | ||
kernel_size=module.kernel_size, padding=module.padding, | ||
padding_mode=module.padding_mode, bias=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's indent one line per kwarg
|
||
if isinstance(module, torch.nn.Conv2d): | ||
if module.padding == (0, 0): | ||
module.bias.data = (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this needs to be adapted to object.__setattr__(module, 'bias', (original_weight * shift[index]).sum(dim=[1, 2, 3]) + original_bias)
of instance, which is why deleting instance attributes with the same name reverts them to the original | ||
function. | ||
''' | ||
self.module.features = Sequential(*list(self.module.features.children())[:-2]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I remember correctly, you can slice Sequential
, i.e. self.module.feature = self.module.features[:-2]
''' | ||
return DenseNetAdaptiveAvgPoolCanonizer() | ||
|
||
def register(self, module, attributes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring
for key in self.attribute_keys: | ||
delattr(self.module, key) | ||
|
||
def forward(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring
return out | ||
|
||
|
||
class DenseNetSeqThreshCanonizer(CompositeCanonizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring
)) | ||
|
||
|
||
class DenseNetThreshSeqCanonizer(CompositeCanonizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing docstring
Hello,
Here are a summary of the contributions:
Further more BN->ReLU->AvgPool->Linear chains are found and canonized using the same method, because Batch normalization commutes with average pooling.
6.Full proposed canonizers are added to torchvision.py. Another addition is DenseNetAdaptiveAvgPoolCanonizer which is needed before applying other canonizers to densenets. It makes the final ReLU and AvgPooling layers of torchvision densenet objects explicit. By default, these are applied in the forward method of the model, not as nn.module objects.
Thank you very much and I am looking forward to any kind of feedback!