From 6557961e7f222050b2caac21ba308e6ecc94dcda Mon Sep 17 00:00:00 2001 From: Carl Date: Wed, 15 Jan 2025 18:57:52 +0000 Subject: [PATCH] Porting TF fake_quant_with_min_max functions (#20641) * QAT (squashed this time) (#1) * adds fake_quant_with_min_max functions from TF to keras3 * Addresses PR review comments * drops another type hint * swaps out if statements, change float() to ops.cast and adds fake_quant_with_min_max_vars function * fix missed if statement, adds gradient tests via main function for tf and torch * fix unbound variable error when not using torch or tf backend (#2) Refactor to use backend specific gradient functions in tests and merges logic into single function * More QAT function revisions (#3) This PR addresses review feedback to fix implementation and to move tests to using named_parameters rather than individual functions. * Qat revisions (#4) Adds axis param and fixes logic for per channel function * updated implementation * removed redundant functions --- .../_tf_keras/keras/quantizers/__init__.py | 3 + keras/api/quantizers/__init__.py | 3 + keras/src/quantizers/__init__.py | 1 + keras/src/quantizers/quantizers.py | 137 +++++++ keras/src/quantizers/quantizers_test.py | 381 ++++++++++++++++++ 5 files changed, 525 insertions(+) diff --git a/keras/api/_tf_keras/keras/quantizers/__init__.py b/keras/api/_tf_keras/keras/quantizers/__init__.py index d8a209bbb62..8b11f6a3d63 100644 --- a/keras/api/_tf_keras/keras/quantizers/__init__.py +++ b/keras/api/_tf_keras/keras/quantizers/__init__.py @@ -12,4 +12,7 @@ from keras.src.quantizers.quantizers import abs_max_quantize from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale +from keras.src.quantizers.quantizers import ( + fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel, +) from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/api/quantizers/__init__.py b/keras/api/quantizers/__init__.py index d8a209bbb62..8b11f6a3d63 100644 --- a/keras/api/quantizers/__init__.py +++ b/keras/api/quantizers/__init__.py @@ -12,4 +12,7 @@ from keras.src.quantizers.quantizers import abs_max_quantize from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale +from keras.src.quantizers.quantizers import ( + fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel, +) from keras.src.quantizers.quantizers import quantize_and_dequantize diff --git a/keras/src/quantizers/__init__.py b/keras/src/quantizers/__init__.py index b12d5cc84d7..dc7643e1e82 100644 --- a/keras/src/quantizers/__init__.py +++ b/keras/src/quantizers/__init__.py @@ -6,6 +6,7 @@ from keras.src.quantizers.quantizers import abs_max_quantize from keras.src.quantizers.quantizers import compute_float8_amax_history from keras.src.quantizers.quantizers import compute_float8_scale +from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars from keras.src.quantizers.quantizers import quantize_and_dequantize from keras.src.saving import serialization_lib from keras.src.utils.naming import to_snake_case diff --git a/keras/src/quantizers/quantizers.py b/keras/src/quantizers/quantizers.py index 3e4aac181e1..2f7db0c9787 100644 --- a/keras/src/quantizers/quantizers.py +++ b/keras/src/quantizers/quantizers.py @@ -4,6 +4,7 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend.common.backend_utils import canonicalize_axis from keras.src.backend.common.backend_utils import standardize_axis_for_numpy """Int8-related classes and methods""" @@ -127,6 +128,142 @@ def get_config(self): } +def adjust_and_nudge(min_range, max_range, num_bits, narrow_range): + """Adjusts and nudges the quantization range for better accuracy.""" + + quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32") + + quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32") + + # Calculate the scale and ensure it's positive + scale = ops.divide( + ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min) + ) + + inv_scale = ops.reciprocal(scale) + + # Calculate the zero point from the min range + zero_point_from_min = quant_min - ops.divide(min_range, scale) + + # Ensure zero point is within valid range [0, quant_max] + zero_point = ops.clip(zero_point_from_min, quant_min, quant_max) + + # Nudge zero point if it's very close to an integer + nudged_zero_point = ops.round(zero_point) + + # Calculate nudged limits + nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale) + nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale) + + return nudged_min, nudged_max, scale, inv_scale + + +@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel") +def fake_quant_with_min_max_vars( + inputs, + min_vals, + max_vals, + num_bits, + narrow_range=False, + axis=None, +): + """ + Perform per-tensor or per-channel fake quantization. + + `[min_vals, max_vals]` define the clamping range for the `inputs`. + + The `inputs` are quantized into the quantization range: + - `[0, 2^num_bits - 1]` when `narrow_range=False` + - `[1, 2^num_bits - 1]` when `narrow_range=True` + + After quantization, the values are dequantized and output as floats within + the `[min_vals, max_vals]` interval. + + This operation supports gradient computation, allowing `min_vals` and + `max_vals` to be trained. + + Args: + inputs: Input tensor of float dtype. + min_vals: A global minimum scalar or a per-channel minimum tensor. + max_vals: A global maximum scalar or a per-channel maximum tensor. + num_bits: Quantization bit width (e.g., `8` for int8). + narrow_range: Whether to use narrow quantization range. + axis: Axis along which to perform per-channel quantization. If `None`, + per-tensor quantization is performed. Defaults to `None`. + + + Returns: + Fake-quantized tensor + """ + inputs = ops.convert_to_tensor(inputs) + min_vals = ops.convert_to_tensor(min_vals) + max_vals = ops.convert_to_tensor(max_vals) + + if axis is not None: + axis = canonicalize_axis(axis, inputs.ndim) + + @ops.custom_gradient + def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val): + # Calculate quantization parameters for all channels at once + nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge( + min_val, max_val, num_bits, narrow_range + ) + + quant_zero = ops.floor( + ops.add(ops.multiply(-nudged_min, inv_scale), 0.5) + ) + x_clamped = ops.clip(x, nudged_min, nudged_max) + x_clamped_shifted = ops.subtract(x_clamped, nudged_min) + result = ops.multiply( + ops.floor( + ops.add( + ops.subtract( + ops.multiply(x_clamped_shifted, inv_scale), quant_zero + ), + 0.5, + ) + ), + scale, + ) + + # Create gradient mask for all channels + masks = ops.cast( + (x >= nudged_min) & (x <= nudged_max), + dtype="float32", + ) + + def grad(*args, upstream=None): + if upstream is None: + (upstream,) = args + + # Gradient for x + dx = ops.multiply(upstream, masks) + axes = [i for i in range(len(dx.shape)) if i != axis] + # Gradient for min_val + # When x is clipped to min, the gradient flows to min_val + min_mask = ops.cast(x <= nudged_min, dtype="float32") + grad_min = ops.multiply(upstream, min_mask) + if axis is not None: + grad_min = ops.sum(grad_min, axis=axes) + else: + grad_min = ops.sum(grad_min) + + # Gradient for max_val + # When x is clipped to max, the gradient flows to max_val + max_mask = ops.cast(x >= nudged_max, dtype="float32") + grad_max = ops.multiply(upstream, max_mask) + if axis is not None: + grad_max = ops.sum(grad_max, axis=axes) + else: + grad_max = ops.sum(grad_max) + + return dx, grad_min, grad_max + + return result, grad + + return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals) + + """Float8-related methods""" diff --git a/keras/src/quantizers/quantizers_test.py b/keras/src/quantizers/quantizers_test.py index 2d62240080e..d71f8fe2a1f 100644 --- a/keras/src/quantizers/quantizers_test.py +++ b/keras/src/quantizers/quantizers_test.py @@ -1,3 +1,6 @@ +from absl.testing import parameterized + +from keras.src import backend from keras.src import ops from keras.src import quantizers from keras.src import random @@ -100,3 +103,381 @@ def test_quantize_and_dequantize(self): ) # A loose assertion due to an expected quantization error self.assertAllClose(qdq_values, values, atol=5e-1) + + @parameterized.named_parameters( + [ + { + "testcase_name": "wide_8bits_input_mins_0.0_input_maxs_255.0", + "narrow_range": False, + "input_mins": [0.0], + "input_maxs": [255.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [255.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_0.5_input_maxs_128.0", + "narrow_range": False, + "input_mins": [0.5], + "input_maxs": [128.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_-128.0_input_maxs_-0.5", + "narrow_range": False, + "input_mins": [-128.0], + "input_maxs": [-0.5], + "num_bits": 8, + "expected_nudged_input_mins": [-127.5], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_input_mins_-0.1_input_maxs_127.4", + "narrow_range": False, + "input_mins": [-0.1], + "input_maxs": [127.4], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_8bits_input_mins_0.0_input_maxs_254.0", + "narrow_range": True, + "input_mins": [0.0], + "input_maxs": [254.0], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [254.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "narrow_8bits_input_mins_0.1_input_maxs_127.1", + "narrow_range": True, + "input_mins": [0.1], + "input_maxs": [127.1], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_8bits_input_mins_-127.1_input_maxs_-0.1" + ), + "narrow_range": True, + "input_mins": [-127.1], + "input_maxs": [-0.1], + "num_bits": 8, + "expected_nudged_input_mins": [-127.0], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_8bits_input_mins_-0.1_input_maxs_126.9" + ), + "narrow_range": True, + "input_mins": [-0.1], + "input_maxs": [126.9], + "num_bits": 8, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_0.0_input_maxs_127.0", + "narrow_range": False, + "input_mins": [0.0], + "input_maxs": [127.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [127.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_0.5_input_maxs_64.0", + "narrow_range": False, + "input_mins": [0.5], + "input_maxs": [64.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_-64.0_input_maxs_-0.5", + "narrow_range": False, + "input_mins": [-64.0], + "input_maxs": [-0.5], + "num_bits": 7, + "expected_nudged_input_mins": [-63.5], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_7bits_input_mins_-0.1_input_maxs_63.4", + "narrow_range": False, + "input_mins": [-0.1], + "input_maxs": [63.4], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.5], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_0.0_input_maxs_126.0", + "narrow_range": True, + "input_mins": [0.0], + "input_maxs": [126.0], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [126.0], + "expected_steps": [1.0], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_0.1_input_maxs_63.1", + "narrow_range": True, + "input_mins": [0.1], + "input_maxs": [63.1], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": ( + "narrow_7bits_input_mins_-63.1_input_maxs_-0.1" + ), + "narrow_range": True, + "input_mins": [-63.1], + "input_maxs": [-0.1], + "num_bits": 7, + "expected_nudged_input_mins": [-63.0], + "expected_nudged_input_maxs": [0.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "narrow_7bits_input_mins_-0.1_input_maxs_62.9", + "narrow_range": True, + "input_mins": [-0.1], + "input_maxs": [62.9], + "num_bits": 7, + "expected_nudged_input_mins": [0.0], + "expected_nudged_input_maxs": [63.0], + "expected_steps": [0.5], + "axis": None, + }, + { + "testcase_name": "wide_8bits_multi_channel", + "narrow_range": False, + "input_mins": [0.0, 0.5, -128.0, -0.1], + "input_maxs": [255.0, 128.0, -0.5, 127.4], + "num_bits": 8, + "expected_nudged_input_mins": [0.0, 0.0, -127.5, 0.0], + "expected_nudged_input_maxs": [255.0, 127.5, 0.0, 127.5], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "narrow_8bits_multi_channel", + "narrow_range": True, + "input_mins": [0.0, 0.1, -127.1, -0.1], + "input_maxs": [254.0, 127.1, -0.1, 126.9], + "num_bits": 8, + "expected_nudged_input_mins": [0.0, 0.0, -127.0, 0.0], + "expected_nudged_input_maxs": [254.0, 127.0, 0.0, 127.0], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "wide_7bits_multi_channel", + "narrow_range": False, + "input_mins": [0.0, 0.5, -64.0, -0.1], + "input_maxs": [127.0, 64.0, -0.5, 63.4], + "num_bits": 7, + "expected_nudged_input_mins": [0.0, 0.0, -63.5, 0.0], + "expected_nudged_input_maxs": [127.0, 63.5, 0.0, 63.5], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + { + "testcase_name": "narrow_7bits_multi_channel", + "narrow_range": True, + "input_mins": [0.0, 0.1, -63.1, -0.1], + "input_maxs": [126.0, 63.1, -0.1, 62.9], + "num_bits": 7, + "expected_nudged_input_mins": [0.0, 0.0, -63.0, 0.0], + "expected_nudged_input_maxs": [126.0, 63.0, 0.0, 63.0], + "expected_steps": [1.0, 0.5, 0.5, 0.5], + "axis": 1, + }, + ] + ) + def test_op( + self, + input_mins, + input_maxs, + num_bits, + narrow_range, + axis, + expected_nudged_input_mins, + expected_nudged_input_maxs, + expected_steps, + ): + num_channels = len(input_mins) + inputs_list = [] + expected_list = [] + initial_gradients_list = [] + expected_backprops_wrt_input_list = [] + for i in range(num_channels): + expected_nudged_input_min = expected_nudged_input_mins[i] + expected_nudged_input_max = expected_nudged_input_maxs[i] + expected_step = expected_steps[i] + + inputs_list.append( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, + expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, + expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step, + ] + ) + expected_list.append( + [ + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, + expected_nudged_input_max, + expected_nudged_input_max, + expected_nudged_input_max, + ] + ) + initial_gradients_list.append( + list(range(1, len(inputs_list[-1]) + 1)) + ) + expected_backprops_wrt_input_list.append( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0] + ) + inputs = ops.transpose(ops.array(inputs_list, dtype="float32")) + expected = ops.transpose(ops.array(expected_list, dtype="float32")) + expected_backprops_wrt_input = ops.transpose( + ops.array(expected_backprops_wrt_input_list, dtype="float32") + ) + input_min = ops.array(input_mins, dtype="float32") + input_max = ops.array(input_maxs, dtype="float32") + initial_gradients = ops.transpose( + ops.array(initial_gradients_list, dtype="float32") + ) + if backend.backend() == "tensorflow": + import tensorflow as tf + + @tf.function(jit_compile=True) + def test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ): + with tf.GradientTape() as tape: + tape.watch(inputs) + result = quantizers.fake_quant_with_min_max_vars( + inputs, + input_mins, + input_maxs, + num_bits, + narrow_range, + axis, + ) + return initial_gradients * tape.gradient(result, inputs) + + gradients = test_op( + inputs, input_mins, input_maxs, num_bits, narrow_range, axis + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + + if backend.backend() == "torch": + import torch + + def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + # Create tensor and enable gradient tracking + inputs = torch.tensor( + inputs, dtype=torch.float32, requires_grad=True + ) + + # Apply the quantization operation + result = quantizers.fake_quant_with_min_max_vars( + inputs, input_mins, input_maxs, num_bits, narrow_range + ) + + # Compute gradients + result.backward(torch.ones_like(result)) + + return initial_gradients * inputs.grad + + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + + if backend.backend() == "jax": + import jax + + def test_op(inputs, input_mins, input_maxs, num_bits, narrow_range): + # Define the function to compute gradients for + def quantize_fn(x): + return quantizers.fake_quant_with_min_max_vars( + x, input_mins, input_maxs, num_bits, narrow_range + ) + + _, f_vjp = jax.vjp(quantize_fn, inputs) + # NOTE:python 3.10 input_gradients = f_vjp.args[0].args[0][0] ! + input_gradients = f_vjp.args[0].args[0][1] + + return ops.multiply(initial_gradients, input_gradients) + + gradients = test_op( + inputs, input_min, input_max, num_bits, narrow_range + ) + # test gradients + self.assertAllClose(gradients, expected_backprops_wrt_input) + outputs = quantizers.fake_quant_with_min_max_vars( + inputs, + input_min, + input_max, + num_bits=num_bits, + narrow_range=narrow_range, + axis=axis, + ) + self.assertAllClose(outputs, expected)