Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Combined dataset feature #261

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dcb866f
feat: combined datasets implementation
le1nux Sep 16, 2024
04c3189
chore: Merge branch 'warmstart_infrastructure_switch' into combined_d…
le1nux Sep 16, 2024
63b370e
feat: added DistributedSampler (unmodified pytorch version)
le1nux Sep 17, 2024
51f3fef
refactor: DistributedSampler from pytorch
le1nux Sep 18, 2024
a9a5d93
refactor: vectorized packed index generation
le1nux Sep 18, 2024
3198e4d
feat: added test coverage for CombinedDataset
le1nux Sep 24, 2024
299d2e6
refactor: moved sampler tests
le1nux Sep 24, 2024
915ec8f
feat: implemented distributed sampler tests
le1nux Sep 24, 2024
a25a788
refactor: refactored ResumableDistributedSampler
le1nux Sep 24, 2024
dd5316e
refactor: commented out old sample skipping in dataloader
le1nux Sep 24, 2024
c7a9cfc
feat: added new sampling strategy to config lorem ipsum
le1nux Sep 24, 2024
16f684e
feat: added more tests for the distributed sampler
le1nux Sep 25, 2024
d21e0df
chore: added documentation to ResumableDistributedSampler
le1nux Sep 25, 2024
fb9deea
refactor: the PackedMemMapDatasetContinuous does not load the index b…
le1nux Sep 25, 2024
2823745
feat: added test for dataset packing
le1nux Sep 25, 2024
e1091bf
chore: removed legacy code from DataloaderFactory
le1nux Sep 25, 2024
95a121e
refactor: upated configs
le1nux Sep 25, 2024
611c77b
feat: added number conversion routine
le1nux Sep 25, 2024
4a759f3
chore: updated tutorial configs
le1nux Sep 25, 2024
b5cc617
refactor: removed obsolete test test_dataloader_with_fixed_num_batches
le1nux Sep 25, 2024
053275e
refactor: adapted more failing test to the dataloader changes
le1nux Sep 25, 2024
e5ee5e9
refactor: removed RepeatingDataLoader
le1nux Sep 26, 2024
26b1152
refactor: removed ResumableBatchSampler
le1nux Sep 26, 2024
6a62709
refactor: removed legacy tests
le1nux Sep 26, 2024
b0ac334
refactor: fixed e2e tests
le1nux Sep 26, 2024
aed1d3d
chore: updated documentation
le1nux Sep 27, 2024
03fab0c
chore: Merge branch 'main' into combined_dataset_feature
le1nux Sep 27, 2024
4d27cae
chore: fixed minor path issue
le1nux Sep 27, 2024
2e8a880
refactor: improved test_skipped_and_distributed_dataloader_from_config
le1nux Oct 23, 2024
6d9b88a
Update src/modalities/dataloader/dataset.py
le1nux Oct 24, 2024
b067718
Update tests/dataloader/samplers/test_distributed_samplers.py
le1nux Oct 24, 2024
cc9ef97
Update tests/dataloader/samplers/test_distributed_samplers.py
le1nux Oct 24, 2024
09554b1
Update tests/dataloader/samplers/test_distributed_samplers.py
le1nux Oct 24, 2024
856fba7
Update tests/dataloader/samplers/test_distributed_samplers.py
le1nux Oct 24, 2024
c2e6b8c
refactor: fixed typos
le1nux Oct 24, 2024
9b6c0b0
chore: Merge branch 'combined_dataset_feature' of github.com:Modaliti…
le1nux Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion CHANGELOG_DEV.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ This [PR](https://github.com/Modalities/modalities/pull/236) removes all code re

**Breaking changes:**
* None
*


## PR #254 Warmstart infrastructure switch

Expand All @@ -85,3 +85,26 @@ This PR mainly addresses the warmstart of model training, e.g., after GPU crashe

**Breaking Changes**
* the settings part of the configs have been completely refactored


## PR #261 Dataloader inefficiencies fix and combined dataset feature

This PR addresses issue #258 (inefficiencies in the dataloader) and additionally introduces a combined dataset, where a dataset can now comprise a list of datasets and iterate over them.
As part of fixing the dataloader inefficiencies, we now implement the sample skipping functionality not on the dataloader level anymore but in an adapted version of the PyTorch `DistributedSampler`. I reran a warm start and the learning is equivalent to a full, non-warmstarted run.

<img width="1415" alt="Screenshot 2024-09-27 at 10 36 19" src="https://github.com/user-attachments/assets/65dfb1ed-e96b-4f50-a127-bc9d240ddff9">


**General Changes**
* Introduced `ResumableDistributedSampler` which is a copy of the PyTorch `DistributedSampler` added with the feature to skip samples. This is from now on used for warmstarts instead of the `skip_num_samples` in the Dataloader. In case of skipping samples, the dataloader had to instantiate a `ResumableBatchSampler` which was internally iterating over all the dataset indices. For small datasets this was fine, but for larger datasets (in the trillion token range) this became a bottleneck at instantiation time:
https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/samplers.py#L25-L28
Skipping in the `ResumableDistributedSampler` is skipping in O(1) now. The `ResumableBatchSampler` was removed from the codebase.
* Replaced the packed index generation routine (inefficient due to for loop)
https://github.com/Modalities/modalities/blob/b79d04d3e92d0845c5ec91f8dd41176fd543cb23/src/modalities/dataloader/dataset.py#L331-L334
with a vectorized version.
* added new `NumberConversion` routine `num_samples_from_num_tokens `

**Breaking Changes**
* Removed RepeatingDataloader, as a feature that was never actively used for running multiple epochs and had complex maintenance when refactoring the sampling. If needed we could reimpliment it.
* In the settings, the `training_progress` section has now `num_seen_samples` instead of `local_num_seen_batches `, as skipping is now done on the Sampler level and not on the dataloader level anymore
* `batch_size ` and `fast_forward_batch_id ` fields in the `LLMDataLoader ` are not neede anymore and were removed.
14 changes: 7 additions & 7 deletions config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ settings:
training_progress:
global_num_seen_tokens: 0
num_seen_steps: 0
local_num_seen_batches: 0
num_seen_samples: 0
last_step: -1
coca_example_settings:
train_num_samples: 64
Expand Down Expand Up @@ -96,7 +96,6 @@ train_dataloader:
num_workers: 2
pin_memory: true
dataloader_tag: train
skip_num_batches: ${settings.training_progress.local_num_seen_batches}
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
Expand All @@ -108,16 +107,17 @@ train_dataloader:
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
variant_key: resumable_distributed_sampler
config:
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
drop_last: true
seed: 42
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE
Expand Down
16 changes: 8 additions & 8 deletions config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ settings:
local_train_micro_batch_size: 1
sequence_length: 256
training_target:
num_target_tokens:
num_target_tokens:
component_key: number_conversion
variant_key: num_tokens_from_packed_mem_map_dataset_continuous
config:
Expand All @@ -47,7 +47,7 @@ settings:
training_progress:
global_num_seen_tokens: 0
num_seen_steps: 0
local_num_seen_batches: 0
num_seen_samples: 0
last_step: -1

collate_fn:
Expand All @@ -72,7 +72,6 @@ train_dataloader:
num_workers: 2
pin_memory: true
dataloader_tag: train
skip_num_batches: ${settings.training_progress.local_num_seen_batches}
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
Expand All @@ -84,16 +83,17 @@ train_dataloader:
drop_last: true
sampler:
component_key: sampler
variant_key: distributed_sampler
variant_key: resumable_distributed_sampler
config:
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
rank: ${settings.cuda_env.global_rank}
num_replicas: ${settings.cuda_env.world_size}
shuffle: true
drop_last: true
seed: 42
dataset:
instance_key: train_dataset
pass_type: BY_REFERENCE
drop_last: true
skip_num_global_samples: ${settings.training_progress.num_seen_samples}
collate_fn:
instance_key: collate_fn
pass_type: BY_REFERENCE
Expand Down
3 changes: 2 additions & 1 deletion docs/components/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
| dataset | mem_map_dataset | [DatasetFactory.get_mem_map_dataset](../../src/modalities/dataloader/dataset_factory.py)| [MemMapDatasetConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | MemMap Dataset |
| dataset | packed_mem_map_dataset_continuous | [DatasetFactory.get_packed_mem_map_dataset_continuous](../../src/modalities/dataloader/dataset_factory.py)| [PackedMemMapDatasetContinuousConfig](../../src/modalities/config/config.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Packed Memory Mapped Dataset Continuous |
| dataset | dummy_dataset | [DatasetFactory.get_dummy_dataset](../../src/modalities/dataloader/dataset_factory.py)| [DummyDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dummy dataset creating random samples of specified shape |
| dataset | combined | [DatasetFactory.get_combined_dataset](../../src/modalities/dataloader/dataset_factory.py)| [CombinedDatasetConfig](../../src/modalities/dataloader/dataset.py) | [Dataset](../../src/modalities/dataloader/dataset.py) | Dataset implementation combining multiple datasets into one. |

## Data sampling

Expand All @@ -76,7 +77,6 @@
|Component type | Component Version | Implementation | Configuration | Component Interface | Description |
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| data_loader | default | [DataloaderFactory.get_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [LLMDataLoaderConfig](s../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | LLM Data loader extending pytorch data loader functionality |
| data_loader | repeating_data_loader | [DataloaderFactory.get_repeating_dataloader](../../src/modalities/dataloader/dataloader_factory.py)| [RepeatingDataLoaderConfig](../../src/modalities/config/config.py) | [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) | Data loader that repeats the given dataloader for the specified number of epochs. |

## Checkpointing

Expand Down Expand Up @@ -118,6 +118,7 @@
|---------------|--------------------|----------------|---------------|---------------------|-------------|
| number_conversion | local_num_batches_from_num_samples | [NumberConversion.get_local_num_batches_from_num_samples](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of samples and number of ranks. |
| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_local_num_batches_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [LocalNumBatchesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of local batches for each rank, given the global number of tokens and number of ranks. |
| number_conversion | local_num_batches_from_num_tokens | [NumberConversion.get_num_samples_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumSamplesFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of global samples, given the global number of tokens and sequence length |
| number_conversion | num_steps_from_num_samples | [NumberConversion.get_num_steps_from_num_samples](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumSamplesConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of samples, local micro batch size and number of ranks. |
| number_conversion | num_steps_from_num_tokens | [NumberConversion.get_num_steps_from_num_tokens](../../src/modalities/utils/number_conversion.py)| [NumStepsFromNumTokensConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of steps given the global number of tokens, local micro batch size and number of ranks. |
| number_conversion | num_tokens_from_num_steps | [NumberConversion.get_num_tokens_from_num_steps](../../src/modalities/utils/number_conversion.py)| [NumTokensFromNumStepsConfig](../../src/modalities/utils/number_conversion.py) | -- | Calculates the number of tokens from the number of steps, number of ranks, local micro batch size, global number of tokens, squence length and gradient accumulation steps |
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def entry_point_data_create_raw_index(src_path: Path, index_path: Path):

index_path = LargeFileLinesReader.default_index_path(src_path, index_path)
if index_path.exists():
raise ValueError("index already exists. delete it or specify different output folder.")
raise ValueError(f"Index already exists in {index_path}. Delete it or specify different output folder.")

print(f"reading raw data from {src_path}")
print(f"writing index to {index_path}")
Expand Down
28 changes: 15 additions & 13 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,17 @@ class DistributedSamplerConfig(BaseModel):
drop_last: Literal[True] = True


class ResumableDistributedSamplerConfig(BaseModel):
dataset: PydanticDatasetIFType
rank: Annotated[int, Field(strict=True, ge=0)]
num_replicas: Annotated[int, Field(strict=True, ge=0)] = None
epoch: Annotated[int, Field(strict=True, ge=0)] = 0
shuffle: Optional[bool] = False
seed: Optional[int] = 0
drop_last: Literal[True] = True
skip_num_global_samples: Annotated[int, Field(strict=True, ge=0)] = 0


class MemMapDatasetConfig(BaseModel):
raw_data_path: FilePath
index_path: Optional[FilePath] = None
Expand All @@ -285,17 +296,16 @@ class PackedMemMapDatasetMegatronConfig(BaseModel):
sample_key: str


class CombinedDatasetConfig(BaseModel):
datasets: List[PydanticDatasetIFType]


class BatchSamplerConfig(BaseModel):
sampler: PydanticSamplerIFType
batch_size: Annotated[int, Field(strict=True, gt=0)]
drop_last: Literal[True] = True


class ResumableBatchSamplerConfig(BaseModel):
sampler: PydanticSamplerIFType
start_index: Annotated[int, Field(strict=True, gt=0)]


class GPT2LLMCollateFnConfig(BaseModel):
sample_key: str
target_key: str
Expand All @@ -308,14 +318,6 @@ class LLMDataLoaderConfig(BaseModel):
collate_fn: Optional[PydanticCollateFnIFType] = None
num_workers: Annotated[int, Field(strict=True, ge=0)]
pin_memory: bool
skip_num_batches: Optional[int] = 0
fixed_num_batches: Optional[int] = None


class RepeatingDataLoaderConfig(BaseModel):
dataloader: PydanticLLMDataLoaderIFType
reshuffle_after_epoch: Optional[bool] = False
num_epochs: Annotated[int, Field(strict=True, ge=1)]


class DummyProgressSubscriberConfig(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TrainingTarget(BaseModel):
class TrainingProgress(BaseModel):
global_num_seen_tokens: Annotated[int, Field(strict=True, ge=0)]
num_seen_steps: Annotated[int, Field(strict=True, ge=0)]
local_num_seen_batches: Annotated[int, Field(strict=True, ge=0)]
num_seen_samples: Annotated[int, Field(strict=True, ge=0)]
last_step: Annotated[int, Field(strict=True, ge=-1)]


Expand Down
28 changes: 21 additions & 7 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,12 +323,13 @@ class EmbeddedStreamData:
TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4
HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES

def __init__(self, data_path: Path):
def __init__(self, data_path: Path, load_index: Optional[bool] = True):
"""
Initializes an EmbeddedStreamData object.

Args:
data_path (Path): The path to the packed data file.
load_index (bool, optional): Whether to load the index. Defaults to True.

Raises:
FileNotFoundError: If the packed data file is not found at the specified path.
Expand All @@ -352,14 +353,27 @@ def __init__(self, data_path: Path):
self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="little", signed=False)

# get index
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
# contains the start offset and length of each segment
# as byte positions in the data section
self.index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index)
if load_index:
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
# contains the start offset and length of each segment
# as byte positions in the data section
self._index_base: List[Tuple[int, int]] = pickle.loads(pkl_encoded_index)
else:
self._index_base = None

# initialize memmapped data section
self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))
self._data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))

@property
def index_base(self) -> List[Tuple[int, int]]:
if self._index_base is None:
raise ValueError("Index was not loaded. Set `load_index=True` during initialization.")
return self._index_base

@property
def data(self) -> np.ndarray:
return self._data


def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048):
Expand Down
Loading