From af37fcc34879764e056197ce16694bbf002def2b Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:56:34 +0100 Subject: [PATCH 1/3] Fix (quant_tensor): Produce valid IntQuantTensor after AvgPool functional call --- src/brevitas/quant_tensor/int_torch_handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 8882bd097..a9c6572fa 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -114,6 +114,7 @@ def avg_pool2d_handler( avg_scaling = kernel_size * kernel_size quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / avg_scaling) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, avg_scaling)) return quant_input @@ -133,6 +134,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): reduce_size = reduce(mul, k_size, 1) quant_input = quant_input.set(value=x) + quant_input = quant_input.set(scale=quant_input.scale / reduce_size) quant_input = quant_input.set(bit_width=max_acc_bit_width(quant_input.bit_width, reduce_size)) return quant_input From 5c89ef26ee587a25eb0a3feac7b97f9b741082b3 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:58:50 +0100 Subject: [PATCH 2/3] Fix (core/trunc): Fix output scaling after truncation --- src/brevitas/core/quant/int.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/brevitas/core/quant/int.py b/src/brevitas/core/quant/int.py index cdb75df74..46127eb08 100644 --- a/src/brevitas/core/quant/int.py +++ b/src/brevitas/core/quant/int.py @@ -211,12 +211,13 @@ def forward(self, x: Tensor, scale: Tensor, zero_point: Tensor, output_bit_width = self.msb_clamp_bit_width_impl() trunc_bit_width = input_bit_width - output_bit_width trunc_scale = 2.0 ** trunc_bit_width + output_scale = scale * trunc_scale y = y / trunc_scale y = self.float_to_int_impl(y) y = y - zero_point - y = y * scale + y = y * output_scale y = self.delay_wrapper(x, y) - return y, scale, zero_point, output_bit_width + return y, output_scale, zero_point, output_bit_width class DecoupledRescalingIntQuantWithInput(DecoupledRescalingIntQuant): From 090bb14344442b2f74805edbd1fc0408fc741ed8 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 4 Oct 2024 13:59:46 +0100 Subject: [PATCH 3/3] Fix (nn/TruncAvgPool): Remove any quant tensor manual manipulation. --- src/brevitas/nn/quant_avg_pool.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 7a3f108da..31bac921b 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -54,6 +54,7 @@ def _avg_scaling(self): else: return self.kernel_size * self.kernel_size + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -62,8 +63,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AvgPool2d.forward(self, x) - rescaled_value = y.value * self._avg_scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AvgPool2d.forward(self, _unpack_quant_tensor(x)) @@ -111,6 +110,7 @@ def compute_kernel_size_stride(input_shape, output_shape): stride_list.append(stride) return kernel_size_list, stride_list + # TODO: Replace with functional call def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) @@ -122,10 +122,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: y = AdaptiveAvgPool2d.forward(self, x) - k_size, stride = self.compute_kernel_size_stride(x.value.shape[2:], y.value.shape[2:]) - reduce_size = reduce(mul, k_size, 1) - rescaled_value = y.value * reduce_size # remove avg scaling - y = y.set(value=rescaled_value) y = self.trunc_quant(y) else: y = AdaptiveAvgPool2d.forward(self, _unpack_quant_tensor(x))