From df8c7802aaaf23edbdfdb9d1e0ecc98118441b07 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 31 Oct 2024 17:22:28 +0000 Subject: [PATCH 1/5] Feat (export): dequantize during export --- src/brevitas/export/manager.py | 5 +++++ src/brevitas/export/onnx/manager.py | 6 +++++- src/brevitas/proxy/parameter_quant.py | 6 ++++-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/brevitas/export/manager.py b/src/brevitas/export/manager.py index 7b7e7a145..68cd7c1e9 100644 --- a/src/brevitas/export/manager.py +++ b/src/brevitas/export/manager.py @@ -14,6 +14,8 @@ from torch.nn import Module from brevitas import config +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.nn.mixin.base import QuantLayerMixin from brevitas.nn.mixin.base import QuantRecurrentLayerMixin from brevitas.proxy.quant_proxy import QuantProxyProtocol @@ -218,7 +220,10 @@ def jit_inference_trace( requires_grad_backup_dict = _force_requires_grad_false(module) # wrapping with a lambda forces inlining during tracing, # converts everything to const and removes unused params/buffers + return_quant_tensor_state = disable_return_quant_tensor(module) traced_model = torch.jit.trace(_JitTraceExportWrapper(module), args) + restore_return_quant_tensor(module, return_quant_tensor_state) + # Hack to clone the function, otherwise restoring requires_grad # on module will break traced_model with BytesIO() as tmp: diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index ae3270cc9..482c0a70f 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -9,6 +9,9 @@ from packaging import version +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor + try: import onnx import onnxoptimizer as opt @@ -164,10 +167,11 @@ def export_onnx( else: model_bytes = BytesIO() export_target = model_bytes + return_quant_tensor_state = disable_return_quant_tensor(model) with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) - + restore_return_quant_tensor(model, return_quant_tensor_state) # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) cls.set_export_mode(module, enabled=False) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f28233aed..015b3982f 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -137,7 +137,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out = self.create_quant_tensor(out) else: out = self.tensor_quant(x) - if is_dynamo_compiling(): + if is_dynamo_compiling() or self.export_mode: out = out[0] else: out = self.create_quant_tensor(out) @@ -273,6 +273,8 @@ def forward( impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) + if self.export_mode: + return out return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled return x @@ -350,7 +352,7 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - if not is_dynamo_compiling(): + if not is_dynamo_compiling() or not self.export_mode: out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) if not self.training and self.cache_inference_quant_bias: From cba513c023620d298742768c61b948c91d97d7be Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 31 Oct 2024 17:24:35 +0000 Subject: [PATCH 2/5] runtime proxy --- src/brevitas/proxy/runtime_quant.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 9feb593b4..746e829d0 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -188,7 +188,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if is_dynamo_compiling(): + if is_dynamo_compiling() or self.export_mode: out = y[0] else: # If the second value (i.e., scale) is None, then quant is disabled @@ -274,6 +274,8 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + if self.export_mode: + return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training) else: From 9bf71fecdb91fd793582baa46c851d211c5e3542 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 31 Oct 2024 17:29:04 +0000 Subject: [PATCH 3/5] fix --- src/brevitas/export/onnx/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 482c0a70f..444eb3227 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -171,7 +171,7 @@ def export_onnx( with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) - restore_return_quant_tensor(model, return_quant_tensor_state) + restore_return_quant_tensor(module, return_quant_tensor_state) # restore the model to previous properties module.apply(lambda m: _restore_act_caching_mode(m)) cls.set_export_mode(module, enabled=False) From ffef28d717ee49abef54f776500986578b91a00b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 31 Oct 2024 17:31:31 +0000 Subject: [PATCH 4/5] fix name --- src/brevitas/export/onnx/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas/export/onnx/manager.py b/src/brevitas/export/onnx/manager.py index 444eb3227..73dc9f727 100644 --- a/src/brevitas/export/onnx/manager.py +++ b/src/brevitas/export/onnx/manager.py @@ -167,7 +167,7 @@ def export_onnx( else: model_bytes = BytesIO() export_target = model_bytes - return_quant_tensor_state = disable_return_quant_tensor(model) + return_quant_tensor_state = disable_return_quant_tensor(module) with PatchFp8Ops(): torch.onnx.export(module, args, export_target, **onnx_export_kwargs) From d561e66717dc3953a0a9f6e9f355ee6270db58fe Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 31 Oct 2024 17:45:38 +0000 Subject: [PATCH 5/5] change detect method --- src/brevitas/proxy/parameter_quant.py | 6 +++--- src/brevitas/proxy/runtime_quant.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 015b3982f..2eea157f5 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -137,7 +137,7 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: out = self.create_quant_tensor(out) else: out = self.tensor_quant(x) - if is_dynamo_compiling() or self.export_mode: + if is_dynamo_compiling() or torch._C._get_tracing_state() is not None: out = out[0] else: out = self.create_quant_tensor(out) @@ -273,7 +273,7 @@ def forward( impl = self.export_handler if self.export_mode else self.tensor_quant out, scale, zero_point, bit_width, pre_scale, pre_zero_point = impl(x, input_bit_width, input_is_signed) - if self.export_mode: + if torch._C._get_tracing_state() is not None: return out return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training) else: # quantization disabled @@ -352,7 +352,7 @@ def forward( out, out_scale, out_zp, out_bit_width = impl(x, input_scale) else: out, out_scale, out_zp, out_bit_width = impl(x) - if not is_dynamo_compiling() or not self.export_mode: + if not is_dynamo_compiling() or torch._C._get_tracing_state() is not None: out = IntQuantTensor( out, out_scale, out_zp, out_bit_width, self.is_signed, self.training) if not self.training and self.cache_inference_quant_bias: diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 746e829d0..517c94896 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -188,7 +188,7 @@ def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: # If y is an empty QuantTensor, we need to check if this is a passthrough proxy, # otherwise return a simple Tensor - if is_dynamo_compiling() or self.export_mode: + if is_dynamo_compiling() or torch._C._get_tracing_state() is not None: out = y[0] else: # If the second value (i.e., scale) is None, then quant is disabled @@ -274,7 +274,7 @@ def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: else: out_tuple = self.tensor_quant(x.value, x.scale, x.zero_point, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple - if self.export_mode: + if torch._C._get_tracing_state() is not None: return out_value return IntQuantTensor( out_value, out_scale, out_zp, out_bit_width, x.signed, self.training)