Skip to content

Commit

Permalink
Dataloader return patch and doc fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed Feb 14, 2024
1 parent a9a6aba commit a0143ac
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 75 deletions.
27 changes: 18 additions & 9 deletions dspy/datasets/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dspy
from dspy.datasets import Dataset

from typing import Union, List
Expand All @@ -11,7 +12,7 @@ def _process_dataset(
self,
dataset: Dataset,
fields: List[str] = None
):
) -> List[dspy.Example]:
if not(self.train_size and self.dev_size and self.test_size):
self.train_size = 1.0

Expand Down Expand Up @@ -42,26 +43,34 @@ def _process_dataset(
dev_dataset = tmp_dataset["train"]
test_dataset = tmp_dataset["test"]

returned_split = {}
if train_split_size:
self._train = [{field:row[field] for field in fields} for row in train_dataset]
self.train_size = train_split_size

returned_split["train"] = self._shuffle_and_sample("train", self._train, self.train_size, self.train_seed)

if dev_split_size:
self._dev = [{field:row[field] for field in fields} for row in dev_dataset]
self.dev_size = dev_split_size

returned_split["dev"] = self._shuffle_and_sample("dev", self._dev, self.dev_size, self.dev_seed)

if test_split_size:
self._test = [{field:row[field] for field in fields} for row in test_dataset]

self.train_size = train_split_size
self.dev_size = dev_split_size
self.test_size = test_split_size
self.test_size = test_split_size

returned_split["test"] = self._shuffle_and_sample("test", self._test, self.test_size, self.test_seed)

return returned_split

def from_huggingface(
self,
dataset_name: str,
fields: List[str] = None,
splits: Union[str, List[str]] = None,
revision: str = None,
):
) -> List[dspy.Example]:
dataset = None
if splits:
if isinstance(splits, str):
Expand All @@ -86,12 +95,12 @@ def from_huggingface(
if not fields:
fields = list(dataset.features)

self._process_dataset(dataset, fields)
return self._process_dataset(dataset, fields)

def from_csv(self, file_path:str, fields: List[str] = None):
def from_csv(self, file_path:str, fields: List[str] = None) -> List[dspy.Example]:
dataset = load_dataset("csv", data_files=file_path)["train"]

if not fields:
fields = list(dataset.features)

self._process_dataset(dataset, fields)
return self._process_dataset(dataset, fields)
Loading

0 comments on commit a0143ac

Please sign in to comment.