-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
70a9ccd
commit 99d8455
Showing
1 changed file
with
74 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,95 @@ | ||
import dspy | ||
import random | ||
from dspy.datasets import Dataset | ||
|
||
from typing import Union, List | ||
from datasets import load_dataset, ReadInstruction | ||
from datasets import load_dataset | ||
from typing import Union, List, Mapping, Tuple | ||
|
||
class DataLoader(Dataset): | ||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
def __init__(self,): | ||
pass | ||
|
||
def _process_dataset( | ||
self, | ||
dataset: Dataset, | ||
fields: List[str] = None | ||
) -> List[dspy.Example]: | ||
if not(self.train_size and self.dev_size and self.test_size): | ||
self.train_size = 1.0 | ||
|
||
train_split_size = self.train_size if self.train_size else 0 | ||
dev_split_size = self.dev_size if self.dev_size else 0 | ||
test_split_size = self.test_size if self.test_size else 0 | ||
def from_huggingface( | ||
self, | ||
dataset_name: str, | ||
*args, | ||
input_keys: Tuple[str] = (), | ||
fields: List[str] = None, | ||
**kwargs | ||
) -> Union[Mapping[str, List[dspy.Example]], List[dspy.Example]]: | ||
if fields and not isinstance(fields, list): | ||
raise ValueError(f"Invalid fields provided. Please provide a list of fields.") | ||
|
||
if isinstance(train_split_size, float): | ||
train_split_size = int(len(dataset) * train_split_size) | ||
dataset = load_dataset(dataset_name, *args, **kwargs) | ||
|
||
if train_split_size: | ||
test_size = dev_split_size + test_split_size | ||
if test_size > 0: | ||
tmp_dataset = dataset.train_test_split(test_size=test_size) | ||
train_dataset = tmp_dataset["train"] | ||
dataset = tmp_dataset["test"] | ||
else: | ||
train_dataset = dataset | ||
|
||
if isinstance(dev_split_size, float): | ||
dev_split_size = int(len(dataset) * dev_split_size) | ||
try: | ||
returned_split = {} | ||
for split in dataset.keys(): | ||
if fields: | ||
returned_split[split] = [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset[split]] | ||
else: | ||
returned_split[split] = [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset[split]] | ||
|
||
if isinstance(test_split_size, float): | ||
test_split_size = int(len(dataset) * test_split_size) | ||
|
||
if dev_split_size or test_split_size: | ||
tmp_dataset = dataset.train_test_split(test_size=dev_split_size) | ||
dev_dataset = tmp_dataset["train"] | ||
test_dataset = tmp_dataset["test"] | ||
return returned_split | ||
except AttributeError: | ||
if fields: | ||
return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset] | ||
else: | ||
return [dspy.Example({field:row[field] for field in row.keys()}).with_inputs(input_keys) for row in dataset] | ||
|
||
returned_split = {} | ||
if train_split_size: | ||
self._train = [{field:row[field] for field in fields} for row in train_dataset] | ||
self.train_size = train_split_size | ||
|
||
returned_split["train"] = self._shuffle_and_sample("train", self._train, self.train_size, self.train_seed) | ||
def from_csv(self, file_path:str, fields: List[str] = None, input_keys: Tuple[str] = ()) -> List[dspy.Example]: | ||
dataset = load_dataset("csv", data_files=file_path)["train"] | ||
|
||
if dev_split_size: | ||
self._dev = [{field:row[field] for field in fields} for row in dev_dataset] | ||
self.dev_size = dev_split_size | ||
if not fields: | ||
fields = list(dataset.features) | ||
|
||
return [dspy.Example({field:row[field] for field in fields}).with_inputs(input_keys) for row in dataset] | ||
|
||
returned_split["dev"] = self._shuffle_and_sample("dev", self._dev, self.dev_size, self.dev_seed) | ||
def sample( | ||
self, | ||
dataset: List[dspy.Example], | ||
n: int, | ||
*args, | ||
**kwargs | ||
) -> List[dspy.Example]: | ||
if not isinstance(dataset, list): | ||
raise ValueError(f"Invalid dataset provided of type {type(dataset)}. Please provide a list of examples.") | ||
|
||
if test_split_size: | ||
self._test = [{field:row[field] for field in fields} for row in test_dataset] | ||
self.test_size = test_split_size | ||
return random.sample(dataset, n, *args, **kwargs) | ||
|
||
returned_split["test"] = self._shuffle_and_sample("test", self._test, self.test_size, self.test_seed) | ||
def train_test_split( | ||
self, | ||
dataset: List[dspy.Example], | ||
train_size: Union[int, float] = 0.75, | ||
test_size: Union[int, float] = None, | ||
random_state: int = None | ||
) -> Mapping[str, List[dspy.Example]]: | ||
if random_state is not None: | ||
random.seed(random_state) | ||
|
||
return returned_split | ||
dataset_shuffled = dataset.copy() | ||
random.shuffle(dataset_shuffled) | ||
|
||
def from_huggingface( | ||
self, | ||
dataset_name: str, | ||
fields: List[str] = None, | ||
splits: Union[str, List[str]] = None, | ||
revision: str = None, | ||
) -> List[dspy.Example]: | ||
dataset = None | ||
if splits: | ||
if isinstance(splits, str): | ||
splits = [splits] | ||
|
||
try: | ||
ri = ReadInstruction(splits[0]) | ||
for split in splits[1:]: | ||
ri += ReadInstruction(split) | ||
dataset = load_dataset(dataset_name, split=ri, revision=revision) | ||
except: | ||
raise ValueError("Invalid split name provided. Please provide a valid split name or list of split names.") | ||
if train_size is not None and isinstance(train_size, float) and (0 < train_size < 1): | ||
train_end = int(len(dataset_shuffled) * train_size) | ||
elif train_size is not None and isinstance(train_size, int): | ||
train_end = train_size | ||
else: | ||
dataset = load_dataset(dataset_name, revision=revision) | ||
if len(dataset.keys())==1: | ||
split_name = next(iter(dataset.keys())) | ||
dataset = dataset[split_name] | ||
raise ValueError("Invalid train_size. Please provide a float between 0 and 1 or an int.") | ||
|
||
if test_size is not None: | ||
if isinstance(test_size, float) and (0 < test_size < 1): | ||
test_end = int(len(dataset_shuffled) * test_size) | ||
elif isinstance(test_size, int): | ||
test_end = test_size | ||
else: | ||
raise ValueError("No splits provided and dataset has more than one split. At this moment multiple splits will be concatenated into one single split.") | ||
|
||
if not fields: | ||
fields = list(dataset.features) | ||
raise ValueError("Invalid test_size. Please provide a float between 0 and 1 or an int.") | ||
if train_end + test_end > len(dataset_shuffled): | ||
raise ValueError("train_size + test_size cannot exceed the total number of samples.") | ||
else: | ||
test_end = len(dataset_shuffled) - train_end | ||
|
||
return self._process_dataset(dataset, fields) | ||
train_dataset = dataset_shuffled[:train_end] | ||
test_dataset = dataset_shuffled[train_end:train_end + test_end] | ||
|
||
def from_csv(self, file_path:str, fields: List[str] = None) -> List[dspy.Example]: | ||
dataset = load_dataset("csv", data_files=file_path)["train"] | ||
|
||
if not fields: | ||
fields = list(dataset.features) | ||
|
||
return self._process_dataset(dataset, fields) | ||
return {'train': train_dataset, 'test': test_dataset} |