From d6e2cf832375c1c2b879e7bd94d0014a7f7983b9 Mon Sep 17 00:00:00 2001 From: BinyanHu Date: Wed, 19 Aug 2020 14:05:17 +0800 Subject: [PATCH 1/5] add test for reloading lookahead optimizer --- .../optimizers/tests/lookahead_test.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index 4802bf0fd5..7691c4ffe6 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -14,9 +14,12 @@ # ============================================================================== """Tests for Lookahead optimizer.""" +import os + import numpy as np import pytest import tensorflow as tf +import tempfile from tensorflow_addons.optimizers import Lookahead from tensorflow_addons.utils import test_utils @@ -186,3 +189,58 @@ def test_serialization(): config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() + + +def assert_same_optimizer_states(optimizer, new_optimizer): + # Remove the iteration variable + weights = [] + for weight in optimizer.weights: + if "iter" not in weight.name: + weights.append(weight) + new_weights = [] + for weight in new_optimizer.weights: + if "iter" not in weight.name: + new_weights.append(weight) + + assert len(weights) == len(new_weights) + + weights = sorted(weights, key=lambda w: w.name) + new_weights = sorted(new_weights, key=lambda w: w.name) + + for weight, new_weight in zip(weights, new_weights): + assert np.all(weight == new_weight) + + # Assert recursively + if hasattr(optimizer, "_optimizer"): + assert_same_optimizer_states(optimizer._optimizer, new_optimizer._optimizer) + + +def test_save_load(): + # NOTE currely passes with "sgd" instead of "adam" + x = np.random.standard_normal((10000, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4 + + model = tf.keras.models.Sequential() + model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + model.compile(Lookahead("adam"), loss="mse") + + model.fit(x, y, epochs=1) + + with tempfile.TemporaryDirectory() as ckpt_dir: + ckpt_path = os.path.join(ckpt_dir, "model.ckpt") + model.save_weights(ckpt_path) + + # Rebuild and reload + new_model = tf.keras.models.Sequential() + new_model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) + new_model.compile(Lookahead("adam"), loss="mse") + new_model.load_weights(ckpt_path) + + # Trigger optimizer initialization + try: + new_model.fit(x, y, epochs=1, steps_per_epoch=0) + except UnboundLocalError: + pass + + assert_same_optimizer_states(model.optimizer, new_model.optimizer) From 3dc401a084d817f04accfcd62b3e94ca98242ed4 Mon Sep 17 00:00:00 2001 From: BinyanHu Date: Thu, 20 Aug 2020 16:43:45 +0800 Subject: [PATCH 2/5] include sgd and adam, whole graph and weights only --- .../optimizers/tests/lookahead_test.py | 53 ++++++++++++------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index 7691c4ffe6..46ac8af66a 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -191,6 +191,15 @@ def test_serialization(): assert new_optimizer.get_config() == optimizer.get_config() +def _init_model(optimizer, init_w): + model = tf.keras.models.Sequential() + dense = tf.keras.layers.Dense(input_shape=(3,), units=1) + model.add(dense) + model.compile(Lookahead(optimizer), loss="mse") + dense.set_weights([init_w, np.zeros(1,)]) + return model + + def assert_same_optimizer_states(optimizer, new_optimizer): # Remove the iteration variable weights = [] @@ -208,39 +217,43 @@ def assert_same_optimizer_states(optimizer, new_optimizer): new_weights = sorted(new_weights, key=lambda w: w.name) for weight, new_weight in zip(weights, new_weights): - assert np.all(weight == new_weight) + assert np.allclose(weight.numpy(), new_weight.numpy(), atol=1e-4) # Assert recursively if hasattr(optimizer, "_optimizer"): assert_same_optimizer_states(optimizer._optimizer, new_optimizer._optimizer) -def test_save_load(): - # NOTE currely passes with "sgd" instead of "adam" +@pytest.mark.parametrize("optimizer", ["sgd", "adam"]) +@pytest.mark.parametrize("weights_only", [False, True]) +def test_save_load(optimizer, weights_only): x = np.random.standard_normal((10000, 3)) w = np.random.standard_normal((3, 1)) y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4 - model = tf.keras.models.Sequential() - model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) - model.compile(Lookahead("adam"), loss="mse") + init_w = np.random.standard_normal((3, 1)) - model.fit(x, y, epochs=1) + model = _init_model(optimizer, init_w) + model.fit(x, y, epochs=2, shuffle=False) with tempfile.TemporaryDirectory() as ckpt_dir: + new_model = _init_model(optimizer, init_w) + new_model.fit(x, y, epochs=1, shuffle=False) + ckpt_path = os.path.join(ckpt_dir, "model.ckpt") - model.save_weights(ckpt_path) - - # Rebuild and reload - new_model = tf.keras.models.Sequential() - new_model.add(tf.keras.layers.Dense(input_shape=(3,), units=1)) - new_model.compile(Lookahead("adam"), loss="mse") - new_model.load_weights(ckpt_path) - - # Trigger optimizer initialization - try: - new_model.fit(x, y, epochs=1, steps_per_epoch=0) - except UnboundLocalError: - pass + if weights_only: + new_model.save_weights(ckpt_path) + new_model = _init_model(optimizer, init_w) + new_model.load_weights(ckpt_path) + else: + new_model.save(ckpt_path) + new_model = tf.keras.models.load_model( + ckpt_path, + custom_objects={ + "Lookahead": Lookahead, + } + ) + + new_model.fit(x, y, epochs=1, shuffle=False) assert_same_optimizer_states(model.optimizer, new_model.optimizer) From 6c6f6322d6698f662c3d39f1f29a343acb6d7fdb Mon Sep 17 00:00:00 2001 From: BinyanHu Date: Thu, 20 Aug 2020 16:49:58 +0800 Subject: [PATCH 3/5] reformat test file by black --- tensorflow_addons/optimizers/tests/lookahead_test.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index 46ac8af66a..a4f8419d1d 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -248,10 +248,7 @@ def test_save_load(optimizer, weights_only): else: new_model.save(ckpt_path) new_model = tf.keras.models.load_model( - ckpt_path, - custom_objects={ - "Lookahead": Lookahead, - } + ckpt_path, custom_objects={"Lookahead": Lookahead,} ) new_model.fit(x, y, epochs=1, shuffle=False) From 94a79c38ed55a57c3fa60b1a1ce2ebcc32eae7f5 Mon Sep 17 00:00:00 2001 From: BinyanHu Date: Fri, 21 Aug 2020 19:59:08 +0800 Subject: [PATCH 4/5] correct sorting of optimizer weights --- .../optimizers/tests/lookahead_test.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index a4f8419d1d..7fd8068b9c 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -213,8 +213,17 @@ def assert_same_optimizer_states(optimizer, new_optimizer): assert len(weights) == len(new_weights) - weights = sorted(weights, key=lambda w: w.name) - new_weights = sorted(new_weights, key=lambda w: w.name) + optimizer_name = optimizer._name + len_name_scope = len(optimizer_name) + 1 + + def _get_key(weight): + name = weight.name + if name.startswith(optimizer_name): + name = name[len_name_scope:] + return name + + weights = sorted(weights, key=_get_key) + new_weights = sorted(new_weights, key=_get_key) for weight, new_weight in zip(weights, new_weights): assert np.allclose(weight.numpy(), new_weight.numpy(), atol=1e-4) From 80f6cb6358fa4a638ac23e24009713b37a53446f Mon Sep 17 00:00:00 2001 From: BinyanHu Date: Sun, 23 Aug 2020 20:31:24 +0800 Subject: [PATCH 5/5] include check of 'iter' --- .../optimizers/tests/lookahead_test.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index 7fd8068b9c..8bf9de046a 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -201,15 +201,8 @@ def _init_model(optimizer, init_w): def assert_same_optimizer_states(optimizer, new_optimizer): - # Remove the iteration variable - weights = [] - for weight in optimizer.weights: - if "iter" not in weight.name: - weights.append(weight) - new_weights = [] - for weight in new_optimizer.weights: - if "iter" not in weight.name: - new_weights.append(weight) + weights = optimizer.weights + new_weights = new_optimizer.weights assert len(weights) == len(new_weights) @@ -257,7 +250,7 @@ def test_save_load(optimizer, weights_only): else: new_model.save(ckpt_path) new_model = tf.keras.models.load_model( - ckpt_path, custom_objects={"Lookahead": Lookahead,} + ckpt_path, custom_objects={"Lookahead": Lookahead} ) new_model.fit(x, y, epochs=1, shuffle=False)