From 8a468cf993090b6b30d950f25fcdf32ddfb8a546 Mon Sep 17 00:00:00 2001 From: SuryanarayanaY Date: Mon, 23 Oct 2023 12:49:10 +0530 Subject: [PATCH] Added test case for evaluation epoch iterator --- keras/trainers/trainer_test.py | 43 ++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/keras/trainers/trainer_test.py b/keras/trainers/trainer_test.py index ee96a0e5f8c..bd7071f605d 100644 --- a/keras/trainers/trainer_test.py +++ b/keras/trainers/trainer_test.py @@ -650,6 +650,49 @@ def call(self, inputs): out = model.predict({"a": x1, "b": x2}) self.assertEqual(out.shape, (3, 4)) + @pytest.mark.requires_trainable_backend + def test_for_eval_epoch_iterator(self): + model = ExampleModel(units=3) + model.compile( + optimizer="adam", loss="mse", metrics=["mean_absolute_error"] + ) + x = np.ones((16, 4)) + y = np.zeros((16, 3)) + x_test = np.ones((16, 4)) + y_test = np.zeros((16, 3)) + model.fit( + x, + y, + batch_size=4, + validation_data=(x_test, y_test), + ) + assert getattr(model, "_eval_epoch_iterator", None) is None + + # Try model.fit with reshaped validation_data + # This will throw an exception which is intended + try: + model.fit( + x, + y, + batch_size=4, + validation_data=( + x_test.reshape((-1, 16, 4)), + y_test.reshape((-1, 16, 3)), + ), + ) + except: + pass + + # Try model.fit with correct validation_data this should work. + # After successful training `_eval_epoch_iterator` should be None + model.fit( + x, + y, + batch_size=4, + validation_data=(x_test, y_test), + ) + assert getattr(model, "_eval_epoch_iterator", None) is None + @pytest.mark.requires_trainable_backend def test_callback_methods_keys(self): class CustomCallback(Callback):