Skip to content

Commit

Permalink
Core: Implement glob api (#62)
Browse files Browse the repository at this point in the history
* Core: Implement glob api

* Core: Implement glob api

* Core: Implement glob api
  • Loading branch information
yanghua authored Sep 9, 2024
1 parent d9d24d7 commit 330abd7
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 3 deletions.
17 changes: 16 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,29 @@ ignore = [
#D107 Missing docstring in __init__
"D107"
]
"tosfs/fsspec_utils.py" = [
#SIM108 Use ternary operator xxx instead of `if`-`else`-block
"SIM108",
#N803 Argument name should be lowercase
"N803",
#PLR0915 Too many statements
"PLR0915",
#PLR0912 Too many branches
"PLR0912",
#C901 is too complex
"C901",
]

[tool.mypy]
python_version = "3.9"
ignore_missing_imports = true
disallow_untyped_calls = true
disallow_untyped_defs = true
strict_optional = true
exclude = "tosfs/tests/.*"
exclude = [
"tosfs/tests/.*",
"tosfs/fsspec_utils.py"
]

[build-system]
requires = ["poetry-core"]
Expand Down
77 changes: 76 additions & 1 deletion tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import os
import time
from glob import has_magic
from typing import Any, BinaryIO, Generator, List, Optional, Tuple, Union
from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union

import tos
from fsspec import AbstractFileSystem
Expand All @@ -41,6 +41,7 @@
TOS_SERVER_RESPONSE_CODE_NOT_FOUND,
)
from tosfs.exceptions import TosfsError
from tosfs.fsspec_utils import glob_translate
from tosfs.utils import find_bucket_key, get_brange, retryable_func_wrapper

