Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Porting TF fake_quant_with_min_max functions #20641

Merged
merged 16 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions keras/api/_tf_keras/keras/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
5 changes: 5 additions & 0 deletions keras/api/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,9 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
5 changes: 5 additions & 0 deletions keras/src/quantizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from keras.src.quantizers.quantizers import abs_max_quantize
from keras.src.quantizers.quantizers import compute_float8_amax_history
from keras.src.quantizers.quantizers import compute_float8_scale
from keras.src.quantizers.quantizers import fake_quant_with_min_max_args
from keras.src.quantizers.quantizers import fake_quant_with_min_max_vars
from keras.src.quantizers.quantizers import (
fake_quant_with_min_max_vars_per_channel,
)
from keras.src.quantizers.quantizers import quantize_and_dequantize
from keras.src.saving import serialization_lib
from keras.src.utils.naming import to_snake_case
Expand Down
175 changes: 175 additions & 0 deletions keras/src/quantizers/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,181 @@ def get_config(self):
}


def adjust_and_nudge(min_range, max_range, num_bits, narrow_range):
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
"""Adjusts and nudges the quantization range for better accuracy."""
if num_bits < 2:
raise ValueError("num_bits must be >= 2")

n_steps = ops.cast(2**num_bits - 1, "float32")
n_steps = n_steps if not narrow_range else n_steps - 1.0

# Handle the case where min and max are too close
# if abs(max_range - min_range) < 1e-10:
# return min_range, max_range, 1.0

# Calculate the step size
step_size = (max_range - min_range) / n_steps
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved

# Calculate the reciprocal of the step size
inv_step_size = 1.0 / step_size

# Round the reciprocal to get an integer
rounded_inv_step_size = ops.round(inv_step_size)

# Calculate the final step size
final_step_size = 1.0 / rounded_inv_step_size

# Calculate the quantized min/max values, ensuring accurate rounding
quantized_min = (
ops.round(min_range * rounded_inv_step_size) / rounded_inv_step_size
)
quantized_max = (
ops.round(max_range * rounded_inv_step_size) / rounded_inv_step_size
)

# Convert quantization limits to float
quant_min_float = ops.cast(quantized_min, "float32")
quant_max_float = ops.cast(quantized_max, "float32")

# Calculate the scale
nudged_scale = (max_range - min_range) / (quant_max_float - quant_min_float)

# Calculate zero point from min
zero_point_from_min = quant_min_float - min_range / nudged_scale

# Determine nudged zero point
nudged_zero_point = ops.where(
zero_point_from_min < quant_min_float,
quantized_min,
ops.where(
zero_point_from_min > quant_max_float,
quantized_max,
ops.round(zero_point_from_min),
),
)

# Calculate nudged min and max
nudged_min = (quant_min_float - nudged_zero_point) * nudged_scale
nudged_max = (quant_max_float - nudged_zero_point) * nudged_scale

return (
nudged_min,
nudged_max,
final_step_size,
) # Returning nudged values and scale


@keras_export("keras.quantizers.fake_quant_with_min_max_vars_per_channel")
def fake_quant_with_min_max_vars_per_channel(
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
inputs,
min_vals,
max_vals,
num_bits,
narrow_range,
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
):
"""
Perform per-channel fake quantization with custom gradient using vectorized
operations.
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved

Args:
inputs: Input tensor of float type
min_vals: Per-channel minimum values
max_vals: Per-channel maximum values
num_bits: Quantization bit width (2-16)
narrow_range: Whether to use narrow quantization range

Returns:
Fake-quantized tensor
"""
inputs = ops.convert_to_tensor(inputs)
min_vals = ops.convert_to_tensor(min_vals)
max_vals = ops.convert_to_tensor(max_vals)

@ops.custom_gradient
def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):
# Calculate quantization parameters for all channels at once
qnt_min, qnt_max, step_size = adjust_and_nudge(
min_val, max_val, num_bits, narrow_range
)

# Calculate number of steps
n_steps = 2**num_bits - 1
if narrow_range:
n_steps -= 1

# Expand dimensions to allow broadcasting
qnt_min = ops.expand_dims(qnt_min, axis=list(range(len(x.shape) - 1)))
qnt_max = ops.expand_dims(qnt_max, axis=list(range(len(x.shape) - 1)))
step_size = ops.expand_dims(
step_size, axis=list(range(len(x.shape) - 1))
)

# Clip and quantize all channels simultaneously
x_clipped = ops.clip(x, qnt_min, qnt_max)
x_norm = (x_clipped - qnt_min) / step_size
x_quantized = ops.round(x_norm)
x_quantized = ops.clip(x_quantized, 0.0, n_steps)
result = x_quantized * step_size + qnt_min

# Create gradient mask for all channels
masks = ops.cast(
(x >= qnt_min) & (x <= qnt_max),
dtype=np.float32,
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
)

def grad(*args, upstream=None):
if upstream is None:
(upstream,) = args

# Gradient for x
dx = ops.multiply(upstream, masks)

# Gradient for min_val
# When x is clipped to min, the gradient flows to min_val
min_mask = ops.cast(x <= qnt_min, dtype=np.float32)
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
dims_to_reduce = list(range(len(x.shape) - 1))
grad_min = ops.sum(upstream * min_mask, axis=dims_to_reduce)

# Gradient for max_val
# When x is clipped to max, the gradient flows to max_val
max_mask = ops.cast(x >= qnt_max, dtype=np.float32)
doncarlos999 marked this conversation as resolved.
Show resolved Hide resolved
grad_max = ops.sum(upstream * max_mask, axis=dims_to_reduce)

return dx, grad_min, grad_max

return result, grad

return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)


@keras_export("keras.quantizers.fake_quant_with_min_max_args")
def fake_quant_with_min_max_args(
inputs,
min_vals,
max_vals,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation matching TensorFlow's implementation."""
return fake_quant_with_min_max_vars_per_channel(
inputs, min_vals, max_vals, num_bits, narrow_range
)


@keras_export("keras.quantizers.fake_quant_with_min_max_vars")
def fake_quant_with_min_max_vars(
inputs,
min_vals,
max_vals,
num_bits=8,
narrow_range=False,
):
"""Fake quantization operation matching TensorFlow's implementation."""
return fake_quant_with_min_max_vars_per_channel(
inputs, min_vals, max_vals, num_bits, narrow_range
)


"""Float8-related methods"""


Expand Down
Loading
Loading