Skip to content

Commit

Permalink
update validation epoch (#7121)
Browse files Browse the repository at this point in the history
- this allows for validation epoch at the very beginning of training
- fixes #7122


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Oct 12, 2023
1 parent 4743945 commit e36982b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions monai/apps/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def download_url(
if urlparse(url).netloc == "drive.google.com":
if not has_gdown:
raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
if "fuzzy" not in gdown_kwargs:
gdown_kwargs["fuzzy"] = True # default to true for flexible url
gdown.download(url, f"{tmp_name}", quiet=not progress, **gdown_kwargs)
elif urlparse(url).netloc == "cloud-api.yandex.net":
with urlopen(url) as response:
Expand Down
2 changes: 1 addition & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def run(self, global_epoch: int = 1) -> None: # type: ignore[override]
"""
# init env value for current validation process
self.state.max_epochs = global_epoch
self.state.max_epochs = max(global_epoch, 1) # at least one epoch of validation
self.state.epoch = global_epoch - 1
self.state.iteration = 0
super().run()
Expand Down
6 changes: 4 additions & 2 deletions tests/test_handler_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

class TestEvaluator(Evaluator):
def _iteration(self, engine, batchdata):
pass
engine.state.output = "called"
return engine.state.output


class TestHandlerValidation(unittest.TestCase):
Expand All @@ -42,8 +43,9 @@ def _train_func(engine, batch):
ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine)
# test execution at start
engine.run(data, max_epochs=1)
self.assertEqual(evaluator.state.max_epochs, 0)
self.assertEqual(evaluator.state.max_epochs, 1)
self.assertEqual(evaluator.state.epoch_length, 8)
self.assertEqual(evaluator.state.output, "called")

engine.run(data, max_epochs=5)
self.assertEqual(evaluator.state.max_epochs, 4)
Expand Down

0 comments on commit e36982b

Please sign in to comment.