# environment variable names
Expand Down Expand Up @@ -912,6 +913,80 @@ def cp_file(
# serial multipart copy
self._copy_managed(path1, path2, size, **kwargs)

def glob(
self, path: str, maxdepth: Optional[int] = None, **kwargs: Any
) -> Collection[Any]:
"""Return list of paths matching a glob-like pattern.
Parameters
----------
path : str
The path to search.
maxdepth : int, optional
The maximum depth to search to (default is None).
**kwargs : Any, optional
Additional arguments.
"""
if path.startswith("*"):
raise ValueError("Cannot traverse all of tosfs")

if maxdepth is not None and maxdepth < 1:
raise ValueError("maxdepth must be at least 1")

import re

seps = (os.path.sep, os.path.altsep) if os.path.altsep else (os.path.sep,)
ends_with_sep = path.endswith(seps) # _strip_protocol strips trailing slash
path = self._strip_protocol(path)
append_slash_to_dirname = ends_with_sep or path.endswith(
tuple(sep + "**" for sep in seps)
)

idx_star = path.find("*") if path.find("*") >= 0 else len(path)
idx_qmark = path.find("?") if path.find("?") >= 0 else len(path)
idx_brace = path.find("[") if path.find("[") >= 0 else len(path)
min_idx = min(idx_star, idx_qmark, idx_brace)

detail = kwargs.pop("detail", False)

if not has_magic(path):
if self.exists(path, **kwargs):
return {path: self.info(path, **kwargs)} if detail else [path]
return {} if detail else []

depth: Optional[int] = None
root, depth = "", path[min_idx + 1 :].count("/") + 1
if "/" in path[:min_idx]:
min_idx = path[:min_idx].rindex("/")
root = path[: min_idx + 1]

if "**" in path:
if maxdepth is not None:
idx_double_stars = path.find("**")
depth_double_stars = path[idx_double_stars:].count("/") + 1
depth = depth - depth_double_stars + maxdepth
else:
depth = None

allpaths = self.find(root, maxdepth=depth, withdirs=True, detail=True, **kwargs)
pattern = re.compile(glob_translate(path + ("/" if ends_with_sep else "")))

if isinstance(allpaths, dict):
out = {
p: info
for p, info in sorted(allpaths.items())
if pattern.match(
p + "/"
if append_slash_to_dirname and info["type"] == "directory"
else p
)
}
else:
out = {}

return out if detail else list(out)

def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None:
"""Copy file between locations on tos.
Expand Down
130 changes: 130 additions & 0 deletions tosfs/fsspec_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# ByteDance Volcengine EMR, Copyright 2024.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The module contains utility functions copied from fsspec package."""

import os
import re


def _translate(pat, STAR, QUESTION_MARK):
# Copied from: https://github.com/python/cpython/pull/106703.
res: list[str] = []
add = res.append
i, n = 0, len(pat)
while i < n:
c = pat[i]
i = i + 1
if c == "*":
# compress consecutive `*` into one
if (not res) or res[-1] is not STAR:
add(STAR)
elif c == "?":
add(QUESTION_MARK)
elif c == "[":
j = i
if j < n and pat[j] == "!":
j = j + 1
if j < n and pat[j] == "]":
j = j + 1
while j < n and pat[j] != "]":
j = j + 1
if j >= n:
add("\\[")
else:
stuff = pat[i:j]
if "-" not in stuff:
stuff = stuff.replace("\\", r"\\")
else:
chunks = []
k = i + 2 if pat[i] == "!" else i + 1
while True:
k = pat.find("-", k, j)
if k < 0:
break
chunks.append(pat[i:k])
i = k + 1
k = k + 3
chunk = pat[i:j]
if chunk:
chunks.append(chunk)
else:
chunks[-1] += "-"
# Remove empty ranges -- invalid in RE.
for k in range(len(chunks) - 1, 0, -1):
if chunks[k - 1][-1] > chunks[k][0]:
chunks[k - 1] = chunks[k - 1][:-1] + chunks[k][1:]
del chunks[k]
# Escape backslashes and hyphens for set difference (--).
# Hyphens that create ranges shouldn't be escaped.
stuff = "-".join(
s.replace("\\", r"\\").replace("-", r"\-") for s in chunks
)
# Escape set operations (&&, ~~ and ||).
stuff = re.sub(r"([&~|])", r"\\\1", stuff)
i = j + 1
if not stuff:
# Empty range: never match.
add("(?!)")
elif stuff == "!":
# Negated empty range: match any character.
add(".")
else:
if stuff[0] == "!":
stuff = "^" + stuff[1:]
elif stuff[0] in ("^", "["):
stuff = "\\" + stuff
add(f"[{stuff}]")
else:
add(re.escape(c))
assert i == n
return res


def glob_translate(pat: str):
# Copied from: https://github.com/python/cpython/pull/106703.
# The keyword parameters' values are fixed to:
# recursive=True, include_hidden=True, seps=None
"""Translate a pathname with shell wildcards to a regular expression."""
if os.path.altsep:
seps = os.path.sep + os.path.altsep
else:
seps = os.path.sep
escaped_seps = "".join(map(re.escape, seps))
any_sep = f"[{escaped_seps}]" if len(seps) > 1 else escaped_seps
not_sep = f"[^{escaped_seps}]"
one_last_segment = f"{not_sep}+"
one_segment = f"{one_last_segment}{any_sep}"
any_segments = f"(?:.+{any_sep})?"
any_last_segments = ".*"
results = []
parts = re.split(any_sep, pat)
last_part_idx = len(parts) - 1
for idx, part in enumerate(parts):
if part == "*":
results.append(one_segment if idx < last_part_idx else one_last_segment)
continue
if part == "**":
results.append(any_segments if idx < last_part_idx else any_last_segments)
continue
elif "**" in part:
raise ValueError(
"Invalid pattern: '**' can only be an entire path component"
)
if part:
results.extend(_translate(part, f"{not_sep}*", not_sep))
if idx < last_part_idx:
results.append(any_sep)
res = "".join(results)
return rf"(?s:{res})\Z"
93 changes: 93 additions & 0 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
import os.path
import tempfile

import fsspec
import pytest
from tos.exceptions import TosServerError

from tosfs.core import TosFileSystem
from tosfs.exceptions import TosfsError
from tosfs.utils import create_temp_dir, random_str

fsspec_version = fsspec.__version__


def test_ls_bucket(tosfs: TosFileSystem, bucket: str) -> None:
assert bucket in tosfs.ls("", detail=False)
Expand Down Expand Up @@ -535,6 +538,96 @@ def test_expand_path(
)


def test_glob(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None:
dir_name = random_str()
sub_dir_name = random_str()
file_name = random_str()
sub_file_name = random_str()
nested_file_name = random_str()

tosfs.makedirs(f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}")
tosfs.touch(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}")
tosfs.touch(
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}"
)
tosfs.touch(
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{nested_file_name}"
)

# Test invalid inputs
with pytest.raises(ValueError, match="Cannot traverse all of tosfs"):
tosfs.glob("*")

with pytest.raises(ValueError, match="maxdepth must be at least 1"):
tosfs.glob(f"{bucket}/{temporary_workspace}", maxdepth=0)

# Test valid inputs
# No wildcards
assert tosfs.glob(f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}") == [
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"
]

# Single wildcard *
assert sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/{dir_name}/*")) == sorted(
[
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}",
]
)

# Single character wildcard ?
assert tosfs.glob(
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name[:-1]}?"
) == [f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"]

# Character class wildcard []
assert tosfs.glob(
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name[:-1]}[{file_name[-1]}]"
) == [f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"]

# Recursive wildcard **
assert sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/**")) == sorted(
[
f"{bucket}/{temporary_workspace}",
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{sub_file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}/{nested_file_name}",
]
)

# Test with maxdepth
assert (
sorted(tosfs.glob(f"{bucket}/{temporary_workspace}/**", maxdepth=2))
== sorted(
[
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}",
]
)
if fsspec_version == "2023.5.0"
else sorted(
[
f"{bucket}/{temporary_workspace}",
f"{bucket}/{temporary_workspace}/{dir_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}",
f"{bucket}/{temporary_workspace}/{dir_name}/{sub_dir_name}",
]
)
)

# Test with detail
result = tosfs.glob(f"{bucket}/{temporary_workspace}/**", detail=True)
assert isinstance(result, dict)
assert f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}" in result
assert (
result[f"{bucket}/{temporary_workspace}/{dir_name}/{file_name}"]["type"]
== "file"
)


###########################################################
# File operation tests #
###########################################################
Expand Down
1 change: 0 additions & 1 deletion tosfs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""The module contains utility functions for the tosfs package."""

import random
import re
import string
Expand Down

0 comments on commit 330abd7

Please sign in to comment.