Skip to content

Commit

Permalink
refactor: use composition to wrap the pytorch DataLoader using LLMDat…
Browse files Browse the repository at this point in the history
…aLoader, so that both LLMDataLoader and WebLoader inherit only from DataLoaderIF
  • Loading branch information
sthoduka committed Sep 27, 2024
1 parent f1dbe91 commit 146682f
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 27 deletions.
10 changes: 5 additions & 5 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
PydanticCheckpointSavingExecutionIFType,
PydanticCheckpointSavingStrategyIFType,
PydanticCollateFnIFType,
PydanticDataLoaderIFType,
PydanticDatasetIFType,
PydanticFSDPModuleType,
PydanticLLMDataLoaderIFType,
PydanticModelInitializationIFType,
PydanticOptimizerIFType,
PydanticPytorchDeviceType,
Expand Down Expand Up @@ -343,7 +343,7 @@ class WebLoaderConfig(BaseModel):


class RepeatingDataLoaderConfig(BaseModel):
dataloader: PydanticLLMDataLoaderIFType
dataloader: PydanticDataLoaderIFType
reshuffle_after_epoch: Optional[bool] = False
num_epochs: Annotated[int, Field(strict=True, ge=1)]

Expand All @@ -353,15 +353,15 @@ class DummyProgressSubscriberConfig(BaseModel):


class SimpleProgressSubscriberConfig(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list)
train_dataloader: PydanticDataLoaderIFType
eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list)
world_size: int
global_num_seen_samples: int
local_rank: int


class RichProgressSubscriberConfig(BaseModel):
eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list)
eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list)
train_dataloader_tag: str
num_seen_steps: Annotated[int, Field(strict=True, ge=0)]
num_target_steps: Annotated[int, Field(strict=True, gt=0)]
Expand Down
6 changes: 3 additions & 3 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from modalities.config.pydanctic_if_types import (
PydanticCheckpointSavingIFType,
PydanticDataLoaderIFType,
PydanticDatasetIFType,
PydanticGradientClipperIFType,
PydanticLLMDataLoaderIFType,
PydanticLossIFType,
PydanticLRSchedulerIFType,
PydanticMessageSubscriberIFType,
Expand Down Expand Up @@ -170,8 +170,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
scheduler: PydanticLRSchedulerIFType
loss_fn: PydanticLossIFType | list[PydanticLossIFType]
train_dataset: PydanticDatasetIFType
train_dataloader: PydanticLLMDataLoaderIFType
eval_dataloaders: list[PydanticLLMDataLoaderIFType]
train_dataloader: PydanticDataLoaderIFType
eval_dataloaders: list[PydanticDataLoaderIFType]
progress_subscriber: PydanticMessageSubscriberIFType
evaluation_subscriber: PydanticMessageSubscriberIFType
checkpoint_saving: PydanticCheckpointSavingIFType
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/config/pydanctic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __get_pydantic_core_schema__(
PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)]
PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)]
PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)]
PydanticLLMDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)]
PydanticDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)]
PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)]
PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)]
Expand Down
57 changes: 49 additions & 8 deletions src/modalities/dataloader/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import multiprocessing
from typing import Iterable, Optional

import webdataset as wd
Expand All @@ -11,7 +12,7 @@ class DataLoaderIF:
pass


class LLMDataLoader(DataLoader[T_co], DataLoaderIF):
class LLMDataLoader(DataLoaderIF):
"""LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class."""

