diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f28233aed..2144edd53 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -23,6 +23,7 @@ from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import is_broadcastable from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -309,7 +310,12 @@ def quant_output_scale_impl( channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1 output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim) output_scale = weight.scale.view(output_scale_shape) - output_scale = output_scale * input.scale.view(output_scale_shape) + + input_scale_view = input.scale.view(output_scale_shape) + if not is_broadcastable(output_scale.shape, input_scale_view.shape): + return None + + output_scale = output_scale * input_scale_view return output_scale def compute_bias_scale( @@ -336,8 +342,8 @@ def forward( weight: Optional[Union[Tensor, IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]: out = x - input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: + input_scale = self.compute_bias_scale(input, weight) impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None and self.is_quant_enabled: input_scale = self.scale() diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index 8882bd097..3258b8914 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -10,6 +10,7 @@ from brevitas.function.ops import max_int from brevitas.function.ops_ste import ceil_ste from brevitas.utils.torch_utils import compute_channel_view_shape +from brevitas.utils.torch_utils import is_broadcastable INT_QUANT_TENSOR_FN_HANDLER = {} @@ -198,6 +199,9 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): (quant_weight.zero_point != 0.0).any()): warnings.warn("Computing zero point of output accumulator not supported yet.") compute_output_quant_tensor = False + if output_scale is None: + warnings.warn("Could not compute output scale factor, returning Tensor") + compute_output_quant_tensor = False if compute_output_quant_tensor: if output_zero_point is None: @@ -230,8 +234,9 @@ def quant_output_scale_impl( output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) quant_weight_scale = quant_weight_scale.view(output_scale_shape) - if len(quant_input_scale.shape) == 0: - quant_input_scale = quant_input_scale.view(output_scale_shape) + quant_input_scale = quant_input_scale.view(output_scale_shape) + if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape): + return None output_scale = quant_weight_scale * quant_input_scale return output_scale diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 2f0d34fba..8942c513a 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -113,3 +113,12 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: padding[2 * group_dim] = group_size - size[group_dim] % group_size padding = list(reversed(padding)) return padding + + +def is_broadcastable(tensor, other): + for a, b in zip(tensor[::-1], other[::-1]): + if a == 1 or b == 1 or a == b: + pass + else: + return False + return True