diff --git a/tests/test_training.py b/tests/test_training.py index 0730923d..07d23a25 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -203,7 +203,13 @@ def test_nargnn(): @pytest.mark.skipfif("numba" not in sys.modules, reason="Numba not installed") def test_deepaco(): env = TSPEnv(generator_params=dict(num_loc=20)) - model = DeepACO(env, train_data_size=10, val_data_size=10, test_data_size=10) + model = DeepACO( + env, + train_data_size=10, + val_data_size=10, + test_data_size=10, + policy_kwargs={"n_ants": 5}, + ) trainer = RL4COTrainer( max_epochs=1, gradient_clip_val=1, devices=1, accelerator=accelerator )