Skip to content

Commit

Permalink
Implement syncbn for TensorFlow (#18671)
Browse files Browse the repository at this point in the history
* Implement syncbn for TensorFlow

* code refined

* fix : make calculate_mean_and_var private
  • Loading branch information
edwardyehuang authored Oct 23, 2023
1 parent 080276f commit 666b8d3
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions keras/layers/normalization/batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from keras import backend
from keras import constraints
from keras import initializers
from keras import ops
Expand Down Expand Up @@ -67,6 +68,10 @@ class BatchNormalization(Layer):
gamma_regularizer: Optional regularizer for the gamma weight.
beta_constraint: Optional constraint for the beta weight.
gamma_constraint: Optional constraint for the gamma weight.
synchronized: If `True`, synchronizes the global batch statistics (mean and
variance) for the layer across all devices at each training step in a
distributed training strategy. If `False`, each replica uses its own
local batch statistics.
**kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).
Call arguments:
Expand Down Expand Up @@ -125,10 +130,17 @@ def __init__(
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
synchronized=False,
**kwargs,
):
super().__init__(**kwargs)
self.axis = int(axis)

if synchronized and backend.backend() != "tensorflow":
raise ValueError("Argument synchronized=True is only supported with the TensorFlow backend.")

self.synchronized = synchronized

self.momentum = float(momentum)
self.epsilon = float(epsilon)
self.center = center
Expand Down Expand Up @@ -187,6 +199,60 @@ def build(self, input_shape):

def compute_output_shape(self, input_shape):
return input_shape


def _calculate_mean_and_var(self, inputs):

if not self.synchronized:
return ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)

axes = self._reduction_axes
y = inputs

if backend() == "tensorflow":

from keras.utils.module_utils import tensorflow as tf
replica_ctx = tf.distribute.get_replica_context()

if replica_ctx:
local_sum = ops.sum(y, axis=axes, keepdims=True)
local_squared_sum = ops.sum(ops.square(y), axis=axes, keepdims=True)
batch_size = ops.cast(ops.shape(y)[axes[0]], dtype="float32")

y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, batch_size)

axes_vals = [(ops.shape(y))[axes[i]] for i in range(1, len(axes))]
multiplier = ops.cast(ops.prod(axes_vals), "float32")
multiplier = multiplier * global_batch_size

mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - ops.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = ops.mean(y, axes, keepdims=True, name="mean")
# sample variance, not unbiased variance
# Note: stop_gradient does not change the gradient that gets
# backpropagated to the mean from the variance calculation,
# because that gradient is zero
variance = ops.reduce_mean(
tf.math.squared_difference(y, ops.stop_gradient(mean)), axes, keepdims=True, name="variance"
)

mean = ops.cast(mean, inputs.dtype)
variance = ops.cast(variance, inputs.dtype)

return (mean, variance)
elif backend() == "jax":
raise NotImplementedError
elif backend() == "torch":
raise NotImplementedError


def call(self, inputs, training=None, mask=None):
input_dtype = standardize_dtype(inputs.dtype)
Expand All @@ -198,9 +264,7 @@ def call(self, inputs, training=None, mask=None):
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[self.axis] = inputs.shape[self.axis]
if training and self.trainable:
mean, variance = ops.moments(
inputs, axes=self._reduction_axes, keepdims=True
)
mean, variance = self._calculate_mean_and_var(inputs)
moving_mean = ops.cast(self.moving_mean, inputs.dtype)
moving_variance = ops.cast(self.moving_variance, inputs.dtype)
self.moving_mean.assign(
Expand Down

0 comments on commit 666b8d3

Please sign in to comment.