-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Combined dataset feature #261
base: main
Are you sure you want to change the base?
Conversation
…y default anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :) Left a few minor comments.
""" | ||
Initializes the PackedMemMapDatasetBase object. | ||
|
||
Args: | ||
raw_data_path (Path): Path to a packed binary file (*.pbin). | ||
Use `modalities data pack_encoded_data` to create one based on a JSONL-file. | ||
sample_key (str): The key to access the sample in the BatchEncoding. | ||
load_index (bool, optional): Flag indicating whether to load the index. Defaults to True. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be more consistent if this defaulted to False like in PackedMemMapDatasetContinuous
(see line 308)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PackedMemMapDatasetContinuous
we would never load the index (apart for debugging purposes), that's why I made it defaulting to False. The "Continuuous" implementation does not need an index. The PackedMemMapDatasetBase, however, in it's default implementation would use the index for packing the data, which is why it defaults to True.
CHANGELOG_DEV.md
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comments on the changelog:
- The table in the beginning also needs to be updated
- It might be useful to reverse the order of PRs so that the newest ones come first
Co-authored-by: Felix Stollenwerk <[email protected]>
Co-authored-by: Felix Stollenwerk <[email protected]>
Co-authored-by: Felix Stollenwerk <[email protected]>
Co-authored-by: Felix Stollenwerk <[email protected]>
Co-authored-by: Felix Stollenwerk <[email protected]>
…es/modalities into combined_dataset_feature
What does this PR do?
This PR addresses issue #258 (inefficiencies in the dataloader) and additionally introduces a combined dataset, where a dataset can now comprise a list of datasets and iterate over them.
As part of fixing the dataloader inefficiencies, we now implement the sample skipping functionality not on the dataloader level anymore but in an adapted version of the PyTorch
DistributedSampler
. I reran a warm start and the learning is equivalent to a full, non-warmstarted run.General Changes
ResumableDistributedSampler
which is a copy of the PyTorchDistributedSampler
added with the feature to skip samples. This is from now on used for warmstarts instead of theskip_num_samples
in the Dataloader. In case of skipping samples, the dataloader had to instantiate aResumableBatchSampler
which was internally iterating over all the dataset indices. For small datasets this was fine, but for larger datasets (in the trillion token range) this became a bottleneck at instantiation time:modalities/src/modalities/dataloader/samplers.py
Lines 25 to 28 in b79d04d
Skipping in the
ResumableDistributedSampler
is skipping in O(1) now. TheResumableBatchSampler
was removed from the codebase.modalities/src/modalities/dataloader/dataset.py
Lines 331 to 334 in b79d04d
with a vectorized version.
NumberConversion
routinenum_samples_from_num_tokens
Breaking Changes
training_progress
section has nownum_seen_samples
instead oflocal_num_seen_batches
, as skipping is now done on the Sampler level and not on the dataloader level anymorebatch_size
andfast_forward_batch_id
fields in theLLMDataLoader
are not neede anymore and were removed.Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)