From c8c9590d7a2169ed76321dda2d8117031ed3c302 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:27:27 +0800 Subject: [PATCH 1/5] Update `build` in `Dense` and `EinsumDense` for quantized dtype --- keras/layers/core/dense.py | 153 +++++++++++-------- keras/layers/core/dense_test.py | 15 ++ keras/layers/core/einsum_dense.py | 204 +++++++++++++++---------- keras/layers/core/einsum_dense_test.py | 21 +++ 4 files changed, 252 insertions(+), 141 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index a2210c83241b..62a0ecc59727 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -102,13 +102,18 @@ def __init__( def build(self, input_shape): input_dim = input_shape[-1] - self._kernel = self.add_weight( - name="kernel", - shape=(input_dim, self.units), - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - ) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + self.quantized_build( + input_shape, mode=self.dtype_policy.quantization_mode + ) + else: + self._kernel = self.add_weight( + name="kernel", + shape=(input_dim, self.units), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + ) if self.use_bias: self.bias = self.add_weight( name="bias", @@ -120,11 +125,12 @@ def build(self, input_shape): else: self.bias = None self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) - self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): - self.quantize(self.dtype_policy.quantization_mode) + if self.bias is not None: + self.bias.trainable = False + self.built = True @property def kernel(self): @@ -146,20 +152,6 @@ def call(self, inputs): x = self.activation(x) return x - def quantized_call(self, inputs): - if self.lora_enabled: - raise ValueError("`quantized_call` doesn't support lora weights") - inputs, inputs_scale = self.inputs_quantizer(inputs) - x = ops.matmul(inputs, self.kernel) - # De-scale outputs - x = ops.cast(x, self.compute_dtype) - x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) - if self.bias is not None: - x = ops.add(x, self.bias) - if self.activation is not None: - x = self.activation(x) - return x - def compute_output_shape(self, input_shape): output_shape = list(input_shape) output_shape[-1] = self.units @@ -201,6 +193,82 @@ def enable_lora( self.lora_enabled = True self.lora_rank = rank + def save_own_variables(self, store): + if not self.lora_enabled: + return super().save_own_variables(store) + + kernel_value = self.kernel + store["0"] = kernel_value + if self.use_bias: + store["1"] = self.bias + + def load_own_variables(self, store): + if not self.lora_enabled: + return super().load_own_variables(store) + self._kernel.assign(store["0"]) + if self.use_bias: + self.bias.assign(store["1"]) + self.lora_kernel_a.assign(np.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(np.zeros(self.lora_kernel_b.shape)) + + def get_config(self): + base_config = super().get_config() + config = { + "units": self.units, + "activation": activations.serialize(self.activation), + "use_bias": self.use_bias, + "kernel_initializer": initializers.serialize( + self.kernel_initializer + ), + "bias_initializer": initializers.serialize(self.bias_initializer), + "kernel_regularizer": regularizers.serialize( + self.kernel_regularizer + ), + "bias_regularizer": regularizers.serialize(self.bias_regularizer), + "kernel_constraint": constraints.serialize(self.kernel_constraint), + "bias_constraint": constraints.serialize(self.bias_constraint), + } + if self.lora_rank: + config["lora_rank"] = self.lora_rank + return {**base_config, **config} + + """Quantization-related methods""" + + def quantized_build(self, input_shape, mode): + input_dim = input_shape[-1] + kernel_shape = (input_dim, self.units) + if mode == "int8": + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + kernel_scale_shape = (1, kernel_shape[1]) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale_shape, + initializer="zeros", + dtype=self.compute_dtype, + trainable=False, + ) + + def quantized_call(self, inputs): + if self.lora_enabled: + raise ValueError("`quantized_call` doesn't support lora weights") + inputs, inputs_scale = self.inputs_quantizer(inputs) + x = ops.matmul(inputs, self.kernel) + # De-scale outputs + x = ops.cast(x, self.compute_dtype) + x = ops.divide(x, ops.multiply(inputs_scale, self.kernel_scale)) + if self.bias is not None: + x = ops.add(x, self.bias) + if self.activation is not None: + x = self.activation(x) + return x + def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) if mode == "int8": @@ -261,42 +329,3 @@ def _merge_lora_into_kernel(self, untrack=False): self.lora_kernel_b = self._untrack_variable(self.lora_kernel_b) self._tracker.lock() self.lora_rank = None - - def save_own_variables(self, store): - if not self.lora_enabled: - return super().save_own_variables(store) - - kernel_value = self.kernel - store["0"] = kernel_value - if self.use_bias: - store["1"] = self.bias - - def load_own_variables(self, store): - if not self.lora_enabled: - return super().load_own_variables(store) - self._kernel.assign(store["0"]) - if self.use_bias: - self.bias.assign(store["1"]) - self.lora_kernel_a.assign(np.zeros(self.lora_kernel_a.shape)) - self.lora_kernel_b.assign(np.zeros(self.lora_kernel_b.shape)) - - def get_config(self): - base_config = super().get_config() - config = { - "units": self.units, - "activation": activations.serialize(self.activation), - "use_bias": self.use_bias, - "kernel_initializer": initializers.serialize( - self.kernel_initializer - ), - "bias_initializer": initializers.serialize(self.bias_initializer), - "kernel_regularizer": regularizers.serialize( - self.kernel_regularizer - ), - "bias_regularizer": regularizers.serialize(self.bias_regularizer), - "kernel_constraint": constraints.serialize(self.kernel_constraint), - "bias_constraint": constraints.serialize(self.bias_constraint), - } - if self.lora_rank: - config["lora_rank"] = self.lora_rank - return {**base_config, **config} diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index fcfffed16533..5c0b2d6d2ca2 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -310,6 +310,13 @@ def test_quantize_int8(self): layer.build((None, 8)) layer.quantize("int8") + # Verify weights dtype + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.compute_dtype, + ) + # Try eager call x = np.random.random((2, 8)) _ = layer(x) @@ -337,6 +344,14 @@ def test_quantize_int8(self): x = np.random.random((2, 8)) _ = layer(x) + # Try building with quantized dtype policy + layer = layers.Dense(units=16, dtype="int8_from_mixed_bfloat16") + layer.build((None, 8)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" + ) + @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): self.run_layer_test( diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index e7609e75539e..682c437d4efa 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -153,16 +153,22 @@ def build(self, input_shape): ) kernel_shape, bias_shape, full_output_shape = shape_data self.full_output_shape = tuple(full_output_shape) - self._kernel = self.add_weight( - name="kernel", - shape=tuple(kernel_shape), - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - dtype=self.dtype, - trainable=True, - ) - + # `quantized_build` needs `self.input_spec` + self.input_spec = InputSpec(ndim=len(input_shape)) + if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): + self.quantized_build( + input_shape, mode=self.dtype_policy.quantization_mode + ) + else: + self._kernel = self.add_weight( + name="kernel", + shape=tuple(kernel_shape), + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint, + dtype=self.dtype, + trainable=True, + ) if bias_shape is not None: self.bias = self.add_weight( name="bias", @@ -175,12 +181,12 @@ def build(self, input_shape): ) else: self.bias = None - self.input_spec = InputSpec(ndim=len(input_shape)) - self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): - self.quantize(self.dtype_policy.quantization_mode) + if self.bias is not None: + self.bias.trainable = False + self.built = True @property def kernel(self): @@ -197,6 +203,68 @@ def kernel(self): def compute_output_shape(self, _): return self.full_output_shape + def call(self, inputs): + x = ops.einsum(self.equation, inputs, self.kernel) + if self.bias is not None: + x += self.bias + if self.activation is not None: + x = self.activation(x) + return x + + def enable_lora( + self, rank, a_initializer="he_uniform", b_initializer="zeros" + ): + if self.kernel_constraint: + raise ValueError( + "Lora is incompatible with kernel constraints. " + "In order to enable lora on this layer, remove the " + "`kernel_constraint` argument." + ) + if not self.built: + raise ValueError( + "Cannot enable lora on a layer that isn't yet built." + ) + if self.lora_enabled: + raise ValueError( + "lora is already enabled. " + "This can only be done once per layer." + ) + self._tracker.unlock() + self.lora_kernel_a = self.add_weight( + name="lora_kernel_a", + shape=(self.kernel.shape[:-1] + (rank,)), + initializer=initializers.get(a_initializer), + regularizer=self.kernel_regularizer, + ) + self.lora_kernel_b = self.add_weight( + name="lora_kernel_b", + shape=(rank, self.kernel.shape[-1]), + initializer=initializers.get(b_initializer), + regularizer=self.kernel_regularizer, + ) + self._kernel.trainable = False + self._tracker.lock() + self.lora_enabled = True + self.lora_rank = rank + + def save_own_variables(self, store): + if not self.lora_enabled: + return super().save_own_variables(store) + + kernel_value = self.kernel + store["0"] = kernel_value + if self.bias is not None: + store["1"] = self.bias + + def load_own_variables(self, store): + if not self.lora_enabled: + return super().load_own_variables(store) + self._kernel.assign(store["0"]) + if self.bias is not None: + self.bias.assign(store["1"]) + self.lora_kernel_a.assign(np.zeros(self.lora_kernel_a.shape)) + self.lora_kernel_b.assign(np.zeros(self.lora_kernel_b.shape)) + def get_config(self): base_config = super().get_config() config = { @@ -222,13 +290,50 @@ def get_config(self): config["lora_rank"] = self.lora_rank return {**base_config, **config} - def call(self, inputs): - x = ops.einsum(self.equation, inputs, self.kernel) - if self.bias is not None: - x += self.bias - if self.activation is not None: - x = self.activation(x) - return x + """Quantization-related methods""" + + def quantized_build(self, input_shape, mode): + shape_data = _analyze_einsum_string( + self.equation, + self.bias_axes, + input_shape, + self.partial_output_shape, + ) + kernel_shape, _, _ = shape_data + if mode == "int8": + ( + self._input_reduced_axes, + self._kernel_reduced_axes, + self._input_transpose_axes, + self._kernel_transpose_axes, + self._input_expand_axes, + self._kernel_expand_axes, + self._input_squeeze_axes, + self._kernel_squeeze_axes, + ) = _analyze_quantization_info(self.equation, self.input_spec.ndim) + self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) + self._kernel = self.add_weight( + name="kernel", + shape=kernel_shape, + initializer="zeros", + dtype="int8", + trainable=False, + ) + kernel_scale_shape = np.array(kernel_shape) + kernel_scale_shape[self._kernel_reduced_axes] = 1 + kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes] + kernel_scale_shape = kernel_scale_shape.tolist() + for a in self._kernel_expand_axes: + kernel_scale_shape.insert(a, 1) + for a in self._kernel_squeeze_axes: + kernel_scale_shape.pop(a) + self.kernel_scale = self.add_weight( + name="kernel_scale", + shape=kernel_scale_shape, + initializer="zeros", + dtype=self.compute_dtype, + trainable=False, + ) def quantized_call(self, inputs): if self.lora_enabled: @@ -254,42 +359,6 @@ def quantized_call(self, inputs): x = self.activation(x) return x - def enable_lora( - self, rank, a_initializer="he_uniform", b_initializer="zeros" - ): - if self.kernel_constraint: - raise ValueError( - "Lora is incompatible with kernel constraints. " - "In order to enable lora on this layer, remove the " - "`kernel_constraint` argument." - ) - if not self.built: - raise ValueError( - "Cannot enable lora on a layer that isn't yet built." - ) - if self.lora_enabled: - raise ValueError( - "lora is already enabled. " - "This can only be done once per layer." - ) - self._tracker.unlock() - self.lora_kernel_a = self.add_weight( - name="lora_kernel_a", - shape=(self.kernel.shape[:-1] + (rank,)), - initializer=initializers.get(a_initializer), - regularizer=self.kernel_regularizer, - ) - self.lora_kernel_b = self.add_weight( - name="lora_kernel_b", - shape=(rank, self.kernel.shape[-1]), - initializer=initializers.get(b_initializer), - regularizer=self.kernel_regularizer, - ) - self._kernel.trainable = False - self._tracker.lock() - self.lora_enabled = True - self.lora_rank = rank - def quantize(self, mode): self._check_quantize_args(mode, self.compute_dtype) if mode == "int8": @@ -297,11 +366,6 @@ def quantize(self, mode): raise ValueError("`quantize` can only be done once per layer.") # Merge lora-related parameters to make use of fully int8 kernel self._merge_lora_into_kernel() - - if self.input_spec is None: - raise ValueError( - f"Cannot quantize {self.name} that isn't yet built." - ) ( self._input_reduced_axes, self._kernel_reduced_axes, @@ -376,24 +440,6 @@ def _merge_lora_into_kernel(self, untrack=False): self._tracker.lock() self.lora_rank = None - def save_own_variables(self, store): - if not self.lora_enabled: - return super().save_own_variables(store) - - kernel_value = self.kernel - store["0"] = kernel_value - if self.bias is not None: - store["1"] = self.bias - - def load_own_variables(self, store): - if not self.lora_enabled: - return super().load_own_variables(store) - self._kernel.assign(store["0"]) - if self.bias is not None: - self.bias.assign(store["1"]) - self.lora_kernel_a.assign(np.zeros(self.lora_kernel_a.shape)) - self.lora_kernel_b.assign(np.zeros(self.lora_kernel_b.shape)) - def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): """Analyzes an einsum string to determine the required weight shape.""" diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index e7126ed3f3c4..dab9ec367ac4 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -4,6 +4,7 @@ import pytest from absl.testing import parameterized +from keras import backend from keras import constraints from keras import layers from keras import models @@ -382,6 +383,13 @@ def test_quantize_int8(self): layer.build((None, 3)) layer.quantize("int8") + # Verify weights dtype + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), + layer.compute_dtype, + ) + # Try eager call x = np.random.random((2, 3)) _ = layer(x) @@ -413,6 +421,19 @@ def test_quantize_int8(self): x = np.random.random((2, 3)) _ = layer(x) + # Try building with quantized dtype policy + layer = layers.EinsumDense( + equation="ab,bcd->acd", + output_shape=(8, 32), + bias_axes="d", + dtype="int8_from_mixed_bfloat16", + ) + layer.build((None, 3)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" + ) + @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): self.run_layer_test( From fdb90c606392ee8d6949fa9e0975c880be919e86 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Mar 2024 10:33:44 +0800 Subject: [PATCH 2/5] Fix `self.built` bug --- keras/layers/core/dense.py | 2 +- keras/layers/core/einsum_dense.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index 62a0ecc59727..f40248adad10 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -125,12 +125,12 @@ def build(self, input_shape): else: self.bias = None self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) + self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): if self.bias is not None: self.bias.trainable = False - self.built = True @property def kernel(self): diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 682c437d4efa..2e3549c3aa0a 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -181,12 +181,12 @@ def build(self, input_shape): ) else: self.bias = None + self.built = True if self.lora_rank: self.enable_lora(self.lora_rank) if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy): if self.bias is not None: self.bias.trainable = False - self.built = True @property def kernel(self): From 711e5fd41bd0d5a250fa15bfd39a9d977be98623 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:01:01 +0800 Subject: [PATCH 3/5] Improve test coverage --- keras/layers/core/einsum_dense_test.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index dab9ec367ac4..02b0320b05be 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -423,16 +423,26 @@ def test_quantize_int8(self): # Try building with quantized dtype policy layer = layers.EinsumDense( - equation="ab,bcd->acd", - output_shape=(8, 32), + equation="abcde,afce->acdbf", + output_shape=(2, 4, 8, 16), bias_axes="d", dtype="int8_from_mixed_bfloat16", ) - layer.build((None, 3)) + layer.build((1, 8, 2, 4, 32)) self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") self.assertEqual( backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" ) + layer = layers.EinsumDense( + equation="a,b->ab", + output_shape=(4,), + dtype="int8_from_float32", + ) + layer.build((None,)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" + ) @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): From 9e41d29ed9cbd9b5e8d6543489025eea16129e69 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Thu, 21 Mar 2024 11:37:51 +0800 Subject: [PATCH 4/5] Update tests --- keras/layers/core/einsum_dense.py | 4 ++-- keras/layers/core/einsum_dense_test.py | 14 ++++++++++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index 2e3549c3aa0a..b36a4b0fe245 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -323,9 +323,9 @@ def quantized_build(self, input_shape, mode): kernel_scale_shape[self._kernel_reduced_axes] = 1 kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes] kernel_scale_shape = kernel_scale_shape.tolist() - for a in self._kernel_expand_axes: + for a in sorted(self._kernel_expand_axes): kernel_scale_shape.insert(a, 1) - for a in self._kernel_squeeze_axes: + for a in sorted(self._kernel_squeeze_axes, reverse=True): kernel_scale_shape.pop(a) self.kernel_scale = self.add_weight( name="kernel_scale", diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index 02b0320b05be..beb39a137109 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -423,7 +423,7 @@ def test_quantize_int8(self): # Try building with quantized dtype policy layer = layers.EinsumDense( - equation="abcde,afce->acdbf", + equation="abcde,afce->acdbf", # Test reduce and transpose output_shape=(2, 4, 8, 16), bias_axes="d", dtype="int8_from_mixed_bfloat16", @@ -434,7 +434,7 @@ def test_quantize_int8(self): backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" ) layer = layers.EinsumDense( - equation="a,b->ab", + equation="a,b->ab", # Test expand output_shape=(4,), dtype="int8_from_float32", ) @@ -443,6 +443,16 @@ def test_quantize_int8(self): self.assertEqual( backend.standardize_dtype(layer.kernel_scale.dtype), "float32" ) + layer = layers.EinsumDense( + equation="ab,ab->a", # Test squeeze + output_shape=(2,), + dtype="int8_from_float32", + ) + layer.build((2, 4)) + self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") + self.assertEqual( + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" + ) @pytest.mark.requires_trainable_backend def test_quantize_dtype_argument(self): From 11ef4c4b3c6c3dae02c3f24f097ff57a40de8f25 Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 22 Mar 2024 09:13:59 +0800 Subject: [PATCH 5/5] Use 1d vector for scaling factor in Dense and utilize `variable_dtype` for it both in Dense and EinsumDense --- keras/layers/core/dense.py | 16 +++++++--------- keras/layers/core/dense_test.py | 4 ++-- keras/layers/core/einsum_dense.py | 4 +--- keras/layers/core/einsum_dense_test.py | 4 ++-- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/keras/layers/core/dense.py b/keras/layers/core/dense.py index f40248adad10..874e230b281c 100644 --- a/keras/layers/core/dense.py +++ b/keras/layers/core/dense.py @@ -236,22 +236,19 @@ def get_config(self): def quantized_build(self, input_shape, mode): input_dim = input_shape[-1] - kernel_shape = (input_dim, self.units) if mode == "int8": self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) self._kernel = self.add_weight( name="kernel", - shape=kernel_shape, + shape=(input_dim, self.units), initializer="zeros", dtype="int8", trainable=False, ) - kernel_scale_shape = (1, kernel_shape[1]) self.kernel_scale = self.add_weight( name="kernel_scale", - shape=kernel_scale_shape, - initializer="zeros", - dtype=self.compute_dtype, + shape=(self.units,), + initializer="ones", trainable=False, ) @@ -282,7 +279,9 @@ def quantize(self, mode): kernel_value, kernel_scale = quantizers.abs_max_quantize( self._kernel, axis=0 ) - kernel_scale = ops.cast(kernel_scale, self.compute_dtype) + kernel_scale = ops.cast( + ops.squeeze(kernel_scale, axis=0), self.compute_dtype + ) self._tracker.unlock() self._untrack_variable(self._kernel) self._kernel = self.add_weight( @@ -295,10 +294,9 @@ def quantize(self, mode): ) self.kernel_scale = self.add_weight( name="kernel_scale", - shape=kernel_scale.shape, + shape=(self.units,), # Prevent adding a large constant to the computation graph initializer=lambda shape, dtype: kernel_scale, - dtype=self.compute_dtype, trainable=False, ) if self.bias is not None: diff --git a/keras/layers/core/dense_test.py b/keras/layers/core/dense_test.py index 5c0b2d6d2ca2..6da88db44a72 100644 --- a/keras/layers/core/dense_test.py +++ b/keras/layers/core/dense_test.py @@ -314,7 +314,7 @@ def test_quantize_int8(self): self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") self.assertEqual( backend.standardize_dtype(layer.kernel_scale.dtype), - layer.compute_dtype, + layer.variable_dtype, ) # Try eager call @@ -349,7 +349,7 @@ def test_quantize_int8(self): layer.build((None, 8)) self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") self.assertEqual( - backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" ) @pytest.mark.requires_trainable_backend diff --git a/keras/layers/core/einsum_dense.py b/keras/layers/core/einsum_dense.py index b36a4b0fe245..5fc66f6d6ef0 100644 --- a/keras/layers/core/einsum_dense.py +++ b/keras/layers/core/einsum_dense.py @@ -330,8 +330,7 @@ def quantized_build(self, input_shape, mode): self.kernel_scale = self.add_weight( name="kernel_scale", shape=kernel_scale_shape, - initializer="zeros", - dtype=self.compute_dtype, + initializer="ones", trainable=False, ) @@ -411,7 +410,6 @@ def quantize(self, mode): shape=kernel_scale.shape, # Prevent adding a large constant to the computation graph initializer=lambda shape, dtype: kernel_scale, - dtype=self.compute_dtype, trainable=False, ) if self.bias is not None: diff --git a/keras/layers/core/einsum_dense_test.py b/keras/layers/core/einsum_dense_test.py index beb39a137109..95eca16afab7 100644 --- a/keras/layers/core/einsum_dense_test.py +++ b/keras/layers/core/einsum_dense_test.py @@ -387,7 +387,7 @@ def test_quantize_int8(self): self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") self.assertEqual( backend.standardize_dtype(layer.kernel_scale.dtype), - layer.compute_dtype, + layer.variable_dtype, ) # Try eager call @@ -431,7 +431,7 @@ def test_quantize_int8(self): layer.build((1, 8, 2, 4, 32)) self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), "int8") self.assertEqual( - backend.standardize_dtype(layer.kernel_scale.dtype), "bfloat16" + backend.standardize_dtype(layer.kernel_scale.dtype), "float32" ) layer = layers.EinsumDense( equation="a,b->ab", # Test expand