From 8b8877cecbf8ab9f5b3e25fa0b797630f73c1ce2 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 26 Oct 2024 00:27:26 +0100 Subject: [PATCH 1/4] Fix: correct output scale compute --- src/brevitas/proxy/parameter_quant.py | 11 ++++++++--- src/brevitas/quant_tensor/int_torch_handler.py | 8 +++++++- src/brevitas/utils/torch_utils.py | 8 ++++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f28233aed..c170ecfc2 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -22,7 +22,7 @@ from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO -from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import compute_channel_view_shape, is_broadcastable from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -309,7 +309,12 @@ def quant_output_scale_impl( channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1 output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim) output_scale = weight.scale.view(output_scale_shape) - output_scale = output_scale * input.scale.view(output_scale_shape) + + input_scale_view = input.scale.view(output_scale_shape) + if not is_broadcastable(output_scale.shape, input_scale_view.shape): + return None + + output_scale = output_scale * input_scale_view return output_scale def compute_bias_scale( @@ -336,8 +341,8 @@ def forward( weight: Optional[Union[Tensor, IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]: out = x - input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: + input_scale = self.compute_bias_scale(input, weight) impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None and self.is_quant_enabled: input_scale = self.scale() diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 8882bd097..71cc5ea17 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -9,7 +9,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste -from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import compute_channel_view_shape, is_broadcastable INT_QUANT_TENSOR_FN_HANDLER = {} @@ -198,6 +198,9 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): (quant_weight.zero_point != 0.0).any()): warnings.warn("Computing zero point of output accumulator not supported yet.") compute_output_quant_tensor = False + if output_scale is None: + warnings.warn("Could not compute output scale factor, returning Tensor") + compute_output_quant_tensor = False if compute_output_quant_tensor: if output_zero_point is None: @@ -232,6 +235,9 @@ def quant_output_scale_impl( quant_weight_scale = quant_weight_scale.view(output_scale_shape) if len(quant_input_scale.shape) == 0: quant_input_scale = quant_input_scale.view(output_scale_shape) + quant_input_scale = quant_input_scale.view(output_scale_shape) + if not is_broadcastable(output_scale_shape, quant_input_scale.shape): + return None output_scale = quant_weight_scale * quant_input_scale return output_scale diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2f0d34fba..ac1cc84fd 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -113,3 +113,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding + +def is_broadcastable(tensor, other): + for a, b in zip(tensor[::-1], other[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True \ No newline at end of file From f0e9a857ef88365e266934b6305a59c2a0dfb3f3 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 09:36:10 +0000 Subject: [PATCH 2/4] precommit --- src/brevitas/proxy/parameter_quant.py | 3 ++- src/brevitas/quant_tensor/int_torch_handler.py | 3 ++- src/brevitas/utils/torch_utils.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index c170ecfc2..2144edd53 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -22,7 +22,8 @@ from brevitas.quant_tensor import IntQuantTensor from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO -from brevitas.utils.torch_utils import compute_channel_view_shape, is_broadcastable +from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import is_broadcastable from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 71cc5ea17..91627633d 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -9,7 +9,8 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste -from brevitas.utils.torch_utils import compute_channel_view_shape, is_broadcastable +from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import is_broadcastable INT_QUANT_TENSOR_FN_HANDLER = {} diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index ac1cc84fd..8942c513a 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -114,10 +114,11 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding = list(reversed(padding)) return padding + def is_broadcastable(tensor, other): for a, b in zip(tensor[::-1], other[::-1]): if a == 1 or b == 1 or a == b: pass else: return False - return True \ No newline at end of file + return True From dd05b5285b08c1092c176d7fb4fd14ec412d4b8e Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 28 Oct 2024 09:59:08 +0000 Subject: [PATCH 3/4] Fix --- src/brevitas/quant_tensor/int_torch_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 91627633d..fb5db7ca1 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -237,7 +237,7 @@ def quant_output_scale_impl( if len(quant_input_scale.shape) == 0: quant_input_scale = quant_input_scale.view(output_scale_shape) quant_input_scale = quant_input_scale.view(output_scale_shape) - if not is_broadcastable(output_scale_shape, quant_input_scale.shape): + if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape): return None output_scale = quant_weight_scale * quant_input_scale From 60367a61eb3d5a5dd5934f794a716407947d0a4a Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 30 Oct 2024 13:32:16 +0100 Subject: [PATCH 4/4] Update int_torch_handler.py --- src/brevitas/quant_tensor/int_torch_handler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index fb5db7ca1..3258b8914 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -234,8 +234,6 @@ def quant_output_scale_impl( output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) quant_weight_scale = quant_weight_scale.view(output_scale_shape) - if len(quant_input_scale.shape) == 0: - quant_input_scale = quant_input_scale.view(output_scale_shape) quant_input_scale = quant_input_scale.view(output_scale_shape) if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape): return None