diff --git a/classy_vision/models/resnext.py b/classy_vision/models/resnext.py
index d017d9314a..6bd525e34b 100644
--- a/classy_vision/models/resnext.py
+++ b/classy_vision/models/resnext.py
@@ -235,6 +235,7 @@ def __init__(
         base_width_and_cardinality: Optional[Union[Tuple, List]] = None,
         basic_layer: bool = False,
         final_bn_relu: bool = True,
+        bn_weight_decay: Optional[bool] = False,
     ):
         """
             Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
@@ -251,6 +252,7 @@ def __init__(
         assert all(is_pos_int(n) for n in num_blocks)
         assert is_pos_int(init_planes) and is_pos_int(reduction)
         assert type(small_input) == bool
+        assert type(bn_weight_decay) == bool
         assert (
             type(zero_init_bn_residuals) == bool
         ), "zero_init_bn_residuals must be a boolean, set to true if gamma of last\
@@ -262,9 +264,11 @@ def __init__(
             and is_pos_int(base_width_and_cardinality[1])
         )
 
-        # we apply weight decay to batch norm if the model is a ResNeXt and we don't if
-        # it is a ResNet
-        self.bn_weight_decay = base_width_and_cardinality is not None
+        # Chooses whether to apply weight decay to batch norm
+        # parameters. This improves results in some situations,
+        # e.g. ResNeXt models trained / evaluated using the Imagenet
+        # dataset, but can cause worse performance in other scenarios
+        self.bn_weight_decay = bn_weight_decay
 
         # initial convolutional block:
         self.num_blocks = num_blocks
@@ -374,6 +378,7 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
             "basic_layer": config.get("basic_layer", False),
             "final_bn_relu": config.get("final_bn_relu", True),
             "zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
+            "bn_weight_decay": config.get("bn_weight_decay", False),
         }
         return cls(**config)
 
@@ -476,6 +481,12 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
         return cls()
 
 
+# Note, the ResNeXt models all have weight decay enabled for the batch
+# norm parameters. We have found empirically that this gives better
+# results when training on ImageNet (~0.5pp of top-1 acc) and brings
+# our results on track with reported ImageNet results...but for
+# training on other datasets, we have observed losses in accuracy (for
+# example, the dataset used in https://arxiv.org/abs/1805.00932).
 @register_model("resnext50_32x4d")
 class ResNeXt50(ResNeXt):
     def __init__(self):
@@ -484,6 +495,7 @@ def __init__(self):
             basic_layer=False,
             zero_init_bn_residuals=True,
             base_width_and_cardinality=(4, 32),
+            bn_weight_decay=True,
         )
 
     @classmethod
@@ -499,6 +511,7 @@ def __init__(self):
             basic_layer=False,
             zero_init_bn_residuals=True,
             base_width_and_cardinality=(4, 32),
+            bn_weight_decay=True,
         )
 
     @classmethod
@@ -514,6 +527,7 @@ def __init__(self):
             basic_layer=False,
             zero_init_bn_residuals=True,
             base_width_and_cardinality=(4, 32),
+            bn_weight_decay=True,
         )
 
     @classmethod