def __init__(
Expand Down Expand Up @@ -62,7 +63,9 @@ def __init__(
None
"""
assert batch_sampler is not None and batch_size == 1
super().__init__(
self._dataloader_tag = dataloader_tag
self._batch_size = batch_sampler.batch_size
self._torch_dataloader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False, # shuffling must be implemented on a dataset level
Expand All @@ -81,9 +84,6 @@ def __init__(
pin_memory_device=pin_memory_device,
)

self._dataloader_tag = dataloader_tag
self._batch_size = batch_sampler.batch_size

@property
def dataloader_tag(self) -> str:
"""
Expand Down Expand Up @@ -125,6 +125,47 @@ def batch_size(self, value: int):
"""
self._batch_size = value

def __len__(self):
return self._torch_dataloader.__len__()

def __iter__(self):
return self._torch_dataloader.__iter__()

@property
def dataset(self) -> Dataset[T_co]:
return self._torch_dataloader.dataset

@property
def batch_sampler(self) -> ResumableBatchSampler:
return self._torch_dataloader.batch_sampler

@property
def sampler(self) -> Sampler | Iterable | None:
return self._torch_dataloader.sampler

@property
def collate_fn(self) -> _collate_fn_t:
return self._torch_dataloader.collate_fn

@property
def multiprocessing_context(self) -> str | multiprocessing.context.BaseContext:
return self._torch_dataloader.multiprocessing_context

@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):
self._torch_dataloader.multiprocessing_context = multiprocessing_context

@property
def _auto_collation(self):
return self._torch_dataloader._auto_collation

@property
def _index_sampler(self):
return self._torch_dataloader._index_sampler

def check_worker_number_rationality(self):
return self._torch_dataloader.check_worker_number_rationality()

@property
def fast_forward_batch_id(self) -> int:
"""
Expand All @@ -133,15 +174,15 @@ def fast_forward_batch_id(self) -> int:
Returns:
int: fast forward batch ID
"""
return self.batch_sampler.start_index
return self._torch_dataloader.batch_sampler.start_index


class RepeatingDataLoader(LLMDataLoader[T_co]):
class RepeatingDataLoader(LLMDataLoader):
"""
RepeatingDataLoader is a custom DataLoader class that repeats the given dataloader
for the specified number of epochs."""

def __init__(self, dataloader: LLMDataLoader[T_co], num_epochs: int, reshuffle_after_epoch: bool = False):
def __init__(self, dataloader: LLMDataLoader, num_epochs: int, reshuffle_after_epoch: bool = False):
"""
Initializes a RepeatingDataLoader object that repeats the given dataloader for the specified number of epochs.
This is especially useful for DataLoader types that we wish to automatically restart upon completion.
Expand Down
4 changes: 2 additions & 2 deletions tests/dataloader/distributed/test_distributed_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType
from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType
from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType
from modalities.running_env.cuda_env import CudaEnv
from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig

Expand All @@ -18,7 +18,7 @@


class DataloaderInstantiationModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType


@pytest.mark.skipif(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pydantic import BaseModel

from modalities.__main__ import Main
from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType
from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType
from modalities.running_env.cuda_env import CudaEnv
from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig

Expand All @@ -17,7 +17,7 @@


class DataloaderInstantiationModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType


@pytest.mark.skipif(
Expand Down
8 changes: 4 additions & 4 deletions tests/dataloader/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from modalities.config.component_factory import ComponentFactory
from modalities.config.config import load_app_config_dict
from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType
from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType
from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader
from modalities.dataloader.dataset import Dataset
from modalities.dataloader.samplers import ResumableBatchSampler
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_dataloader_from_config(dummy_config: dict):
dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index

class DataloaderTestModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType

registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
Expand Down Expand Up @@ -167,7 +167,7 @@ def test_repeating_dataloader_with_shuffling():

def test_skipped_and_distributed_dataloader_from_config():
class DataloaderTestModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType
skip_num_batches: int

root_dir = Path(__file__).parents[0]
Expand Down Expand Up @@ -244,7 +244,7 @@ class DataloaderTestModel(BaseModel):
)
def test_dataloader_with_fixed_num_batches(global_rank):
class DataloaderTestModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType
fixed_num_batches: int

class IdentityCollateFn(CollateFnIF):
Expand Down
4 changes: 2 additions & 2 deletions tests/end2end_tests/test_fsdp_warmstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from modalities.__main__ import Main, load_app_config_dict
from modalities.batch import EvaluationResultBatch
from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType
from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.logging_broker.messages import Message
Expand Down Expand Up @@ -46,7 +46,7 @@ class SaveAllResultSubscriberConfig(BaseModel):


class TrainDataloaderInstantiationModel(BaseModel):
train_dataloader: PydanticLLMDataLoaderIFType
train_dataloader: PydanticDataLoaderIFType


@pytest.mark.skipif(
Expand Down

0 comments on commit 146682f

Please sign in to comment.