Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid using skip() in hf_datasets #838

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

mori360
Copy link
Contributor

@mori360 mori360 commented Feb 12, 2025

Fix issue #809

The current self._data.skip(self._sample_idx) could not get the correct data for c_4 dataset.
Thus we switch to next() first before the fix is landed.

Test plan:
We reproduce the #809 by resuming from checkpoint at step 500, then compare the loss curve in 3 conditions:

  1. the origin curve running from step 0 to 750
  2. the resumed curve keeping .skip()
  3. the resumed curve switch to next(), with this PR change
Screenshot 2025-02-12 at 11 19 24 AM

Warning
for c_4 dataset, if we resume from a large enough step, we call next() for self._sample_idx times, resuming from checkpoint would be much slower than using .skip()

Next step:
add unit test:

  1. test the state_dict check between dcp.save/load and torch.save/load
  2. test the difference between next() and .skip()

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Feb 12, 2025
@mori360 mori360 marked this pull request as ready for review February 12, 2025 20:34
@mori360 mori360 requested review from tianyu-l and fegin February 12, 2025 20:34
if isinstance(self._data, Dataset) and self._sample_idx == len(self._data):
return iter([])

return iter(self._data.skip(self._sample_idx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to understand if skip causes error in both map-style and Iterable datasets, or only in the newly added IterableDataset case.
If it's the latter we should just revert #521, rather than universally use next for both, because it would make the healthy case slow too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest that we land the PR first. It is better to have a slower checkpoint resume than an incorrect silent accuracy failure. It's blocking several accuracy verifications. Or at least we should make the default C4 dataset work for now.

@tianyu-l tianyu-l linked an issue Feb 13, 2025 that may be closed by this pull request
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stamp to unblock, but we should follow up with more robust tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Loss metrics dramatically change after resuming from checkpoint
4 participants