diff --git a/tensorflow_addons/optimizers/tests/lookahead_test.py b/tensorflow_addons/optimizers/tests/lookahead_test.py index 4802bf0fd5..8bf9de046a 100644 --- a/tensorflow_addons/optimizers/tests/lookahead_test.py +++ b/tensorflow_addons/optimizers/tests/lookahead_test.py @@ -14,9 +14,12 @@ # ============================================================================== """Tests for Lookahead optimizer.""" +import os + import numpy as np import pytest import tensorflow as tf +import tempfile from tensorflow_addons.optimizers import Lookahead from tensorflow_addons.utils import test_utils @@ -186,3 +189,70 @@ def test_serialization(): config = tf.keras.optimizers.serialize(optimizer) new_optimizer = tf.keras.optimizers.deserialize(config) assert new_optimizer.get_config() == optimizer.get_config() + + +def _init_model(optimizer, init_w): + model = tf.keras.models.Sequential() + dense = tf.keras.layers.Dense(input_shape=(3,), units=1) + model.add(dense) + model.compile(Lookahead(optimizer), loss="mse") + dense.set_weights([init_w, np.zeros(1,)]) + return model + + +def assert_same_optimizer_states(optimizer, new_optimizer): + weights = optimizer.weights + new_weights = new_optimizer.weights + + assert len(weights) == len(new_weights) + + optimizer_name = optimizer._name + len_name_scope = len(optimizer_name) + 1 + + def _get_key(weight): + name = weight.name + if name.startswith(optimizer_name): + name = name[len_name_scope:] + return name + + weights = sorted(weights, key=_get_key) + new_weights = sorted(new_weights, key=_get_key) + + for weight, new_weight in zip(weights, new_weights): + assert np.allclose(weight.numpy(), new_weight.numpy(), atol=1e-4) + + # Assert recursively + if hasattr(optimizer, "_optimizer"): + assert_same_optimizer_states(optimizer._optimizer, new_optimizer._optimizer) + + +@pytest.mark.parametrize("optimizer", ["sgd", "adam"]) +@pytest.mark.parametrize("weights_only", [False, True]) +def test_save_load(optimizer, weights_only): + x = np.random.standard_normal((10000, 3)) + w = np.random.standard_normal((3, 1)) + y = np.dot(x, w) + np.random.standard_normal((10000, 1)) * 1e-4 + + init_w = np.random.standard_normal((3, 1)) + + model = _init_model(optimizer, init_w) + model.fit(x, y, epochs=2, shuffle=False) + + with tempfile.TemporaryDirectory() as ckpt_dir: + new_model = _init_model(optimizer, init_w) + new_model.fit(x, y, epochs=1, shuffle=False) + + ckpt_path = os.path.join(ckpt_dir, "model.ckpt") + if weights_only: + new_model.save_weights(ckpt_path) + new_model = _init_model(optimizer, init_w) + new_model.load_weights(ckpt_path) + else: + new_model.save(ckpt_path) + new_model = tf.keras.models.load_model( + ckpt_path, custom_objects={"Lookahead": Lookahead} + ) + + new_model.fit(x, y, epochs=1, shuffle=False) + + assert_same_optimizer_states(model.optimizer, new_model.optimizer)