Skip to content

Commit

Permalink
Merge branch 'develop' into add/1034
Browse files Browse the repository at this point in the history
  • Loading branch information
PGijsbers committed Sep 29, 2024
2 parents f4bd174 + d37542b commit 26ca9b5
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
24 changes: 19 additions & 5 deletions openml/_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
import math
import random
import shutil
import time
import urllib.parse
import xml
Expand Down Expand Up @@ -186,14 +187,14 @@ def _download_minio_file(
def _download_minio_bucket(source: str, destination: str | Path) -> None:
"""Download file ``source`` from a MinIO Bucket and store it at ``destination``.
Does not redownload files which already exist.
Parameters
----------
source : str
URL to a MinIO bucket.
destination : str | Path
Path to a directory to store the bucket content in.
exists_ok : bool, optional (default=True)
If False, raise FileExists if a file already exists in ``destination``.
"""
destination = Path(destination)
parsed_url = urllib.parse.urlparse(source)
Expand All @@ -206,15 +207,28 @@ def _download_minio_bucket(source: str, destination: str | Path) -> None:

for file_object in client.list_objects(bucket, prefix=prefix, recursive=True):
if file_object.object_name is None:
raise ValueError("Object name is None.")
raise ValueError(f"Object name is None for object {file_object!r}")

with contextlib.suppress(FileExistsError): # Simply use cached version instead
marker = destination / file_object.etag
if marker.exists():
continue

file_destination = destination / file_object.object_name.rsplit("/", 1)[1]
if (file_destination.parent / file_destination.stem).exists():
# Marker is missing but archive exists means the server archive changed, force a refresh
shutil.rmtree(file_destination.parent / file_destination.stem)

with contextlib.suppress(FileExistsError):
_download_minio_file(
source=source.rsplit("/", 1)[0] + "/" + file_object.object_name.rsplit("/", 1)[1],
destination=Path(destination, file_object.object_name.rsplit("/", 1)[1]),
destination=file_destination,
exists_ok=False,
)

if file_destination.is_file() and file_destination.suffix == ".zip":
file_destination.unlink()
marker.touch()


def _download_text_file(
source: str,
Expand Down
9 changes: 6 additions & 3 deletions openml/tasks/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,12 @@ def _get_repr_body_fields(self) -> Sequence[tuple[str, str | int | list[str]]]:
]
return [(key, fields[key]) for key in order if key in fields]

def get_dataset(self) -> datasets.OpenMLDataset:
"""Download dataset associated with task."""
return datasets.get_dataset(self.dataset_id)
def get_dataset(self, **kwargs) -> datasets.OpenMLDataset:
"""Download dataset associated with task.
Accepts the same keyword arguments as the `openml.datasets.get_dataset`.
"""
return datasets.get_dataset(self.dataset_id, **kwargs)

def get_train_test_split_indices(
self,
Expand Down

0 comments on commit 26ca9b5

Please sign in to comment.