From 5fa2dd819a4ad124f7ab54f9d2b9c2b44de73031 Mon Sep 17 00:00:00 2001 From: suyeong Date: Tue, 23 Jan 2024 05:37:20 +0000 Subject: [PATCH 1/2] fix: add runtime_shape argument in batchnorm-related classes to match shape --- src/brevitas/nn/quant_bn.py | 31 +++++++++++++++++++++++++++-- src/brevitas/nn/quant_scale_bias.py | 3 ++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/brevitas/nn/quant_bn.py b/src/brevitas/nn/quant_bn.py index a8047a690..97e52014f 100644 --- a/src/brevitas/nn/quant_bn.py +++ b/src/brevitas/nn/quant_bn.py @@ -16,6 +16,29 @@ class _BatchNormToQuantScaleBias(QuantScaleBias, ABC): + def __init__( + self, + num_features: int, + bias: bool = True, + weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat, + bias_quant: Optional[BiasQuantType] = None, + input_quant: Optional[ActQuantType] = None, + output_quant: Optional[ActQuantType] = None, + return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), + **kwargs): + QuantScaleBias.__init__( + self, + num_features=num_features, + weight_quant=weight_quant, + bias_quant=bias_quant, + input_quant=input_quant, + output_quant=output_quant, + return_quant_tensor=return_quant_tensor, + runtime_shape=runtime_shape, + **kwargs + ) + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -59,16 +82,18 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), **kwargs): super(BatchNorm1dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1, 1), + runtime_shape=runtime_shape, weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant, output_quant=output_quant, return_quant_tensor=return_quant_tensor, + runtime_shape=runtime_shape, **kwargs) self.eps = eps @@ -84,15 +109,17 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1, 1), **kwargs): super(BatchNorm2dToQuantScaleBias, self).__init__( num_features, bias=True, - runtime_shape=(1, -1, 1, 1), + runtime_shape=runtime_shape, weight_quant=weight_quant, bias_quant=bias_quant, input_quant=input_quant, output_quant=output_quant, return_quant_tensor=return_quant_tensor, + runtime_shape=runtime_shape, **kwargs) self.eps = eps diff --git a/src/brevitas/nn/quant_scale_bias.py b/src/brevitas/nn/quant_scale_bias.py index a97f54ed5..7a6427da1 100644 --- a/src/brevitas/nn/quant_scale_bias.py +++ b/src/brevitas/nn/quant_scale_bias.py @@ -48,8 +48,9 @@ def __init__( input_quant: Optional[ActQuantType] = None, output_quant: Optional[ActQuantType] = None, return_quant_tensor: bool = False, + runtime_shape=(1, -1, 1), **kwargs) -> None: - ScaleBias.__init__(self, num_features, bias) + ScaleBias.__init__(self, num_features, bias, runtime_shape=runtime_shape) QuantWBIOL.__init__( self, weight_quant=weight_quant, From 8d64d19bc2fd4116a160fd9b62171f83e88dc178 Mon Sep 17 00:00:00 2001 From: suyeong Date: Tue, 23 Jan 2024 05:49:41 +0000 Subject: [PATCH 2/2] fix: rm redundant code --- src/brevitas/nn/quant_bn.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/nn/quant_bn.py b/src/brevitas/nn/quant_bn.py index 97e52014f..765e4918d 100644 --- a/src/brevitas/nn/quant_bn.py +++ b/src/brevitas/nn/quant_bn.py @@ -93,7 +93,6 @@ def __init__( input_quant=input_quant, output_quant=output_quant, return_quant_tensor=return_quant_tensor, - runtime_shape=runtime_shape, **kwargs) self.eps = eps @@ -120,6 +119,5 @@ def __init__( input_quant=input_quant, output_quant=output_quant, return_quant_tensor=return_quant_tensor, - runtime_shape=runtime_shape, **kwargs) self.eps = eps