Skip to content

Commit

Permalink
Porting TF fake_quant_with_min_max functions (#20641)
Browse files Browse the repository at this point in the history
* QAT (squashed this time) (#1)

* adds fake_quant_with_min_max functions from TF to keras3

* Addresses PR review comments

* drops another type hint

* swaps out if statements, change float() to ops.cast and adds fake_quant_with_min_max_vars function

* fix missed if statement, adds gradient tests via main function for tf and torch

* fix unbound variable error when not using torch or tf backend (#2)

Refactor to use backend specific gradient functions in tests and merges logic into single function

* More QAT function revisions (#3)

This PR addresses review feedback to fix implementation and to move tests to using named_parameters rather  than individual functions.

* Qat revisions (#4)

Adds axis param and fixes logic for per channel function

* updated implementation

* removed redundant functions
  • Loading branch information
doncarlos999 authored Jan 15, 2025
1 parent e345cbd commit 6557961
Show file tree
Hide file tree
Showing 5 changed files with 525 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
3 changes: 3 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,7 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars as fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
1 change: 1 addition & 0 deletions keras/src/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import quantize_and_dequantize
from keras.src.saving import serialization_lib
from keras.src.utils.naming import to_snake_case
Expand Down
137 changes: 137 additions & 0 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common.backend_utils import canonicalize_axis
from keras.src.backend.common.backend_utils import standardize_axis_for_numpy

"""Int8-related classes and methods"""
Expand Down Expand Up @@ -127,6 +128,142 @@ def get_config(self):
}


def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
"""Adjusts and nudges the quantization range for better accuracy."""

quant_max = ops.cast(ops.subtract(ops.power(2, num_bits), 1.0), "float32")

quant_min = ops.cast(0.0 if not narrow_range else 1.0, "float32")

# Calculate the scale and ensure it's positive
scale = ops.divide(
ops.subtract(max_range, min_range), ops.subtract(quant_max, quant_min)
)

inv_scale = ops.reciprocal(scale)

# Calculate the zero point from the min range
zero_point_from_min = quant_min - ops.divide(min_range, scale)

# Ensure zero point is within valid range [0, quant_max]
zero_point = ops.clip(zero_point_from_min, quant_min, quant_max)

# Nudge zero point if it's very close to an integer
nudged_zero_point = ops.round(zero_point)

# Calculate nudged limits
nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale)
nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale)

return nudged_min, nudged_max, scale, inv_scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
def fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits,
narrow_range=False,
axis=None,
):
"""
Perform per-tensor or per-channel fake quantization.
`[min_vals, max_vals]` define the clamping range for the `inputs`.
The `inputs` are quantized into the quantization range:
- `[0, 2^num_bits - 1]` when `narrow_range=False`
- `[1, 2^num_bits - 1]` when `narrow_range=True`
After quantization, the values are dequantized and output as floats within
the `[min_vals, max_vals]` interval.
This operation supports gradient computation, allowing `min_vals` and
`max_vals` to be trained.
Args:
inputs: Input tensor of float dtype.
min_vals: A global minimum scalar or a per-channel minimum tensor.
max_vals: A global maximum scalar or a per-channel maximum tensor.
num_bits: Quantization bit width (e.g., `8` for int8).
narrow_range: Whether to use narrow quantization range.
axis: Axis along which to perform per-channel quantization. If `None`,
per-tensor quantization is performed. Defaults to `None`.
Returns:
Fake-quantized tensor
"""
inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)

if axis is not None:
axis = canonicalize_axis(axis, inputs.ndim)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
# Calculate quantization parameters for all channels at once
nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
)

quant_zero = ops.floor(
ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)
)
x_clamped = ops.clip(x, nudged_min, nudged_max)
x_clamped_shifted = ops.subtract(x_clamped, nudged_min)
result = ops.multiply(
ops.floor(
ops.add(
ops.subtract(
ops.multiply(x_clamped_shifted, inv_scale), quant_zero
),
0.5,
)
),
scale,
)

# Create gradient mask for all channels
masks = ops.cast(
(x >= nudged_min) & (x <= nudged_max),
dtype="float32",
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient for x
dx = ops.multiply(upstream, masks)
axes = [i for i in range(len(dx.shape)) if i != axis]
# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= nudged_min, dtype="float32")
grad_min = ops.multiply(upstream, min_mask)
if axis is not None:
grad_min = ops.sum(grad_min, axis=axes)
else:
grad_min = ops.sum(grad_min)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= nudged_max, dtype="float32")
grad_max = ops.multiply(upstream, max_mask)
if axis is not None:
grad_max = ops.sum(grad_max, axis=axes)
else:
grad_max = ops.sum(grad_max)

return dx, grad_min, grad_max

return result, grad

return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)


"""Float8-related methods"""


Expand Down
Loading

0 comments on commit 6557961

Please sign in to comment.