diff --git a/pyproject.toml b/pyproject.toml index 777ac42..f6140d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,18 @@ ignore = [ #D107 Missing docstring in __init__ "D107" ] +"tosfs/fsspec_utils.py" = [ + #SIM108 Use ternary operator xxx instead of `if`-`else`-block + "SIM108", + #N803 Argument name should be lowercase + "N803", + #PLR0915 Too many statements + "PLR0915", + #PLR0912 Too many branches + "PLR0912", + #C901 is too complex + "C901", +] [tool.mypy] python_version = "3.9" @@ -90,7 +102,10 @@ ignore_missing_imports = true disallow_untyped_calls = true disallow_untyped_defs = true strict_optional = true -exclude = "tosfs/tests/.*" +exclude = [ + "tosfs/tests/.*", + "tosfs/fsspec_utils.py" +] [build-system] requires = ["poetry-core"] diff --git a/tosfs/core.py b/tosfs/core.py index 1534909..c52dbb6 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -19,7 +19,7 @@ import os import time from glob import has_magic -from typing import Any, BinaryIO, Generator, List, Optional, Tuple, Union +from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union import tos from fsspec import AbstractFileSystem @@ -41,6 +41,7 @@ TOS_SERVER_RESPONSE_CODE_NOT_FOUND, ) from tosfs.exceptions import TosfsError +from tosfs.fsspec_utils import glob_translate from tosfs.utils import find_bucket_key, get_brange, retryable_func_wrapper # environment variable names @@ -912,6 +913,80 @@ def cp_file( # serial multipart copy self._copy_managed(path1, path2, size, **kwargs) + def glob( + self, path: str, maxdepth: Optional[int] = None, **kwargs: Any + ) -> Collection[Any]: + """Return list of paths matching a glob-like pattern. + + Parameters + ---------- + path : str + The path to search. + maxdepth : int, optional + The maximum depth to search to (default is None). + **kwargs : Any, optional + Additional arguments. + + """ + if path.startswith("*"): + raise ValueError("Cannot traverse all of tosfs") + + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + import re + + seps = (os.path.sep, os.path.altsep) if os.path.altsep else (os.path.sep,) + ends_with_sep = path.endswith(seps) # _strip_protocol strips trailing slash + path = self._strip_protocol(path) + append_slash_to_dirname = ends_with_sep or path.endswith( + tuple(sep + "**" for sep in seps) + ) + + idx_star = path.find("*") if path.find("*") >= 0 else len(path) + idx_qmark = path.find("?") if path.find("?") >= 0 else len(path) + idx_brace = path.find("[") if path.find("[") >= 0 else len(path) + min_idx = min(idx_star, idx_qmark, idx_brace) + + detail = kwargs.pop("detail", False) + + if not has_magic(path): + if self.exists(path, **kwargs): + return {path: self.info(path, **kwargs)} if detail else [path] + return {} if detail else [] + + depth: Optional[int] = None + root, depth = "", path[min_idx + 1 :].count("/") + 1 + if "/" in path[:min_idx]: + min_idx = path[:min_idx].rindex("/") + root = path[: min_idx + 1] + + if "**" in path: + if maxdepth is not None: + idx_double_stars = path.find("**") + depth_double_stars = path[idx_double_stars:].count("/") + 1 + depth = depth - depth_double_stars + maxdepth + else: + depth = None + + allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs) + pattern = re.compile(glob_translate(path + ("/" if ends_with_sep else ""))) + + if isinstance(allpaths, dict): + out = { + p: info + for p, info in sorted(allpaths.items()) + if pattern.match( + p + "/" + if append_slash_to_dirname and info["type"] == "directory" + else p + ) + } + else: + out = {} + + return out if detail else list(out) + def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None: """Copy file between locations on tos. diff --git a/tosfs/fsspec_utils.py b/tosfs/fsspec_utils.py new file mode 100644 index 0000000..2239aa9 --- /dev/null +++ b/tosfs/fsspec_utils.py @@ -0,0 +1,130 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The module contains utility functions copied from fsspec package.""" + +import os +import re + + +def _translate(pat, STAR, QUESTION_MARK): + # Copied from: https://github.com/python/cpython/pull/106703. + res: list[str] = [] + add = res.append + i, n = 0, len(pat) + while i < n: + c = pat[i] + i = i + 1 + if c == "*": + # compress consecutive `*` into one + if (not res) or res[-1] is not STAR: + add(STAR) + elif c == "?": + add(QUESTION_MARK) + elif c == "[": + j = i + if j < n and pat[j] == "!": + j = j + 1 + if j < n and pat[j] == "]": + j = j + 1 + while j < n and pat[j] != "]": + j = j + 1 + if j >= n: + add("\\[") + else: + stuff = pat[i:j] + if "-" not in stuff: + stuff = stuff.replace("\\", r"\\") + else: + chunks = [] + k = i + 2 if pat[i] == "!" else i + 1 + while True: + k = pat.find("-", k, j) + if k < 0: + break + chunks.append(pat[i:k]) + i = k + 1 + k = k + 3 + chunk = pat[i:j] + if chunk: + chunks.append(chunk) + else: + chunks[-1] += "-" + # Remove empty ranges -- invalid in RE. + for k in range(len(chunks) - 1, 0, -1): + if chunks[k - 1][-1] > chunks[k][0]: + chunks[k - 1] = chunks[k - 1][:-1] + chunks[k][1:] + del chunks[k] + # Escape backslashes and hyphens for set difference (--). + # Hyphens that create ranges shouldn't be escaped. + stuff = "-".join( + s.replace("\\", r"\\").replace("-", r"\-") for s in chunks + ) + # Escape set operations (&&, ~~ and ||). + stuff = re.sub(r"([&~|])", r"\\\1", stuff) + i = j + 1 + if not stuff: + # Empty range: never match. + add("(?!)") + elif stuff == "!": + # Negated empty range: match any character. + add(".") + else: + if stuff[0] == "!": + stuff = "^" + stuff[1:] + elif stuff[0] in ("^", "["): + stuff = "\\" + stuff + add(f"[{stuff}]") + else: + add(re.escape(c)) + assert i == n + return res + + +def glob_translate(pat: str): + # Copied from: https://github.com/python/cpython/pull/106703. + # The keyword parameters' values are fixed to: + # recursive=True, include_hidden=True, seps=None + """Translate a pathname with shell wildcards to a regular expression.""" + if os.path.altsep: + seps = os.path.sep + os.path.altsep + else: + seps = os.path.sep + escaped_seps = "".join(map(re.escape, seps)) + any_sep = f"[{escaped_seps}]" if len(seps) > 1 else escaped_seps + not_sep = f"[^{escaped_seps}]" + one_last_segment = f"{not_sep}+" + one_segment = f"{one_last_segment}{any_sep}" + any_segments = f"(?:.+{any_sep})?" + any_last_segments = ".*" + results = [] + parts = re.split(any_sep, pat) + last_part_idx = len(parts) - 1 + for idx, part in enumerate(parts): + if part == "*": + results.append(one_segment if idx < last_part_idx else one_last_segment) + continue + if part == "**": + results.append(any_segments if idx < last_part_idx else any_last_segments) + continue + elif "**" in part: + raise ValueError( + "Invalid pattern: '**' can only be an entire path component" + ) + if part: + results.extend(_translate(part, f"{not_sep}*", not_sep)) + if idx < last_part_idx: + results.append(any_sep) + res = "".join(results) + return rf"(?s:{res})\Z" diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index 9c2e89a..0f2e855 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -14,6 +14,7 @@ import os.path import tempfile +import fsspec import pytest from tos.exceptions import TosServerError @@ -21,6 +22,8 @@ from tosfs.exceptions import TosfsError from tosfs.utils import create_temp_dir, random_str +fsspec_version = fsspec.__version__ + def test_ls_bucket(tosfs: TosFileSystem, bucket: str) -> None: assert bucket in tosfs.ls("", detail=False) @@ -535,6 +538,96 @@ def test_expand_path( ) +def test_glob(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None: + dir_name = random_str() + sub_dir_name = random_str() + file_name = random_str() + sub_file_name = random_str() + nested_file_name = random_str() + + tosfs.makedirs(f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}") + tosfs.touch(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}") + tosfs.touch( + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}" + ) + tosfs.touch( + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{nested_file_name}" + ) + + # Test invalid inputs + with pytest.raises(ValueError, match="Cannot traverse all of tosfs"): + tosfs.glob("*") + + with pytest.raises(ValueError, match="maxdepth must be at least 1"): + tosfs.glob(f"{bucket}/{temporary_workspace}", maxdepth=0) + + # Test valid inputs + # No wildcards + assert tosfs.glob(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}") == [ + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}" + ] + + # Single wildcard * + assert sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/{dir_name}/*")) == sorted( + [ + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}", + ] + ) + + # Single character wildcard ? + assert tosfs.glob( + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name[:-1]}?" + ) == [f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"] + + # Character class wildcard [] + assert tosfs.glob( + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name[:-1]}[{file_name[-1]}]" + ) == [f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"] + + # Recursive wildcard ** + assert sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/**")) == sorted( + [ + f"{bucket}/{temporary_workspace}", + f"{bucket}/{temporary_workspace}/{dir_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{nested_file_name}", + ] + ) + + # Test with maxdepth + assert ( + sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/**", maxdepth=2)) + == sorted( + [ + f"{bucket}/{temporary_workspace}/{dir_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}", + ] + ) + if fsspec_version == "2023.5.0" + else sorted( + [ + f"{bucket}/{temporary_workspace}", + f"{bucket}/{temporary_workspace}/{dir_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}", + f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}", + ] + ) + ) + + # Test with detail + result = tosfs.glob(f"{bucket}/{temporary_workspace}/**", detail=True) + assert isinstance(result, dict) + assert f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}" in result + assert ( + result[f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"]["type"] + == "file" + ) + + ########################################################### # File operation tests # ########################################################### diff --git a/tosfs/utils.py b/tosfs/utils.py index d949877..0e96e96 100644 --- a/tosfs/utils.py +++ b/tosfs/utils.py @@ -13,7 +13,6 @@ # limitations under the License. """The module contains utility functions for the tosfs package.""" - import random import re import string