Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 5, 2024
1 parent 1ddea1c commit f27071d
Showing 1 changed file with 41 additions and 20 deletions.
61 changes: 41 additions & 20 deletions src/brevitas/export/onnx/qonnx/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,26 @@ def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs = {
'scale': module.scale(),
'exponent_bit_width': module.exponent_bit_width(),
'mantissa_bit_width': module.mantissa_bit_width(),
'exponent_bias': module.exponent_bias(),
'has_inf': module.inf_values() is not None,
'has_nan': module.nan_values() is not None,
'saturating': module.is_saturating(),
'has_subnormal': True, # Currently we only support subnormal
'rounding_mode': module.rounding_mode,
'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())}
'scale':
module.scale(),
'exponent_bit_width':
module.exponent_bit_width(),
'mantissa_bit_width':
module.mantissa_bit_width(),
'exponent_bias':
module.exponent_bias(),
'has_inf':
module.inf_values() is not None,
'has_nan':
module.nan_values() is not None,
'saturating':
module.is_saturating(),
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float).type_as(module.scale())}
self.return_args = {
'scale': module.scale(),
'zero_point': torch.zeros_like(module.scale()),
Expand Down Expand Up @@ -75,16 +85,27 @@ def prepare_for_export(self, module: WeightQuantProxyFromInjector):
first_qweight = module.tracked_module_list[0].quant_weight()
self.validate(first_qweight.zero_point)
self.symbolic_kwargs = {
'scale': first_qweight.scale,
'exponent_bit_width': first_qweight.exponent_bit_width,
'mantissa_bit_width': first_qweight.mantissa_bit_width,
'exponent_bias': first_qweight.exponent_bias,
'has_inf': first_qweight.inf_values is not None,
'has_nan': first_qweight.nan_values is not None,
'saturating': first_qweight.saturating,
'has_subnormal': True, # Currently we only support subnormal
'rounding_mode': module.rounding_mode,
'max_float': torch.tensor(module.quant_injector.max_available_float).type_as(first_qweight.scale)}
'scale':
first_qweight.scale,
'exponent_bit_width':
first_qweight.exponent_bit_width,
'mantissa_bit_width':
first_qweight.mantissa_bit_width,
'exponent_bias':
first_qweight.exponent_bias,
'has_inf':
first_qweight.inf_values is not None,
'has_nan':
first_qweight.nan_values is not None,
'saturating':
first_qweight.saturating,
'has_subnormal':
True, # Currently we only support subnormal
'rounding_mode':
module.rounding_mode,
'max_float':
torch.tensor(module.quant_injector.max_available_float
).type_as(first_qweight.scale)}
self.return_args = {
'scale': first_qweight.scale,
'zero_point': torch.zeros_like(first_qweight.scale),
Expand Down

0 comments on commit f27071d

Please sign in to comment.