diff --git a/keras/src/backend/common/variables.py b/keras/src/backend/common/variables.py index 2d067f1ab897..db10cedabaaf 100644 --- a/keras/src/backend/common/variables.py +++ b/keras/src/backend/common/variables.py @@ -33,9 +33,11 @@ class Variable: autocast: Optional. Boolean indicating whether the variable supports autocasting. If `True`, the layer may first convert the variable to the compute data type when accessed. Defaults to `True`. - aggregation: Optional. String specifying how a distributed variable will - be aggregated. This serves as a semantic annotation, to be taken - into account by downstream backends or users. Defaults to `"mean"`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"` specifying how a distributed + variable will be aggregated. This serves as a semantic annotation, + to be taken into account by downstream backends or users. Defaults + to `"none"`. name: Optional. A unique name for the variable. Automatically generated if not set. @@ -93,7 +95,7 @@ def __init__( dtype=None, trainable=True, autocast=True, - aggregation="mean", + aggregation="none", name=None, ): name = name or auto_name(self.__class__.__name__) @@ -103,12 +105,21 @@ def __init__( "cannot contain character `/`. " f"Received: name={name}" ) - if aggregation not in ("none", "mean", "sum", "only_first_replica"): + if aggregation not in ( + None, + "none", + "mean", + "sum", + "only_first_replica", + ): raise ValueError( "Invalid valid for argument `aggregation`. Expected " - "one of {'none', 'mean', 'sum', 'only_first_replica'}. " + "one of `None`, `'none'`, `'mean'`, `'sum'`, " + "`'only_first_replica'`. " f"Received: aggregation={aggregation}" ) + if aggregation is None: + aggregation = "none" self._name = name parent_path = current_path() if parent_path: diff --git a/keras/src/backend/tensorflow/distribute_test.py b/keras/src/backend/tensorflow/distribute_test.py index 3c29777c5821..195eb999d35a 100644 --- a/keras/src/backend/tensorflow/distribute_test.py +++ b/keras/src/backend/tensorflow/distribute_test.py @@ -130,8 +130,8 @@ def test_variable_aggregation(self): with strategy.scope(): x = np.random.random((4, 4)) v1 = backend.Variable(x, dtype="float32") - self.assertEqual(v1.aggregation, "mean") - self.assertEqual(v1.value.aggregation, tf.VariableAggregation.MEAN) + self.assertEqual(v1.aggregation, "none") + self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE) v2 = backend.Variable(x, dtype="float32", aggregation="sum") self.assertEqual(v2.aggregation, "sum") diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 2508153d23c2..1de2ba0f2350 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -493,7 +493,7 @@ def add_weight( autocast=True, regularizer=None, constraint=None, - aggregation="mean", + aggregation="none", name=None, ): """Add a weight variable to the layer. @@ -520,10 +520,11 @@ def add_weight( constraint: Contrainst object to call on the variable after any optimizer update, or string name of a built-in constraint. Defaults to `None`. - aggregation: String, one of `'mean'`, `'sum'`, - `'only_first_replica'`. Annotates the variable with the type - of multi-replica aggregation to be used for this variable - when writing custom data parallel training loops. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. name: String name of the variable. Useful for debugging purposes. """ self._check_super_called() diff --git a/keras/src/optimizers/base_optimizer.py b/keras/src/optimizers/base_optimizer.py index 8717192fa84f..57833afadc7e 100644 --- a/keras/src/optimizers/base_optimizer.py +++ b/keras/src/optimizers/base_optimizer.py @@ -245,9 +245,29 @@ def add_variable( shape, initializer="zeros", dtype=None, - aggregation="mean", + aggregation="none", name=None, ): + """Add a variable to the optimizer. + + Args: + shape: Shape tuple for the variable. Must be fully-defined + (no `None` entries). + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). Defaults to `"zeros"`. + dtype: Dtype of the variable to create, e.g. `"float32"`. If + unspecified, defaults to the `keras.backend.floatx()`. + aggregation: Optional string, one of `None`, `"none"`, `"mean"`, + `"sum"` or `"only_first_replica"`. Annotates the variable with + the type of multi-replica aggregation to be used for this + variable when writing custom data parallel training loops. + Defaults to `"none"`. + name: String name of the variable. Useful for debugging purposes. + + Returns: + An optimizer variable, in the format of `keras.Variable`. + """ self._check_super_called() initializer = initializers.get(initializer) with backend.name_scope(self.name, caller=self): @@ -265,8 +285,27 @@ def add_variable( def add_variable_from_reference( self, reference_variable, name=None, initializer="zeros" ): - """Add an all-zeros variable with the shape and dtype of a reference - variable. + """Add an optimizer variable from the model variable. + + Create an optimizer variable based on the information of model variable. + For example, in SGD optimizer momemtum, for each model variable, a + corresponding momemtum variable is created of the same shape and dtype. + + Args: + reference_variable: `keras.Variable`. The corresponding model + variable to the optimizer variable to be created. + name: Optional string. The name prefix of the optimizer variable to + be created. If not provided, it will be set to `"var"`. The + variable name will follow the pattern + `{variable_name}_{reference_variable.name}`, + e.g., `momemtum/dense_1`. Defaults to `None`. + initializer: Initializer object to use to populate the initial + variable value, or string name of a built-in initializer + (e.g. `"random_normal"`). If unspecified, defaults to + `"zeros"`. + + Returns: + An optimizer variable, in the format of `keras.Variable`. """ name = name or "var" if hasattr(reference_variable, "path"):