Skip to content

Commit

Permalink
edit abstraction for dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
krypticmouse committed Feb 15, 2024
1 parent 70a9ccd commit 99d8455
Showing 1 changed file with 74 additions and 85 deletions.
159 changes: 74 additions & 85 deletions dspy/datasets/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,95 @@
import dspy
import random
from dspy.datasets import Dataset

from typing import Union, List
from datasets import load_dataset, ReadInstruction
from datasets import load_dataset
from typing import Union, List, Mapping, Tuple

class DataLoader(Dataset):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self,):
pass

def _process_dataset(
self,
dataset: Dataset,
fields: List[str] = None
) -> List[dspy.Example]:
if not(self.train_size and self.dev_size and self.test_size):
self.train_size = 1.0

train_split_size = self.train_size if self.train_size else 0
dev_split_size = self.dev_size if self.dev_size else 0
test_split_size = self.test_size if self.test_size else 0
def from_huggingface(
self,
dataset_name: str,
*args,
input_keys: Tuple[str] = (),
fields: List[str] = None,
**kwargs
) -> Union[Mapping[str, List[dspy.Example]], List[dspy.Example]]:
if fields and not isinstance(fields, list):
raise ValueError(f"Invalid fields provided. Please provide a list of fields.")

if isinstance(train_split_size, float):
train_split_size = int(len(dataset) * train_split_size)
dataset = load_dataset(dataset_name, *args, **kwargs)

if train_split_size:
test_size = dev_split_size + test_split_size
if test_size > 0:
tmp_dataset = dataset.train_test_split(test_size=test_size)
train_dataset = tmp_dataset["train"]
dataset = tmp_dataset["test"]
else:
train_dataset = dataset

if isinstance(dev_split_size, float):
dev_split_size = int(len(dataset) * dev_split_size)
try:
returned_split = {}
for split in dataset.keys():
if fields:
returned_split[split] = [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset[split]]
else:
returned_split[split] = [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset[split]]

if isinstance(test_split_size, float):
test_split_size = int(len(dataset) * test_split_size)

if dev_split_size or test_split_size:
tmp_dataset = dataset.train_test_split(test_size=dev_split_size)
dev_dataset = tmp_dataset["train"]
test_dataset = tmp_dataset["test"]
return returned_split
except AttributeError:
if fields:
return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset]
else:
return [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset]

returned_split = {}
if train_split_size:
self._train = [{field:row[field] for field in fields} for row in train_dataset]
self.train_size = train_split_size

returned_split["train"] = self._shuffle_and_sample("train", self._train, self.train_size, self.train_seed)
def from_csv(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]:
dataset = load_dataset("csv", data_files=file_path)["train"]

if dev_split_size:
self._dev = [{field:row[field] for field in fields} for row in dev_dataset]
self.dev_size = dev_split_size
if not fields:
fields = list(dataset.features)

return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset]

returned_split["dev"] = self._shuffle_and_sample("dev", self._dev, self.dev_size, self.dev_seed)
def sample(
self,
dataset: List[dspy.Example],
n: int,
*args,
**kwargs
) -> List[dspy.Example]:
if not isinstance(dataset, list):
raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.")

if test_split_size:
self._test = [{field:row[field] for field in fields} for row in test_dataset]
self.test_size = test_split_size
return random.sample(dataset, n, *args, **kwargs)

returned_split["test"] = self._shuffle_and_sample("test", self._test, self.test_size, self.test_seed)
def train_test_split(
self,
dataset: List[dspy.Example],
train_size: Union[int, float] = 0.75,
test_size: Union[int, float] = None,
random_state: int = None
) -> Mapping[str, List[dspy.Example]]:
if random_state is not None:
random.seed(random_state)

return returned_split
dataset_shuffled = dataset.copy()
random.shuffle(dataset_shuffled)

def from_huggingface(
self,
dataset_name: str,
fields: List[str] = None,
splits: Union[str, List[str]] = None,
revision: str = None,
) -> List[dspy.Example]:
dataset = None
if splits:
if isinstance(splits, str):
splits = [splits]

try:
ri = ReadInstruction(splits[0])
for split in splits[1:]:
ri += ReadInstruction(split)
dataset = load_dataset(dataset_name, split=ri, revision=revision)
except:
raise ValueError("Invalid split name provided. Please provide a valid split name or list of split names.")
if train_size is not None and isinstance(train_size, float) and (0 < train_size < 1):
train_end = int(len(dataset_shuffled) * train_size)
elif train_size is not None and isinstance(train_size, int):
train_end = train_size
else:
dataset = load_dataset(dataset_name, revision=revision)
if len(dataset.keys())==1:
split_name = next(iter(dataset.keys()))
dataset = dataset[split_name]
raise ValueError("Invalid train_size. Please provide a float between 0 and 1 or an int.")

if test_size is not None:
if isinstance(test_size, float) and (0 < test_size < 1):
test_end = int(len(dataset_shuffled) * test_size)
elif isinstance(test_size, int):
test_end = test_size
else:
raise ValueError("No splits provided and dataset has more than one split. At this moment multiple splits will be concatenated into one single split.")

if not fields:
fields = list(dataset.features)
raise ValueError("Invalid test_size. Please provide a float between 0 and 1 or an int.")
if train_end + test_end > len(dataset_shuffled):
raise ValueError("train_size + test_size cannot exceed the total number of samples.")
else:
test_end = len(dataset_shuffled) - train_end

return self._process_dataset(dataset, fields)
train_dataset = dataset_shuffled[:train_end]
test_dataset = dataset_shuffled[train_end:train_end + test_end]

def from_csv(self, file_path:str, fields: List[str] = None) -> List[dspy.Example]:
dataset = load_dataset("csv", data_files=file_path)["train"]

if not fields:
fields = list(dataset.features)

return self._process_dataset(dataset, fields)
return {'train': train_dataset, 'test': test_dataset}

0 comments on commit 99d8455

Please sign in to comment.