Skip to content

Commit

Permalink
Add unit test for sync batch norm under distribution strategy.
Browse files Browse the repository at this point in the history
  • Loading branch information
qlzh727 committed Oct 23, 2023
1 parent 2ad8e07 commit 47d7970
Showing 1 changed file with 36 additions and 0 deletions.
36 changes: 36 additions & 0 deletions keras/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,3 +1462,39 @@ def test_moments_sync(self):
expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True)
self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)

@parameterized.product(dtype=["float16", "float32"])
@pytest.mark.skipif(
backend.backend() != "tensorflow",
reason="synchronized=True only implemented for TF backend",
)
def test_moments_sync_with_distribution_strategy(self, dtype):
from keras.utils.module_utils import tensorflow as tf

# Config 2 CPUs for testing.
logical_cpus = tf.config.list_logical_devices("CPU")
if len(logical_cpus) == 1:
from tensorflow.python.eager import context

context._reset_context()
tf.config.set_logical_device_configuration(
tf.config.list_physical_devices("CPU")[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)

@tf.function()
def test_on_moments(inputs):
return knn.moments(
inputs, axes=-1, keepdims=True, synchronized=True
)

# Test output of moments.
inputs = tf.constant([5.0, 9.0, 1.0, 3.0], dtype=dtype)
strategy = tf.distribute.MirroredStrategy(["CPU:0", "CPU:1"])
with strategy.scope():
mean, variance = strategy.run(test_on_moments, args=(inputs,))
self.assertEqual(mean.values[0], 4.5)
self.assertEqual(variance.values[0], 8.75)

0 comments on commit 47d7970

Please sign in to comment.