From 553c18cd767fd5dece0c8e45aedf4c08d7759ee8 Mon Sep 17 00:00:00 2001 From: Aaron Adcock Date: Mon, 2 Mar 2020 08:24:05 -0800 Subject: [PATCH] Make bn weight decay configurable (#65) Summary: Pull Request resolved: https://github.com/fairinternal/ClassyVision/pull/65 Make the bn weight decay configurable, for some datasets it might be desirable to turn it off. Reviewed By: vreis Differential Revision: D20140487 fbshipit-source-id: 77debf2c4600a080081668565d70b7a3ddc788f4 --- classy_vision/models/resnext.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/classy_vision/models/resnext.py b/classy_vision/models/resnext.py index d017d9314a..6bd525e34b 100644 --- a/classy_vision/models/resnext.py +++ b/classy_vision/models/resnext.py @@ -235,6 +235,7 @@ def __init__( base_width_and_cardinality: Optional[Union[Tuple, List]] = None, basic_layer: bool = False, final_bn_relu: bool = True, + bn_weight_decay: Optional[bool] = False, ): """ Implementation of `ResNeXt `_. @@ -251,6 +252,7 @@ def __init__( assert all(is_pos_int(n) for n in num_blocks) assert is_pos_int(init_planes) and is_pos_int(reduction) assert type(small_input) == bool + assert type(bn_weight_decay) == bool assert ( type(zero_init_bn_residuals) == bool ), "zero_init_bn_residuals must be a boolean, set to true if gamma of last\ @@ -262,9 +264,11 @@ def __init__( and is_pos_int(base_width_and_cardinality[1]) ) - # we apply weight decay to batch norm if the model is a ResNeXt and we don't if - # it is a ResNet - self.bn_weight_decay = base_width_and_cardinality is not None + # Chooses whether to apply weight decay to batch norm + # parameters. This improves results in some situations, + # e.g. ResNeXt models trained / evaluated using the Imagenet + # dataset, but can cause worse performance in other scenarios + self.bn_weight_decay = bn_weight_decay # initial convolutional block: self.num_blocks = num_blocks @@ -374,6 +378,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": "basic_layer": config.get("basic_layer", False), "final_bn_relu": config.get("final_bn_relu", True), "zero_init_bn_residuals": config.get("zero_init_bn_residuals", False), + "bn_weight_decay": config.get("bn_weight_decay", False), } return cls(**config) @@ -476,6 +481,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": return cls() +# Note, the ResNeXt models all have weight decay enabled for the batch +# norm parameters. We have found empirically that this gives better +# results when training on ImageNet (~0.5pp of top-1 acc) and brings +# our results on track with reported ImageNet results...but for +# training on other datasets, we have observed losses in accuracy (for +# example, the dataset used in https://arxiv.org/abs/1805.00932). @register_model("resnext50_32x4d") class ResNeXt50(ResNeXt): def __init__(self): @@ -484,6 +495,7 @@ def __init__(self): basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), + bn_weight_decay=True, ) @classmethod @@ -499,6 +511,7 @@ def __init__(self): basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), + bn_weight_decay=True, ) @classmethod @@ -514,6 +527,7 @@ def __init__(self): basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), + bn_weight_decay=True, ) @classmethod