Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (export): dequantize during export #1083

Open
wants to merge 5 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/brevitas/export/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.nn import Module

from brevitas import config
from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.nn.mixin.base import QuantLayerMixin
from brevitas.nn.mixin.base import QuantRecurrentLayerMixin
from brevitas.proxy.quant_proxy import QuantProxyProtocol
Expand Down Expand Up @@ -218,7 +220,10 @@ def jit_inference_trace(
requires_grad_backup_dict = _force_requires_grad_false(module)
# wrapping with a lambda forces inlining during tracing,
# converts everything to const and removes unused params/buffers
return_quant_tensor_state = disable_return_quant_tensor(module)
traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args)
restore_return_quant_tensor(module, return_quant_tensor_state)

# Hack to clone the function, otherwise restoring requires_grad
# on module will break traced_model
with BytesIO() as tmp:
Expand Down
6 changes: 5 additions & 1 deletion src/brevitas/export/onnx/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from packaging import version

from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import restore_return_quant_tensor

try:
import onnx
import onnxoptimizer as opt
Expand Down Expand Up @@ -164,10 +167,11 @@ def export_onnx(
else:
model_bytes = BytesIO()
export_target = model_bytes
return_quant_tensor_state = disable_return_quant_tensor(module)

with PatchFp8Ops():
torch.onnx.export(module, args, export_target, **onnx_export_kwargs)

restore_return_quant_tensor(module, return_quant_tensor_state)
# restore the model to previous properties
module.apply(lambda m: _restore_act_caching_mode(m))
cls.set_export_mode(module, enabled=False)
Expand Down
6 changes: 4 additions & 2 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]:
out = self.create_quant_tensor(out)
else:
out = self.tensor_quant(x)
if is_dynamo_compiling():
if is_dynamo_compiling() or torch._C._get_tracing_state() is not None:
out = out[0]
else:
out = self.create_quant_tensor(out)
Expand Down Expand Up @@ -273,6 +273,8 @@ def forward(

impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed)
if torch._C._get_tracing_state() is not None:
return out
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x
Expand Down Expand Up @@ -350,7 +352,7 @@ def forward(
out, out_scale, out_zp, out_bit_width = impl(x, input_scale)
else:
out, out_scale, out_zp, out_bit_width = impl(x)
if not is_dynamo_compiling():
if not is_dynamo_compiling() or torch._C._get_tracing_state() is not None:
out = IntQuantTensor(
out, out_scale, out_zp, out_bit_width, self.is_signed, self.training)
if not self.training and self.cache_inference_quant_bias:
Expand Down
4 changes: 3 additions & 1 deletion src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
# If y is an empty QuantTensor, we need to check if this is a passthrough proxy,
# otherwise return a simple Tensor

if is_dynamo_compiling():
if is_dynamo_compiling() or torch._C._get_tracing_state() is not None:
out = y[0]
else:
# If the second value (i.e., scale) is None, then quant is disabled
Expand Down Expand Up @@ -274,6 +274,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]:
else:
out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width)
out_value, out_scale, out_zp, out_bit_width = out_tuple
if torch._C._get_tracing_state() is not None:
return out_value
return IntQuantTensor(
out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)
else:
Expand Down
Loading