Skip to content

Commit

Permalink
fix: support proper iterating over MemMapDatasets
Browse files Browse the repository at this point in the history
  • Loading branch information
luzian-hahn committed Jan 15, 2024
1 parent 0b699f2 commit 16b4522
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/modalities/dataloader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, raw_data_path: Path, block_size: int, sample_key: str):
self.block_size = block_size
self.sample_key = sample_key

def _check_if_inbounds(self, idx: int):
if not 0 <= idx < len(self):
raise IndexError


class MemMapDataset(Dataset):
def __init__(
Expand Down Expand Up @@ -54,6 +58,7 @@ def __len__(self) -> int:
return len(self.reader)

def __getitem__(self, idx: int) -> BatchEncoding:
self._check_if_inbounds(idx)
return self.tokenizer(
self.jq_filter.input_text(self.reader[idx]).first(),
max_length=self.block_size,
Expand Down Expand Up @@ -135,6 +140,7 @@ def __len__(self) -> int:
return self._num_samples

def __getitem__(self, idx: int) -> BatchEncoding:
self._check_if_inbounds(idx)
tokens_as_byte_strings = np.memmap(
self.raw_data_path,
mode="r",
Expand Down Expand Up @@ -192,6 +198,7 @@ def __len__(self) -> int:
return len(self._index)

def __getitem__(self, idx: int) -> BatchEncoding:
self._check_if_inbounds(idx)
offset, length = self._index[idx]
tokens_as_byte_strings = np.memmap(
self.raw_data_path,
Expand Down
4 changes: 1 addition & 3 deletions tests/dataloader/test_packed_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest

from modalities.dataloader.create_packed_data import PackedDataGenerator
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.dataloader.dataset import PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron


Expand All @@ -26,8 +25,7 @@ def test_packed_megatron_dataset_loading(dummy_packed_data_path, block_size, exp
def test_packed_continuous_dataset_loading(dummy_packed_data_path, block_size, expected_length, expected_output):
ds = PackedMemMapDatasetContinuous(dummy_packed_data_path, block_size, sample_key="input_ids")
assert len(ds) == expected_length
dl = LLMDataLoader(dataloader_tag="unittest", dataset=ds)
retrieved_input_ids = [list(x["input_ids"]) for x in dl]
retrieved_input_ids = [list(batch["input_ids"]) for batch in ds]
assert retrieved_input_ids == expected_output


Expand Down

0 comments on commit 16b4522

Please sign in to comment.