From 8e7a1a4f5a73d6d68173c4cbddb6155a69e33fe2 Mon Sep 17 00:00:00 2001 From: Akshaya Purohit Date: Thu, 2 Jan 2025 16:18:41 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 711551176 Change-Id: I78f6cc43a90c86e2249a53d597f42761f252c7d9 --- qkeras/__init__.py | 1 + qkeras/quantizer_registry.py | 34 +++++ qkeras/quantizers.py | 224 ++++++++++++++++++------------- tests/quantizer_registry_test.py | 51 +++++++ 4 files changed, 216 insertions(+), 94 deletions(-) create mode 100644 qkeras/quantizer_registry.py create mode 100644 tests/quantizer_registry_test.py diff --git a/qkeras/__init__.py b/qkeras/__init__.py index 4f1a069..3612e40 100644 --- a/qkeras/__init__.py +++ b/qkeras/__init__.py @@ -35,6 +35,7 @@ #from .qtools.run_qtools import QTools #from .qtools.settings import cfg from .quantizers import * # pylint: disable=wildcard-import +from .registry import * # pylint: disable=wildcard-import from .safe_eval import * # pylint: disable=wildcard-import diff --git a/qkeras/quantizer_registry.py b/qkeras/quantizer_registry.py new file mode 100644 index 0000000..25e59f8 --- /dev/null +++ b/qkeras/quantizer_registry.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for QKeras quantizers.""" + +from . import registry + +# Global registry for all QKeras quantizers. +_QUANTIZERS_REGISTRY = registry.Registry() + + +def register_quantizer(quantizer): + """Decorator for registering a quantizer.""" + _QUANTIZERS_REGISTRY.register(quantizer) + # Return the quantizer after registering. This ensures any registered + # quantizer class is properly defined. + return quantizer + + +def lookup_quantizer(name): + """Retrieves a quantizer from the quantizers registry.""" + return _QUANTIZERS_REGISTRY.lookup(name) diff --git a/qkeras/quantizers.py b/qkeras/quantizers.py index 5876bec..ddc76de 100644 --- a/qkeras/quantizers.py +++ b/qkeras/quantizers.py @@ -17,19 +17,22 @@ from __future__ import division from __future__ import print_function -import six import re -from typing import List, Any, Tuple +from typing import Any, List, Tuple + import numpy as np -import tensorflow.compat.v2 as tf -import tensorflow.keras.backend as K +import six from six.moves import range +import tensorflow.compat.v2 as tf from tensorflow.keras import initializers +import tensorflow.keras.backend as K from tensorflow.keras.utils import deserialize_keras_object -from tensorflow.python.framework import smart_cond as tf_utils -from .safe_eval import safe_eval -# from .google_internals.experimental_quantizers import quantized_bits_learnable_scale + +from . import quantizer_registry # from .google_internals.experimental_quantizers import parametric_quantizer_d_xmax +# from .google_internals.experimental_quantizers import quantized_bits_learnable_scale +from .safe_eval import safe_eval +from tensorflow.python.framework import smart_cond as tf_utils # # Library of auxiliary functions @@ -747,6 +750,8 @@ def trainable_variables(self): def non_trainable_variables(self): return () + +@quantizer_registry.register_quantizer class quantized_linear(BaseQuantizer): """Linear quantization with fixed number of bits. @@ -981,7 +986,7 @@ def use_stochastic_rounding(self): @property def scale_axis(self): return self._scale_axis - + @property def use_variables(self): return self._use_variables @@ -989,7 +994,7 @@ def use_variables(self): @property def scale(self): return self.quantization_scale / self.data_type_scale - + @property def data_type_scale(self): """Quantization scale for the data type""" @@ -1008,7 +1013,7 @@ def use_sign_function(self): """Return true if using sign function for quantization""" return (self.bits == 1.0) and self.keep_negative - + @property def default_quantization_scale(self): """Calculate and set quantization_scale default""" @@ -1018,7 +1023,7 @@ def default_quantization_scale(self): # Quantization scale given by alpha if self.alpha is not None and not self.auto_alpha: - quantization_scale = self.alpha * self.data_type_scale + quantization_scale = self.alpha * self.data_type_scale return quantization_scale @@ -1046,7 +1051,7 @@ def __call__(self, x): # Data type conversion x = K.cast_to_floatx(x) shape = x.shape - + if self.auto_alpha: # get data-dependent quantization scale quantization_scale = self._get_auto_quantization_scale(x) @@ -1062,7 +1067,7 @@ def __call__(self, x): res.set_shape(shape) return res - + def _scale_clip_and_round(self, x, quantization_scale): """Scale, clip, and round x to an integer value in a limited range Note that the internal shift is needed for 1-bit quantization to ensure @@ -1076,8 +1081,8 @@ def _scale_clip_and_round(self, x, quantization_scale): scaled_x = x / quantization_scale clipped_scaled_x = K.clip(scaled_x, clip_min, clip_max) - # Round through to nearest integer, using straight-through estimator - # for gradient computations. + # Round through to nearest integer, using straight-through estimator + # for gradient computations. scaled_xq = _round_through( clipped_scaled_x - shift, use_stochastic_rounding=self.use_stochastic_rounding, @@ -1085,7 +1090,7 @@ def _scale_clip_and_round(self, x, quantization_scale): ) return scaled_xq + shift - + def _get_auto_quantization_scale(self, x): """Get quantization_scale, either from self or from input x""" @@ -1111,7 +1116,7 @@ def _get_quantization_scale_from_max_data(self, x): clip_min, clip_max = self.get_clip_bounds() clip_range = clip_max - clip_min - + # get quantization scale- depends on whether we are keeping negative # divide by clip range to ensure that we clip right at the max of x if self.keep_negative: @@ -1151,12 +1156,12 @@ def loop_cond(last_quantization_scale, quantization_scale): tf.not_equal(last_quantization_scale, quantization_scale)) return tensors_not_equal - # Need a tensor of the same shape as quantization_scale that + # Need a tensor of the same shape as quantization_scale that # does not equal quantization_scale dummy_quantization_scale = -tf.ones_like(quantization_scale) # For 1-bit quantization, po2 autoscale loop is guaranteed to converge - # after 1 iteration + # after 1 iteration max_iterations = 1 if self.use_sign_function else 5 _, quantization_scale = tf.while_loop( @@ -1198,7 +1203,7 @@ def range(self): neg_array = K.cast_to_floatx(tf.range(clip_min, 0)) return self.quantization_scale * tf.concat([pos_array, neg_array], axis=0) - + def __str__(self): # Main parameters always printed in string @@ -1242,6 +1247,8 @@ def get_config(self): } return config + +@quantizer_registry.register_quantizer class quantized_bits(BaseQuantizer): # pylint: disable=invalid-name """Legacy quantizer: Quantizes the number to a number of bits. @@ -1557,6 +1564,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class bernoulli(BaseQuantizer): # pylint: disable=invalid-name """Computes a Bernoulli sample with probability sigmoid(x). @@ -1671,6 +1679,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class ternary(BaseQuantizer): # pylint: disable=invalid-name """Computes an activation function returning -alpha, 0 or +alpha. @@ -1820,6 +1829,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class stochastic_ternary(ternary): # pylint: disable=invalid-name """Computes a stochastic activation function returning -alpha, 0 or +alpha. @@ -1831,8 +1841,8 @@ class stochastic_ternary(ternary): # pylint: disable=invalid-name bits: number of bits to perform quantization. alpha: ternary is -alpha or +alpha, or "auto" or "auto_po2". threshold: (1-threshold) specifies the spread of the +1 and -1 values. - temperature: amplifier factor for sigmoid function, making stochastic - less stochastic as it moves away from 0. + temperature: amplifier factor for sigmoid function, making stochastic less + stochastic as it moves away from 0. use_real_sigmoid: use real sigmoid for probability. number_of_unrolls: number of times we iterate between scale and threshold. @@ -1840,12 +1850,17 @@ class stochastic_ternary(ternary): # pylint: disable=invalid-name Computation of sign with stochastic sampling with straight through gradient. """ - def __init__(self, alpha=None, threshold=None, temperature=8.0, - use_real_sigmoid=True, number_of_unrolls=5): - super(stochastic_ternary, self).__init__( - alpha=alpha, - threshold=threshold, - number_of_unrolls=number_of_unrolls) + def __init__( + self, + alpha=None, + threshold=None, + temperature=8.0, + use_real_sigmoid=True, + number_of_unrolls=5, + ): + super().__init__( + alpha=alpha, threshold=threshold, number_of_unrolls=number_of_unrolls + ) self.bits = 2 self.alpha = alpha @@ -1903,10 +1918,11 @@ def stochastic_output(): x_std = K.std(x, axis=axis, keepdims=True) m = K.max(tf.abs(x), axis=axis, keepdims=True) - scale = 2.*m/3. + scale = 2.0 * m / 3.0 if self.alpha == "auto_po2": - scale = K.pow(2.0, - tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0))) + scale = K.pow( + 2.0, tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0)) + ) for _ in range(self.number_of_unrolls): T = scale / 2.0 q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x) @@ -1924,18 +1940,17 @@ def stochastic_output(): r0 = tf.random.uniform(tf.shape(p0)) r1 = tf.random.uniform(tf.shape(p1)) q0 = tf.sign(p0 - r0) - q0 += (1.0 - tf.abs(q0)) + q0 += 1.0 - tf.abs(q0) q1 = tf.sign(p1 - r1) - q1 += (1.0 - tf.abs(q1)) + q1 += 1.0 - tf.abs(q1) q = (q0 + q1) / 2.0 self.scale = scale return x + tf.stop_gradient(-x + scale * q) output = tf_utils.smart_cond( - K.learning_phase(), - stochastic_output, - lambda: ternary.__call__(self, x)) + K.learning_phase(), stochastic_output, lambda: ternary.__call__(self, x) + ) return output def _set_trainable_parameter(self): @@ -1966,11 +1981,12 @@ def get_config(self): "threshold": self.threshold, "temperature": self.temperature, "use_real_sigmoid": self.use_real_sigmoid, - "number_of_unrolls": self.number_of_unrolls + "number_of_unrolls": self.number_of_unrolls, } return config +@quantizer_registry.register_quantizer class binary(BaseQuantizer): # pylint: disable=invalid-name """Computes the sign(x) returning a value between -alpha and alpha. @@ -2170,6 +2186,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class stochastic_binary(binary): # pylint: disable=invalid-name """Computes a stochastic activation function returning -alpha or +alpha. @@ -2181,7 +2198,7 @@ class stochastic_binary(binary): # pylint: disable=invalid-name alpha: binary is -alpha or +alpha, or "auto" or "auto_po2". bits: number of bits to perform quantization. temperature: amplifier factor for sigmoid function, making stochastic - behavior less stochastic as it moves away from 0. + behavior less stochastic as it moves away from 0. use_real_sigmoid: use real sigmoid from tensorflow for probablity. Returns: @@ -2233,15 +2250,16 @@ def stochastic_output(): r = tf.random.uniform(tf.shape(x)) q = tf.sign(p - r) - q += (1.0 - tf.abs(q)) + q += 1.0 - tf.abs(q) q_non_stochastic = tf.sign(x) - q_non_stochastic += (1.0 - tf.abs(q_non_stochastic)) + q_non_stochastic += 1.0 - tf.abs(q_non_stochastic) scale = _get_least_squares_scale(self.alpha, x, q_non_stochastic) self.scale = scale return x + tf.stop_gradient(-x + scale * q) output = tf_utils.smart_cond( - K.learning_phase(), stochastic_output, lambda: binary.__call__(self, x)) + K.learning_phase(), stochastic_output, lambda: binary.__call__(self, x) + ) return output def _set_trainable_parameter(self): @@ -2275,6 +2293,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class quantized_relu(BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized relu to a number of bits. @@ -2496,8 +2515,7 @@ def get_config(self): return config - - +@quantizer_registry.register_quantizer class quantized_ulaw(BaseQuantizer): # pylint: disable=invalid-name """Computes a u-law quantization. @@ -2571,6 +2589,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class quantized_tanh(BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized tanh to a number of bits. @@ -2640,6 +2659,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class quantized_sigmoid(BaseQuantizer): # pylint: disable=invalid-name """Computes a quantized sigmoid to a number of bits. @@ -2828,6 +2848,7 @@ def _get_min_max_exponents(non_sign_bits, need_exponent_sign_bit, return min_exp, max_exp +@quantizer_registry.register_quantizer class quantized_po2(BaseQuantizer): # pylint: disable=invalid-name """Quantizes to the closest power of 2. @@ -2964,6 +2985,7 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class quantized_relu_po2(BaseQuantizer): # pylint: disable=invalid-name """Quantizes x to the closest power of 2 when x > 0 @@ -3137,8 +3159,10 @@ def get_config(self): return config +@quantizer_registry.register_quantizer class quantized_hswish(quantized_bits): # pylint: disable=invalid-name """Computes a quantized hard swish to a number of bits. + # TODO(mschoenb97): Update to inherit from quantized_linear. Equation of h-swisth function in mobilenet v3: @@ -3151,47 +3175,48 @@ class quantized_hswish(quantized_bits): # pylint: disable=invalid-name symmetric: if True, the quantization is in symmetric mode, which puts restricted range for the quantizer. Otherwise, it is in asymmetric mode, which uses the full range. - alpha: a tensor or None, the scaling factor per channel. - If None, the scaling factor is 1 for all channels. + alpha: a tensor or None, the scaling factor per channel. If None, the + scaling factor is 1 for all channels. use_stochastic_rounding: if true, we perform stochastic rounding. This - parameter is passed on to the underlying quantizer quantized_bits which - is used to quantize h_swish. + parameter is passed on to the underlying quantizer quantized_bits which is + used to quantize h_swish. scale_axis: which axis to calculate scale from qnoise_factor: float. a scalar from 0 to 1 that represents the level of quantization noise to add. This controls the amount of the quantization - noise to add to the outputs by changing the weighted sum of - (1 - qnoise_factor)*unquantized_x + qnoise_factor*quantized_x. + noise to add to the outputs by changing the weighted sum of (1 - + qnoise_factor)*unquantized_x + qnoise_factor*quantized_x. var_name: String or None. A variable name shared between the tf.Variables created in the build function. If None, it is generated automatically. use_ste: Bool. Whether to use "straight-through estimator" (STE) method or - not. + not. use_variables: Bool. Whether to make the quantizer variables to be dynamic tf.Variables or not. - relu_shift: integer type, representing the shift amount - of the unquantized relu. + relu_shift: integer type, representing the shift amount of the unquantized + relu. relu_upper_bound: integer type, representing an upper bound of the unquantized relu. If None, we apply relu without the upper bound when "is_quantized_clip" is set to false (true by default). Note: The quantized relu uses the quantization parameters (bits and - integer) to upper bound. So it is important to set relu_upper_bound - appropriately to the quantization parameters. "is_quantized_clip" - has precedence over "relu_upper_bound" for backward compatibility. - + integer) to upper bound. So it is important to set relu_upper_bound + appropriately to the quantization parameters. "is_quantized_clip" has + precedence over "relu_upper_bound" for backward compatibility. """ - def __init__(self, - bits=8, - integer=0, - symmetric=0, - alpha=None, - use_stochastic_rounding=False, - scale_axis=None, - qnoise_factor=1.0, - var_name=None, - use_variables=False, - relu_shift: int = 3, - relu_upper_bound: int = 6): - super(quantized_hswish, self).__init__( + def __init__( + self, + bits=8, + integer=0, + symmetric=0, + alpha=None, + use_stochastic_rounding=False, + scale_axis=None, + qnoise_factor=1.0, + var_name=None, + use_variables=False, + relu_shift: int = 3, + relu_upper_bound: int = 6, + ): + super().__init__( bits=bits, integer=integer, symmetric=symmetric, @@ -3201,26 +3226,33 @@ def __init__(self, scale_axis=scale_axis, qnoise_factor=qnoise_factor, var_name=var_name, - use_variables=use_variables) + use_variables=use_variables, + ) self.relu_shift = relu_shift self.relu_upper_bound = relu_upper_bound def __str__(self): - """ Converts Tensors to printable strings.""" + """Converts Tensors to printable strings.""" - integer_bits = ( - re.sub(r"\[(\d)\]", r"\g<1>", - str(self.integer.numpy() if isinstance(self.integer, tf.Variable) - else self.integer))) + integer_bits = re.sub( + r"\[(\d)\]", + r"\g<1>", + str( + self.integer.numpy() + if isinstance(self.integer, tf.Variable) + else self.integer + ), + ) assert isinstance(integer_bits, int) - flags = [str(self.bits), - integer_bits, - str(int(self.symmetric)), - "relu_shift=" + str(self.relu_shift), - "relu_upper_bound=" + str(self.relu_upper_bound) - ] + flags = [ + str(self.bits), + integer_bits, + str(int(self.symmetric)), + "relu_shift=" + str(self.relu_shift), + "relu_upper_bound=" + str(self.relu_upper_bound), + ] if not self.keep_negative: flags.append("keep_negative=False") @@ -3230,22 +3262,26 @@ def __str__(self): alpha = "'" + alpha + "'" flags.append("alpha=" + alpha) if self.use_stochastic_rounding: - flags.append("use_stochastic_rounding=" + - str(int(self.use_stochastic_rounding))) + flags.append( + "use_stochastic_rounding=" + str(int(self.use_stochastic_rounding)) + ) return "quantized_hswish(" + ",".join(flags) + ")" def __call__(self, x): assert self.relu_upper_bound > 0, ( - f"relu_upper_bound must be a positive value, " - f"found {self.relu_upper_bound} instead") - assert self.relu_shift > 0, ( - f"relu_shift must be a positive value, " - f"found {self.relu_shift} instead") + "relu_upper_bound must be a positive value, " + f"found {self.relu_upper_bound} instead" + ) + assert ( + self.relu_shift > 0 + ), f"relu_shift must be a positive value, found {self.relu_shift} instead" x = K.cast_to_floatx(x) shift_x = x + self.relu_shift - relu_x = tf.where(shift_x <= self.relu_upper_bound, - K.relu(shift_x, alpha=False), - tf.ones_like(shift_x) * self.relu_upper_bound) + relu_x = tf.where( + shift_x <= self.relu_upper_bound, + K.relu(shift_x, alpha=False), + tf.ones_like(shift_x) * self.relu_upper_bound, + ) hswish_x = tf.math.multiply(x, relu_x) / self.relu_upper_bound return super(quantized_hswish, self).__call__(hswish_x) @@ -3275,14 +3311,14 @@ def get_config(self): config = { "relu_shift": self.relu_shift, - "relu_upper_bound": self.relu_upper_bound + "relu_upper_bound": self.relu_upper_bound, } - out_config = dict( - list(base_config.items()) + list(config.items())) + out_config = dict(list(base_config.items()) + list(config.items())) return out_config +# TODO(akshayap): Update to use registry for quantizers instead of globals(). def get_quantizer(identifier): """Gets the quantizer. diff --git a/tests/quantizer_registry_test.py b/tests/quantizer_registry_test.py new file mode 100644 index 0000000..4d5152f --- /dev/null +++ b/tests/quantizer_registry_test.py @@ -0,0 +1,51 @@ +# Copyright 2024 Google LLC +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for QKeras quantizer registry.""" + +import numpy as np +import pytest + +from qkeras import quantizer_registry +from qkeras import quantizers + + +@pytest.mark.parametrize( + "quantizer_name", + [ + "quantized_linear", + "quantized_bits", + "bernoulli", + "ternary", + "stochastic_ternary", + "binary", + "stochastic_binary", + "quantized_relu", + "quantized_ulaw", + "quantized_tanh", + "quantized_sigmoid", + "quantized_po2", + "quantized_relu_po2", + "quantized_hswish", + ], +) +def test_lookup(quantizer_name): + quantizer = quantizer_registry.lookup_quantizer(quantizer_name) + is_class_instance = isinstance(quantizer, type) + np.testing.assert_equal(is_class_instance, True) + + +if __name__ == "__main__": + pytest.main([__file__])