diff --git a/src/brevitas/export/onnx/standard/function.py b/src/brevitas/export/onnx/standard/function.py index aed2782bc..03916c299 100644 --- a/src/brevitas/export/onnx/standard/function.py +++ b/src/brevitas/export/onnx/standard/function.py @@ -4,9 +4,43 @@ import onnx import torch from torch.autograd import Function +from torch.onnx.symbolic_helper import _get_tensor_sizes from brevitas.export.onnx import onnx_export_opset + +class MatMulNBitsFn(Function): + + @staticmethod + def symbolic(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + ret = g.op( + 'com.microsoft::MatMulNBits', + x, + int_weights, + scales, + zero_points, + K_i=K, + N_i=N, + bits_i=bits, + block_size_i=block_size) + output_size = _get_tensor_sizes(x) + output_size[-1] = N + ret.setType(x.type().with_sizes(output_size)) + return ret + + @staticmethod + def forward(g, x, int_weights, scales, zero_points, K, N, bits, block_size): + dtype = x.dtype + device = x.device + shape = x.shape + out_shape = list(shape) + out_shape[-1] = N + # Only tensor metadata (shape, dtype, device) are preserved in the forward pass during + # tracing, not the correct value + out = torch.empty(out_shape, dtype=dtype, device=device) + return out + + AXIS_OPSET = 13 DATATYPE_DICT = { diff --git a/src/brevitas/nn/quant_avg_pool.py b/src/brevitas/nn/quant_avg_pool.py index 7a3f108da..ac2f74c41 100644 --- a/src/brevitas/nn/quant_avg_pool.py +++ b/src/brevitas/nn/quant_avg_pool.py @@ -117,7 +117,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler(_unpack_quant_tensor(x)) - self._set_global_is_quant_layer(False) return out if isinstance(x, QuantTensor) and self.is_trunc_quant_enabled: diff --git a/src/brevitas/nn/quant_eltwise.py b/src/brevitas/nn/quant_eltwise.py index c395d5c9f..8c25873c5 100644 --- a/src/brevitas/nn/quant_eltwise.py +++ b/src/brevitas/nn/quant_eltwise.py @@ -37,7 +37,6 @@ def forward(self, input: Union[Tensor, QuantTensor], if self.export_mode: assert self.cache_quant_io_metadata_only, "Can't cache multiple inputs" out = self.export_handler(inp=input.value, other=other.value) - self._set_global_is_quant_layer(False) return out quant_input = self.input_quant(input) quant_other = self.input_quant(other) @@ -70,7 +69,6 @@ def forward(self, # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler([qt.value for qt in quant_tensor_list]) - self._set_global_is_quant_layer(False) return out quant_tensor_list = [self.input_quant(qt) for qt in quant_tensor_list] # trigger an assert if scale factors and bit widths are None or different diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 215299837..f501afd62 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -47,7 +47,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler(quant_input) - self._set_global_is_quant_layer(False) return out out = self.act_quant(quant_input) out = self.pack_output(out) @@ -139,7 +138,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe # shortcut execution through the export impl during export if self.export_mode: out = self.export_handler(inp) - self._set_global_is_quant_layer(False) return out quant_input = self.input_quant(inp) diff --git a/src/brevitas/nn/quant_upsample.py b/src/brevitas/nn/quant_upsample.py index f2735abf5..6b7c0ba71 100644 --- a/src/brevitas/nn/quant_upsample.py +++ b/src/brevitas/nn/quant_upsample.py @@ -40,7 +40,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: out = self.export_handler(x.value) - self._set_global_is_quant_layer(False) return out y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) if self.mode != 'nearest': @@ -69,7 +68,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: out = self.export_handler(x.value) - self._set_global_is_quant_layer(False) return out y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) # round interpolated values to scale @@ -97,7 +95,6 @@ def forward(self, input: Union[Tensor, QuantTensor]): x = self.unpack_input(input) if self.export_mode: out = self.export_handler(x.value) - self._set_global_is_quant_layer(False) return out y_value = interpolate(x.value, self.size, self.scale_factor, self.mode, self.align_corners) y = x.set(value=y_value) diff --git a/src/brevitas_examples/llm/llm_quant/export.py b/src/brevitas_examples/llm/llm_quant/export.py index 9f068bddd..49ce8dceb 100644 --- a/src/brevitas_examples/llm/llm_quant/export.py +++ b/src/brevitas_examples/llm/llm_quant/export.py @@ -10,6 +10,8 @@ import numpy as np import torch +from torch.nn import Module +from torch.onnx import register_custom_op_symbolic from brevitas.export.common.handler.base import BaseHandler from brevitas.export.manager import _set_layer_export_handler @@ -17,9 +19,12 @@ from brevitas.export.manager import _set_proxy_export_handler from brevitas.export.manager import _set_proxy_export_mode from brevitas.export.manager import BaseManager +from brevitas.export.onnx.handler import ONNXBaseHandler +from brevitas.export.onnx.standard.function import MatMulNBitsFn from brevitas.function.ops import max_int from brevitas.function.ops import min_int from brevitas.nn import QuantLinear +from brevitas.proxy.groupwise_int_parameter_quant import GroupwiseWeightQuantProxyFromInjector from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector @@ -52,27 +57,6 @@ def __init__(self): self.bit_width = None self.dtype = None - def scaling_impl(self, proxy_module): - return proxy_module.tensor_quant.scaling_impl - - def zero_point_impl(self, proxy_module): - return proxy_module.tensor_quant.zero_point_impl - - def bit_width_impl(self, proxy_module): - return proxy_module.tensor_quant.msb_clamp_bit_width_impl - - def export_scale(self, proxy_module, bit_width): - scaling_impl = self.scaling_impl(proxy_module) - int_scaling_impl = proxy_module.tensor_quant.int_scaling_impl - int_threshold = int_scaling_impl(bit_width) - threshold = scaling_impl.wrapped_scaling_impl.stats_scaling_impl( - scaling_impl.wrapped_scaling_impl.parameter_list_stats()) - return threshold / int_threshold - - def export_zero_point(self, proxy_module, scale, bit_width): - zero_point_impl = self.zero_point_impl(proxy_module) - return zero_point_impl.unexpanded_zero_point(scale, bit_width) - @abstractmethod def prepare_for_export(self, module): pass @@ -83,6 +67,7 @@ def forward(self, x): class WeightBlockQuantProxyHandler(WeightBlockQuantHandlerBase): + handled_layer = GroupwiseWeightQuantProxyFromInjector def __init__(self): super().__init__() @@ -93,20 +78,18 @@ def __init__(self): def prepare_for_export(self, module): assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." - self.bit_width = self.bit_width_impl(module)() - assert self.bit_width <= 8., "Only 8b or lower is supported." quant_layer = module.tracked_module_list[0] quant_weight = quant_layer.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." signed = module.is_signed self.int_dtype = torch.int8 if signed else torch.uint8 self.dtype = quant_weight.value.dtype - self.scale = self.export_scale(module, self.bit_width).detach() - self.expanded_groupwise_shape = self.scaling_impl(module).expanded_groupwise_shape - self.reshaped_groupwise_shape = self.scaling_impl(module).reshaped_groupwise_shape + self.scale = quant_weight.scale_ + self.expanded_scaling_shape = quant_weight.value_.shape + self.reshaped_scaling_shape = quant_weight.value.shape if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point(module, self.scale, self.bit_width).detach() - self.expanded_zero_point_shape = self.zero_point_impl(module).expanded_zero_point_shape - self.reshaped_zero_point_shape = self.zero_point_impl(module).reshaped_zero_point_shape + self.zero_point = quant_weight.zero_point_ else: self.zero_point = None @@ -131,15 +114,9 @@ def forward(self, x): x = (x.type(self.dtype) - zero_point) * scale # Fix shape post quantization - scale = scale.expand(self.expanded_groupwise_shape).contiguous().view( - self.reshaped_groupwise_shape) # If zero_point is not defined, propagate same shape as scale if self.zero_point is None: zero_point = torch.zeros_like(scale).type(self.int_dtype) - else: - zero_point = zero_point.expand(self.expanded_zero_point_shape).contiguous().view( - self.reshaped_zero_point_shape).type(self.int_dtype) - x = x.view(self.reshaped_groupwise_shape) return x, scale, zero_point, bit_width @@ -208,18 +185,17 @@ def lcm(x, y): raise ValueError(f"Bit width {bit_width} not supported.") def prepare_for_export(self, module): - self.bit_width = self.bit_width_impl(module.weight_quant)() - assert self.bit_width <= 8., "Only 8b or lower is supported." quant_weight = module.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." self.bias = module.bias - self.scale = self.export_scale(module.weight_quant, self.bit_width) + self.scale = quant_weight.scale_ if (quant_weight.zero_point != 0.).any(): - self.zero_point = self.export_zero_point( - module.weight_quant, self.scale, self.bit_width) + self.zero_point = quant_weight.zero_point_ else: # if there is no zero-point, export zeroes in the shape of scale self.zero_point = torch.zeros_like(self.scale) - self.group_size = module.weight_quant.quant_injector.block_size + self.group_size = quant_weight.group_size self.bit_width = int(self.bit_width.cpu().item()) self.int_weight = self.pack_int_weights(self.bit_width, quant_weight.int().detach()) @@ -237,10 +213,12 @@ def set_export_handler(cls, module): _set_proxy_export_handler(cls, module) -def block_quant_layer_level_manager(export_handlers): +def block_quant_layer_level_manager(export_handlers, target=None, custom_fns_to_register=None): class BlockQuantLayerLevelManager(BaseManager): handlers = export_handlers + target_name = '' if target is None else target + custom_fns = [] if custom_fns_to_register is None else custom_fns_to_register @classmethod def set_export_handler(cls, module): @@ -281,3 +259,92 @@ def replace_call_fn_target(graph_model, src, target): node.target = target graph_model.graph.lint() graph_model.recompile() + + +class ONNXLinearWeightBlockQuantHandlerFwd(ONNXBaseHandler, WeightBlockQuantHandlerBase): + handled_layer = QuantLinear + + def __init__(self): + super(ONNXLinearWeightBlockQuantHandlerFwd, self).__init__() + self.group_size = None + + def pack_int_weights(self, bit_width, int_weights, zero_point): + assert int_weights.dtype in [torch.uint8, torch.int8], "Packing requires (u)int8 input." + assert bit_width == 4, "Only 4 bit quantization export is supported at the moment" + + is_symmetric = torch.sum(zero_point) == 0 + zero_point = zero_point.to(torch.uint8) + rows, cols = int_weights.shape + group_size = self.group_size + blob_size = group_size // 2 + k_blocks = (rows + group_size - 1) // group_size + padded_rows = k_blocks * group_size + pad_len = padded_rows - rows + + # ONNX operator assumes implicit zp of 8 (largest negative number in Po2) + # If we are in a "symmetric" quantized scenario, we need to add this implicit zero point + # Otherwise it has already been added during the convesion to integer. + # This allows to pack weights always in unsigned integer. + zp = 0 if not int_weights.dtype == torch.int8 else 8 + int_weights += zp + if pad_len > 0: + int_weights = torch.nn.functional(int_weights, (0, 0, 0, pad_len)) + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + rows, cols = int_weights.shape + int_weights = int_weights.t() + for n in range(cols): + for k_id in range(0, rows, group_size): + blk_int0 = (int_weights[n, k_id:k_id + group_size:2].numpy()).astype("uint8") + blk_int1 = (int_weights[n, k_id + 1:k_id + group_size:2].numpy()).astype("uint8") + packed[n, k_id // group_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4)) + + zero_point = zero_point.to(torch.uint8).flatten() + + # The constant value 136 is derived from the source code in ORT test suite. + # https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/test/python/quantization/test_quantizeblockwise_4bits.py + base_zp = 136 if is_symmetric else 0 + packed_zp = base_zp * torch.ones( + (zero_point.shape[0] + 1) // 2, device=int_weights.device, dtype=torch.uint8) + + i = 0 + for column in range(packed_zp.shape[0]): + for j in range(i, i + (8 // bit_width)): + shift_factor = (bit_width * (j - i)) + packed_zp[column] |= zero_point[j] << shift_factor + i += 8 // bit_width + return torch.tensor(packed), packed_zp + + def prepare_for_export(self, module): + quant_weight = module.quant_weight() + self.bit_width = quant_weight.bit_width + assert self.bit_width <= 8., "Only 8b or lower is supported." + self.bias = module.bias + self.scale = quant_weight.scale_ + if (quant_weight.zero_point != 0.).any(): + self.zero_point = quant_weight.zero_point_ + else: + # if there is no zero-point, export zeroes in the shape of scale + self.zero_point = torch.zeros_like(self.scale) + self.group_size = module.weight_quant.quant_injector.group_size + self.bit_width = int(self.bit_width.cpu().item()) + self.int_weight, self.zero_point = self.pack_int_weights(self.bit_width, quant_weight.int().t().detach(), self.zero_point) + self.weight_shape = module.weight.shape + + def symbolic_execution(self, x): + int_weights = self.int_weight + scale = self.scale + bit_width = self.bit_width + N, K = self.weight_shape + out = MatMulNBitsFn.apply( + x, int_weights, scale.flatten(), self.zero_point, K, N, bit_width, self.group_size) + return out + + +def export_packed_onnx(model, input, export_path): + export_class = block_quant_layer_level_manager( + export_handlers=[ONNXLinearWeightBlockQuantHandlerFwd], + target='', + custom_fns_to_register=MatMulNBitsFn) + + with torch.inference_mode(), brevitas_layer_export_mode(model, export_class): + torch.onnx.export(model, input, export_path)