diff --git a/bioimage_embed/tests/test_cli.py b/bioimage_embed/tests/test_cli.py index 19fa8752..37e0f5c9 100644 --- a/bioimage_embed/tests/test_cli.py +++ b/bioimage_embed/tests/test_cli.py @@ -123,15 +123,35 @@ def model(): return "dummy_model" +@pytest.fixture +def cfg_recipe(model): + return config.Recipe(model=model) + + +@pytest.fixture +def cfg_trainer(): + return config.Trainer(max_epochs=1, max_steps=1, fast_dev_run=True) + + +@pytest.fixture +def cfg_dataloader(): + return config.DataLoader(num_workers=0) + + # TODO double check this is sensible @pytest.fixture -def cfg(model): - cfg = config.Config() - cfg.dataloader.num_workers = 0 # This avoids processes being forked - cfg.trainer.max_epochs = 1 - cfg.trainer.max_steps = 1 - cfg.trainer.fast_dev_run = True - cfg.recipe.model = model +def cfg(cfg_recipe, cfg_trainer, cfg_dataloader): + cfg = config.Config( + recipe=cfg_recipe, trainer=cfg_trainer, dataloader=cfg_dataloader + ) + return cfg + # This is an alternative way to create a config object but it is less flexible and if the config object is changed in the future, this will break, i.e validation is not guaranteed + + # cfg.dataloader.num_workers = 0 # This avoids processes being forked + # cfg.trainer.max_epochs = 1 + # cfg.trainer.max_steps = 1 + # cfg.trainer.fast_dev_run = True + # cfg.recipe.model = model return cfg