Skip to content

Commit

Permalink
Added test case for evaluation epoch iterator
Browse files Browse the repository at this point in the history
  • Loading branch information
SuryanarayanaY committed Oct 23, 2023
1 parent e48e4f8 commit 8a468cf
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions keras/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,49 @@ def call(self, inputs):
out = model.predict({"a": x1, "b": x2})
self.assertEqual(out.shape, (3, 4))

@pytest.mark.requires_trainable_backend
def test_for_eval_epoch_iterator(self):
model = ExampleModel(units=3)
model.compile(
optimizer="adam", loss="mse", metrics=["mean_absolute_error"]
)
x = np.ones((16, 4))
y = np.zeros((16, 3))
x_test = np.ones((16, 4))
y_test = np.zeros((16, 3))
model.fit(
x,
y,
batch_size=4,
validation_data=(x_test, y_test),
)
assert getattr(model, "_eval_epoch_iterator", None) is None

# Try model.fit with reshaped validation_data
# This will throw an exception which is intended
try:
model.fit(
x,
y,
batch_size=4,
validation_data=(
x_test.reshape((-1, 16, 4)),
y_test.reshape((-1, 16, 3)),
),
)
except:
pass

# Try model.fit with correct validation_data this should work.
# After successful training `_eval_epoch_iterator` should be None
model.fit(
x,
y,
batch_size=4,
validation_data=(x_test, y_test),
)
assert getattr(model, "_eval_epoch_iterator", None) is None

@pytest.mark.requires_trainable_backend
def test_callback_methods_keys(self):
class CustomCallback(Callback):
Expand Down

0 comments on commit 8a468cf

Please sign in to comment.