diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 63abfef3..77a0419a 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -16,9 +16,9 @@ PydanticCheckpointSavingExecutionIFType, PydanticCheckpointSavingStrategyIFType, PydanticCollateFnIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticFSDPModuleType, - PydanticLLMDataLoaderIFType, PydanticModelInitializationIFType, PydanticOptimizerIFType, PydanticPytorchDeviceType, @@ -343,7 +343,7 @@ class WebLoaderConfig(BaseModel): class RepeatingDataLoaderConfig(BaseModel): - dataloader: PydanticLLMDataLoaderIFType + dataloader: PydanticDataLoaderIFType reshuffle_after_epoch: Optional[bool] = False num_epochs: Annotated[int, Field(strict=True, ge=1)] @@ -353,15 +353,15 @@ class DummyProgressSubscriberConfig(BaseModel): class SimpleProgressSubscriberConfig(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + train_dataloader: PydanticDataLoaderIFType + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) world_size: int global_num_seen_samples: int local_rank: int class RichProgressSubscriberConfig(BaseModel): - eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) train_dataloader_tag: str num_seen_steps: Annotated[int, Field(strict=True, ge=0)] num_target_steps: Annotated[int, Field(strict=True, gt=0)] diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 91a30c25..4ac4a4bf 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -6,9 +6,9 @@ from modalities.config.pydanctic_if_types import ( PydanticCheckpointSavingIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticGradientClipperIFType, - PydanticLLMDataLoaderIFType, PydanticLossIFType, PydanticLRSchedulerIFType, PydanticMessageSubscriberIFType, @@ -170,8 +170,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel scheduler: PydanticLRSchedulerIFType loss_fn: PydanticLossIFType | list[PydanticLossIFType] train_dataset: PydanticDatasetIFType - train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: list[PydanticLLMDataLoaderIFType] + train_dataloader: PydanticDataLoaderIFType + eval_dataloaders: list[PydanticDataLoaderIFType] progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index b6c5d555..25d87973 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -56,7 +56,7 @@ def __get_pydantic_core_schema__( PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] -PydanticLLMDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] +PydanticDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)] PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index ef2cedc3..f9156358 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,3 +1,4 @@ +import multiprocessing from typing import Iterable, Optional import webdataset as wd @@ -11,7 +12,7 @@ class DataLoaderIF: pass -class LLMDataLoader(DataLoader[T_co], DataLoaderIF): +class LLMDataLoader(DataLoaderIF): """LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class.""" def __init__( @@ -62,7 +63,9 @@ def __init__( None """ assert batch_sampler is not None and batch_size == 1 - super().__init__( + self._dataloader_tag = dataloader_tag + self._batch_size = batch_sampler.batch_size + self._torch_dataloader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, # shuffling must be implemented on a dataset level @@ -81,9 +84,6 @@ def __init__( pin_memory_device=pin_memory_device, ) - self._dataloader_tag = dataloader_tag - self._batch_size = batch_sampler.batch_size - @property def dataloader_tag(self) -> str: """ @@ -125,6 +125,47 @@ def batch_size(self, value: int): """ self._batch_size = value + def __len__(self): + return self._torch_dataloader.__len__() + + def __iter__(self): + return self._torch_dataloader.__iter__() + + @property + def dataset(self) -> Dataset[T_co]: + return self._torch_dataloader.dataset + + @property + def batch_sampler(self) -> ResumableBatchSampler: + return self._torch_dataloader.batch_sampler + + @property + def sampler(self) -> Sampler | Iterable | None: + return self._torch_dataloader.sampler + + @property + def collate_fn(self) -> _collate_fn_t: + return self._torch_dataloader.collate_fn + + @property + def multiprocessing_context(self) -> str | multiprocessing.context.BaseContext: + return self._torch_dataloader.multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + self._torch_dataloader.multiprocessing_context = multiprocessing_context + + @property + def _auto_collation(self): + return self._torch_dataloader._auto_collation + + @property + def _index_sampler(self): + return self._torch_dataloader._index_sampler + + def check_worker_number_rationality(self): + return self._torch_dataloader.check_worker_number_rationality() + @property def fast_forward_batch_id(self) -> int: """ @@ -133,15 +174,15 @@ def fast_forward_batch_id(self) -> int: Returns: int: fast forward batch ID """ - return self.batch_sampler.start_index + return self._torch_dataloader.batch_sampler.start_index -class RepeatingDataLoader(LLMDataLoader[T_co]): +class RepeatingDataLoader(LLMDataLoader): """ RepeatingDataLoader is a custom DataLoader class that repeats the given dataloader for the specified number of epochs.""" - def __init__(self, dataloader: LLMDataLoader[T_co], num_epochs: int, reshuffle_after_epoch: bool = False): + def __init__(self, dataloader: LLMDataLoader, num_epochs: int, reshuffle_after_epoch: bool = False): """ Initializes a RepeatingDataLoader object that repeats the given dataloader for the specified number of epochs. This is especially useful for DataLoader types that we wish to automatically restart upon completion. diff --git a/tests/dataloader/distributed/test_distributed_dataloader.py b/tests/dataloader/distributed/test_distributed_dataloader.py index 0038d04a..0d2b0b09 100644 --- a/tests/dataloader/distributed/test_distributed_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_dataloader.py @@ -9,7 +9,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -18,7 +18,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py index 7f40cc97..418793a4 100644 --- a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -17,7 +17,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 65139ce6..9d6f171a 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -10,7 +10,7 @@ from modalities.config.component_factory import ComponentFactory from modalities.config.config import load_app_config_dict -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.dataset import Dataset from modalities.dataloader.samplers import ResumableBatchSampler @@ -49,7 +49,7 @@ def test_dataloader_from_config(dummy_config: dict): dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) @@ -167,7 +167,7 @@ def test_repeating_dataloader_with_shuffling(): def test_skipped_and_distributed_dataloader_from_config(): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType skip_num_batches: int root_dir = Path(__file__).parents[0] @@ -244,7 +244,7 @@ class DataloaderTestModel(BaseModel): ) def test_dataloader_with_fixed_num_batches(global_rank): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType fixed_num_batches: int class IdentityCollateFn(CollateFnIF): diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index 3261eb4b..dac0b402 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -11,7 +11,7 @@ from modalities.__main__ import Main, load_app_config_dict from modalities.batch import EvaluationResultBatch -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import Message @@ -46,7 +46,7 @@ class SaveAllResultSubscriberConfig(BaseModel): class TrainDataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif(