Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: correct output scale compute #1077

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO
from brevitas.utils.torch_utils import compute_channel_view_shape
from brevitas.utils.torch_utils import is_broadcastable

from .quant_proxy import QuantProxyFromInjector
from .quant_proxy import QuantProxyProtocol
Expand Down Expand Up @@ -309,7 +310,12 @@ def quant_output_scale_impl(
channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1
output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim)
output_scale = weight.scale.view(output_scale_shape)
output_scale = output_scale * input.scale.view(output_scale_shape)

input_scale_view = input.scale.view(output_scale_shape)
if not is_broadcastable(output_scale.shape, input_scale_view.shape):
return None

output_scale = output_scale * input_scale_view
return output_scale

def compute_bias_scale(
Expand All @@ -336,8 +342,8 @@ def forward(
weight: Optional[Union[Tensor,
IntQuantTensor]] = None) -> Union[Tensor, IntQuantTensor]:
out = x
input_scale = self.compute_bias_scale(input, weight)
if self.is_quant_enabled:
input_scale = self.compute_bias_scale(input, weight)
impl = self.export_handler if self.export_mode else self.tensor_quant
if self.requires_input_scale and input_scale is None and self.is_quant_enabled:
input_scale = self.scale()
Expand Down
9 changes: 7 additions & 2 deletions src/brevitas/quant_tensor/int_torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
from brevitas.utils.torch_utils import compute_channel_view_shape
from brevitas.utils.torch_utils import is_broadcastable

INT_QUANT_TENSOR_FN_HANDLER = {}

Expand Down Expand Up @@ -198,6 +199,9 @@ def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs):
(quant_weight.zero_point != 0.0).any()):
warnings.warn("Computing zero point of output accumulator not supported yet.")
compute_output_quant_tensor = False
if output_scale is None:
warnings.warn("Could not compute output scale factor, returning Tensor")
compute_output_quant_tensor = False

if compute_output_quant_tensor:
if output_zero_point is None:
Expand Down Expand Up @@ -230,8 +234,9 @@ def quant_output_scale_impl(
output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim)

quant_weight_scale = quant_weight_scale.view(output_scale_shape)
if len(quant_input_scale.shape) == 0:
quant_input_scale = quant_input_scale.view(output_scale_shape)
quant_input_scale = quant_input_scale.view(output_scale_shape)
if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape):
return None

output_scale = quant_weight_scale * quant_input_scale
return output_scale
Expand Down
9 changes: 9 additions & 0 deletions src/brevitas/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,12 @@ def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
return padding


def is_broadcastable(tensor, other):
for a, b in zip(tensor[::-1], other[::-1]):
if a == 1 or b == 1 or a == b:
pass
else:
return False
return True
Loading