Skip to content

Commit

Permalink
Refactor LoadHF class to streamline data loading and enhance error ha…
Browse files Browse the repository at this point in the history
…ndling

Signed-off-by: elronbandel <[email protected]>
  • Loading branch information
elronbandel committed Feb 10, 2025
1 parent 85c2cab commit 692297c
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ keep-runtime-typing = true
"src/unitxt/metric.py" = ["F811", "F401"]
"src/unitxt/dataset.py" = ["F811", "F401"]
"src/unitxt/blocks.py" = ["F811", "F401"]
"tests/library/test_loaders.py" = ["N802", "N803"]
"tests/library/test_loaders.py" = ["N802", "N803", "RUF015"]
"tests/library/test_dataclass.py" = ["F811", "E731"]
"src/unitxt/validate.py" = ["B024"]
"src/unitxt/standard.py" = ["C901"]
Expand Down
8 changes: 6 additions & 2 deletions src/unitxt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,14 +284,14 @@ def load_dataset(
data_files=self.data_files,
revision=self.revision,
streaming=streaming,
keep_in_memory=True,
cache_dir=cache_dir,
verification_mode="no_checks",
split=split,
trust_remote_code=settings.allow_unverified_code,
num_proc=self.num_proc,
download_config=DownloadConfig(
max_retries=settings.loaders_max_retries
max_retries=settings.loaders_max_retries,
# extract_on_the_fly=True,
),
)
except ValueError as e:
Expand Down Expand Up @@ -323,6 +323,10 @@ def get_splits(self):
path=self.path,
config_name=self.name,
trust_remote_code=settings.allow_unverified_code,
download_config=DownloadConfig(
max_retries=settings.loaders_max_retries,
extract_on_the_fly=True,
),
)
except:
UnitxtWarning(
Expand Down
90 changes: 41 additions & 49 deletions tests/library/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,114 +154,106 @@ def test_load_from_ibm_cos(self):

def test_load_from_HF_compressed(self):
loader = LoadHF(path="GEM/xlsum", name="igbo") # the smallest file
ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["train"]))
self.assertEqual(
ms.to_dataset()["train"][0]["url"],
instance["url"],
"https://www.bbc.com/igbo/afirika-43986554",
)
assert set(dataset.keys()) == {
assert set(ms.keys()) == {
"train",
"validation",
"test",
}, f"Unexpected fold {dataset.keys()}"
}, f"Unexpected fold {ms.keys()}"

def test_load_from_HF_compressed_split(self):
loader = LoadHF(
path="GEM/xlsum", name="igbo", split="train"
) # the smallest file
ms = loader.process()
dataset = ms.to_dataset()
loader = LoadHF(path="GEM/xlsum", name="igbo", split="train") # the smallest file
ms = loader()
instance = next(iter(ms["train"]))
self.assertEqual(
dataset["train"][0]["url"],
instance["url"],
"https://www.bbc.com/igbo/afirika-43986554",
)
assert list(dataset.keys()) == ["train"], f"Unexpected fold {dataset.keys()}"
assert list(ms.keys()) == ["train"], f"Unexpected fold {ms.keys()}"

def test_load_from_HF(self):
loader = LoadHF(path="sst2", loader_limit=10)
ms = loader.process()
dataset = ms.to_dataset()
loader = LoadHF(path="sst2", loader_limit=10, split="train")
ms = loader()
instance = next(iter(ms["train"]))
self.assertEqual(
dataset["train"][0]["sentence"],
instance["sentence"],
"hide new secretions from the parental units ",
)
self.assertEqual(
dataset["train"][0]["data_classification_policy"],
["public"],
)
self.assertEqual(
dataset["test"][0]["data_classification_policy"],
instance["data_classification_policy"],
["public"],
)
assert set(dataset.keys()) == {
assert set(ms.keys()) == {
"train",
"validation",
"test",
}, f"Unexpected fold {dataset.keys()}"
}, f"Unexpected fold {ms.keys()}"

def test_load_from_HF_multiple_innvocation(self):
loader = LoadHF(
path="CohereForAI/aya_evaluation_suite",
name="aya_human_annotated",
# filtering_lambda='lambda instance: instance["language"]=="eng"',
)
ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["test"]))
self.assertEqual(
list(dataset.keys()), ["test"]
list(ms.keys()), ["test"]
) # that HF dataset only has the 'test' split
self.assertEqual(dataset["test"][0]["language"], "arb")
self.assertEqual(instance["language"], "arb")

ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["test"]))
self.assertEqual(
list(dataset.keys()), ["test"]
list(ms.keys()), ["test"]
) # that HF dataset only has the 'test' split
self.assertEqual(dataset["test"][0]["language"], "arb")
self.assertEqual(instance["language"], "arb")

def test_load_from_HF_multiple_innvocation_with_filter(self):
loader = LoadHF(
path="CohereForAI/aya_evaluation_suite",
name="aya_human_annotated",
filtering_lambda='lambda instance: instance["language"]=="eng"',
)
ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["test"]))
self.assertEqual(
list(dataset.keys()), ["test"]
list(ms.keys()), ["test"]
) # that HF dataset only has the 'test' split
self.assertEqual(dataset["test"][0]["language"], "eng")
self.assertEqual(instance["language"], "eng")

ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["test"]))
self.assertEqual(
list(dataset.keys()), ["test"]
list(ms.keys()), ["test"]
) # that HF dataset only has the 'test' split
self.assertEqual(dataset["test"][0]["language"], "eng")
self.assertEqual(instance["language"], "eng")

def test_load_from_HF_split(self):
loader = LoadHF(path="sst2", split="train")
ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = next(iter(ms["train"]))
self.assertEqual(
dataset["train"][0]["sentence"],
instance["sentence"],
"hide new secretions from the parental units ",
)
assert list(dataset.keys()) == ["train"], f"Unexpected fold {dataset.keys()}"
assert list(ms.keys()) == ["train"], f"Unexpected fold {ms.keys()}"

def test_load_from_HF_filter(self):
loader = LoadHF(
path="CohereForAI/aya_evaluation_suite",
name="aya_human_annotated",
filtering_lambda='lambda instance: instance["language"]=="eng"',
)
ms = loader.process()
dataset = ms.to_dataset()
ms = loader()
instance = list(ms["test"])[0]
self.assertEqual(
list(dataset.keys()), ["test"]
list(ms.keys()), ["test"]
) # that HF dataset only has the 'test' split
self.assertEqual(dataset["test"][0]["language"], "eng")
self.assertEqual(instance["language"], "eng")

def test_multiple_source_loader(self):
# Using a context for the temporary directory
Expand Down
4 changes: 2 additions & 2 deletions utils/.secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "src/unitxt/loaders.py",
"hashed_secret": "840268f77a57d5553add023cfa8a4d1535f49742",
"is_verified": false,
"line_number": 585,
"line_number": 589,
"is_secret": false
}
],
Expand Down Expand Up @@ -184,5 +184,5 @@
}
]
},
"generated_at": "2025-02-10T11:27:26Z"
"generated_at": "2025-02-10T13:25:07Z"
}

0 comments on commit 692297c

Please sign in to comment.