Skip to content

Commit

Permalink
Fix the aggregation in the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 committed Dec 31, 2024
1 parent 6ce93a4 commit 53d6eb4
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 16 deletions.
23 changes: 17 additions & 6 deletions keras/src/backend/common/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ class Variable:
autocast: Optional. Boolean indicating whether the variable supports
autocasting. If `True`, the layer may first convert the variable
to the compute data type when accessed. Defaults to `True`.
aggregation: Optional. String specifying how a distributed variable will
be aggregated. This serves as a semantic annotation, to be taken
into account by downstream backends or users. Defaults to `"mean"`.
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
`"sum"` or `"only_first_replica"` specifying how a distributed
variable will be aggregated. This serves as a semantic annotation,
to be taken into account by downstream backends or users. Defaults
to `"none"`.
name: Optional. A unique name for the variable. Automatically generated
if not set.
Expand Down Expand Up @@ -93,7 +95,7 @@ def __init__(
dtype=None,
trainable=True,
autocast=True,
aggregation="mean",
aggregation="none",
name=None,
):
name = name or auto_name(self.__class__.__name__)
Expand All @@ -103,12 +105,21 @@ def __init__(
"cannot contain character `/`. "
f"Received: name={name}"
)
if aggregation not in ("none", "mean", "sum", "only_first_replica"):
if aggregation not in (
None,
"none",
"mean",
"sum",
"only_first_replica",
):
raise ValueError(
"Invalid valid for argument `aggregation`. Expected "
"one of {'none', 'mean', 'sum', 'only_first_replica'}. "
"one of `None`, `'none'`, `'mean'`, `'sum'`, "
"`'only_first_replica'`. "
f"Received: aggregation={aggregation}"
)
if aggregation is None:
aggregation = "none"
self._name = name
parent_path = current_path()
if parent_path:
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/tensorflow/distribute_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def test_variable_aggregation(self):
with strategy.scope():
x = np.random.random((4, 4))
v1 = backend.Variable(x, dtype="float32")
self.assertEqual(v1.aggregation, "mean")
self.assertEqual(v1.value.aggregation, tf.VariableAggregation.MEAN)
self.assertEqual(v1.aggregation, "none")
self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE)

v2 = backend.Variable(x, dtype="float32", aggregation="sum")
self.assertEqual(v2.aggregation, "sum")
Expand Down
11 changes: 6 additions & 5 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def add_weight(
autocast=True,
regularizer=None,
constraint=None,
aggregation="mean",
aggregation="none",
name=None,
):
"""Add a weight variable to the layer.
Expand All @@ -520,10 +520,11 @@ def add_weight(
constraint: Contrainst object to call on the variable after any
optimizer update, or string name of a built-in constraint.
Defaults to `None`.
aggregation: String, one of `'mean'`, `'sum'`,
`'only_first_replica'`. Annotates the variable with the type
of multi-replica aggregation to be used for this variable
when writing custom data parallel training loops.
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
`"sum"` or `"only_first_replica"`. Annotates the variable with
the type of multi-replica aggregation to be used for this
variable when writing custom data parallel training loops.
Defaults to `"none"`.
name: String name of the variable. Useful for debugging purposes.
"""
self._check_super_called()
Expand Down
45 changes: 42 additions & 3 deletions keras/src/optimizers/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,29 @@ def add_variable(
shape,
initializer="zeros",
dtype=None,
aggregation="mean",
aggregation="none",
name=None,
):
"""Add a variable to the optimizer.
Args:
shape: Shape tuple for the variable. Must be fully-defined
(no `None` entries).
initializer: Initializer object to use to populate the initial
variable value, or string name of a built-in initializer
(e.g. `"random_normal"`). Defaults to `"zeros"`.
dtype: Dtype of the variable to create, e.g. `"float32"`. If
unspecified, defaults to the `keras.backend.floatx()`.
aggregation: Optional string, one of `None`, `"none"`, `"mean"`,
`"sum"` or `"only_first_replica"`. Annotates the variable with
the type of multi-replica aggregation to be used for this
variable when writing custom data parallel training loops.
Defaults to `"none"`.
name: String name of the variable. Useful for debugging purposes.
Returns:
An optimizer variable, in the format of `keras.Variable`.
"""
self._check_super_called()
initializer = initializers.get(initializer)
with backend.name_scope(self.name, caller=self):
Expand All @@ -265,8 +285,27 @@ def add_variable(
def add_variable_from_reference(
self, reference_variable, name=None, initializer="zeros"
):
"""Add an all-zeros variable with the shape and dtype of a reference
variable.
"""Add an optimizer variable from the model variable.
Create an optimizer variable based on the information of model variable.
For example, in SGD optimizer momemtum, for each model variable, a
corresponding momemtum variable is created of the same shape and dtype.
Args:
reference_variable: `keras.Variable`. The corresponding model
variable to the optimizer variable to be created.
name: Optional string. The name prefix of the optimizer variable to
be created. If not provided, it will be set to `"var"`. The
variable name will follow the pattern
`{variable_name}_{reference_variable.name}`,
e.g., `momemtum/dense_1`. Defaults to `None`.
initializer: Initializer object to use to populate the initial
variable value, or string name of a built-in initializer
(e.g. `"random_normal"`). If unspecified, defaults to
`"zeros"`.
Returns:
An optimizer variable, in the format of `keras.Variable`.
"""
name = name or "var"
if hasattr(reference_variable, "path"):
Expand Down

0 comments on commit 53d6eb4

Please sign in to comment.