Skip to content

Commit

Permalink
DataLoader docs and split fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed Feb 14, 2024
1 parent 3e0adb6 commit a34922c
Show file tree
Hide file tree
Showing 3 changed files with 827 additions and 3 deletions.
17 changes: 14 additions & 3 deletions dspy/datasets/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def _process_dataset(
dataset: Dataset,
fields: List[str] = None
):
if not(self.train_size and self.dev_size and self.test_size):
self.train_size = 1.0

train_split_size = self.train_size if self.train_size else 0
dev_split_size = self.dev_size if self.dev_size else 0
test_split_size = self.test_size if self.test_size else 0
Expand All @@ -20,9 +23,13 @@ def _process_dataset(
train_split_size = int(len(dataset) * train_split_size)

if train_split_size:
tmp_dataset = dataset.train_test_split(test_size=(dev_split_size+test_split_size))
train_dataset = tmp_dataset["train"]
dataset = tmp_dataset["test"]
test_size = dev_split_size + test_split_size
if test_size > 0:
tmp_dataset = dataset.train_test_split(test_size=test_size)
train_dataset = tmp_dataset["train"]
dataset = tmp_dataset["test"]
else:
train_dataset = dataset

if isinstance(dev_split_size, float):
dev_split_size = int(len(dataset) * dev_split_size)
Expand All @@ -43,6 +50,10 @@ def _process_dataset(

if test_split_size:
self._test = [{field:row[field] for field in fields} for row in test_dataset]

self.train_size = train_split_size
self.dev_size = dev_split_size
self.test_size = test_split_size

def from_huggingface(
self,
Expand Down
Loading

0 comments on commit a34922c

Please sign in to comment.