Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update nn.py #18674

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 27 additions & 25 deletions keras/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,33 +697,35 @@ def _compute_moments_sync(x, axes, keepdims):
return mean, variance


def _compute_moments(x, axes, keepdims):
# The dynamic range of float16 is too limited for statistics. As a
# workaround, we simply perform the operations on float32 and convert back
# to float16
need_cast = False
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype == "float16":
need_cast = True
x = cast(x, "float32")

mean = tf.reduce_mean(x, axes, keepdims=True)

# The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster
# but less numerically stable.
# Note: stop_gradient does not change the gradient to the mean, because that
# gradient is zero.
variance = tf.reduce_mean(
tf.square(x), axis=axes, keepdims=True
) - tf.square(tf.stop_gradient(mean))
def _compute_moments_sync(x, axes, keepdims):
# Cast to float32 if the input is of dtype float16
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x

replica_ctx = tf.distribute.get_replica_context()
if not replica_ctx:
return _compute_moments(y, axes, keepdims)

local_count = tf.ones_like(y, name="count")

local_sum = tf.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.reduce_sum(tf.square(y), axis=axes, keepdims=True)
local_count = tf.reduce_sum(local_count, axis=axes, keepdims=True)

# Perform all-reduces for sum, squared sum, and count
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_squared_sum)
count_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_count)

mean = y_sum / count_sum
y_squared_mean = y_squared_sum / count_sum
# Calculate variance as E(x^2) - E(x)^2
variance = y_squared_mean - tf.square(mean)

if not keepdims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
if need_cast:
# avoid overflow and underflow when casting from float16 to float32
mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max)
variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max)
mean = cast(mean, ori_dtype)
variance = cast(variance, ori_dtype)

if x.dtype == tf.float16:
return tf.cast(mean, tf.float16), tf.cast(variance, tf.float16)

return mean, variance
Loading