From 6ab3365146075e5b3870c8ddfd15fe831b25b815 Mon Sep 17 00:00:00 2001 From: yanghua Date: Wed, 14 Aug 2024 16:13:25 +0800 Subject: [PATCH] Implement ls API --- tosfs/core.py | 286 +++++++++++++++++++++++++++++++++++- tosfs/exceptions.py | 25 ++++ tosfs/tests/conftest.py | 60 ++++++++ tosfs/tests/test_logging.py | 2 +- tosfs/tests/test_tosfs.py | 59 +++++++- tosfs/tests/test_utils.py | 32 ++++ tosfs/utils.py | 74 ++++++++++ 7 files changed, 533 insertions(+), 5 deletions(-) create mode 100644 tosfs/exceptions.py create mode 100644 tosfs/tests/conftest.py create mode 100644 tosfs/tests/test_utils.py create mode 100644 tosfs/utils.py diff --git a/tosfs/core.py b/tosfs/core.py index c2f1c68..1d5b1d8 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -1,4 +1,4 @@ -# ByteDance Volcengine EMR, Copyright 2022. +# ByteDance Volcengine EMR, Copyright 2024. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,8 +17,16 @@ """ import logging import os +from typing import Optional, Tuple +import tos +from fsspec import AbstractFileSystem from fsspec.utils import setup_logging as setup_logger +from tos.models import CommonPrefixInfo +from tos.models2 import ListedObjectVersion + +from tosfs.exceptions import TosfsError +from tosfs.utils import find_bucket_key # environment variable names ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL" @@ -41,3 +49,279 @@ def setup_logging(): logger.warning( "The tosfs's log level is set to be %s", logging.getLevelName(logger.level) ) + + +class TosFileSystem(AbstractFileSystem): + """ + Tos file system. An implementation of AbstractFileSystem which is an + abstract super-class for pythonic file-systems. + """ + + def __init__( + self, + endpoint_url=None, + key="", + secret="", + region=None, + version_aware=False, + credentials_provider=None, + **kwargs, + ): + self.tos_client = tos.TosClientV2( + key, + secret, + endpoint_url, + region, + credentials_provider=credentials_provider, + ) + self.version_aware = version_aware + super().__init__(**kwargs) + + def ls(self, path, detail=False, refresh=False, versions=False, **kwargs): + """ + List objects under the given path. + :param path: The path to list. + :param detail: Whether to return detailed information. + :param refresh: Whether to refresh the cache. + :param versions: Whether to list object versions. + :param kwargs: Additional arguments. + :return: A list of objects under the given path. + """ + path = self._strip_protocol(path).rstrip("/") + if path in ["", "/"]: + files = self._lsbuckets(refresh) + return files if detail else sorted([o["name"] for o in files]) + + files = self._lsdir(path, refresh, versions=versions) + if not files and "/" in path: + try: + files = self._lsdir( + self._parent(path), refresh=refresh, versions=versions + ) + except IOError: + pass + files = [ + o + for o in files + if o["name"].rstrip("/") == path and o["type"] != "directory" + ] + + return files if detail else sorted([o["name"] for o in files]) + + def _lsbuckets(self, refresh=False): + """ + List all buckets in the account. + :param refresh: Whether to refresh the cache. + :return: A list of buckets. + """ + if "" not in self.dircache or refresh: + try: + resp = self.tos_client.list_buckets() + except tos.exceptions.TosClientError as e: + logger.error("Tosfs failed with client error: %s", e) + raise e + except tos.exceptions.TosServerError as e: + logger.error("Tosfs failed with server error: %s", e) + raise e + except Exception as e: + logger.error("Tosfs failed with unknown error: %s", e) + raise TosfsError( + f"Tosfs failed with unknown error: {e}" + ) from e + + buckets = [] + for bucket in resp.buckets: + buckets.append( + { + "Key": bucket.name, + "Size": 0, + "StorageClass": "BUCKET", + "size": 0, + "type": "directory", + "name": bucket.name, + } + ) + self.dircache[""] = buckets + + return self.dircache[""] + + def _lsdir( + self, + path, + refresh=False, + max_items: int = 1000, + delimiter="/", + prefix="", + versions=False, + ): + """ + List objects in a bucket, here we use cache to improve performance. + :param path: The path to list. + :param refresh: Whether to refresh the cache. + :param max_items: The maximum number of items to return, default is 1000. # noqa: E501 + :param delimiter: The delimiter to use for grouping objects. + :param prefix: The prefix to use for filtering objects. + :param versions: Whether to list object versions. + :return: A list of objects in the bucket. + """ + bucket, key, _ = self._split_path(path) + if not prefix: + prefix = "" + if key: + prefix = key.lstrip("/") + "/" + prefix + if path not in self.dircache or refresh or not delimiter or versions: + logger.debug("Get directory listing for %s", path) + dirs = [] + files = [] + for obj in self._listdir( + bucket, + max_items=max_items, + delimiter=delimiter, + prefix=prefix, + versions=versions, + ): + if isinstance(obj, CommonPrefixInfo): + dirs.append(self._fill_common_prefix_info(obj, bucket)) + else: + files.append(self._fill_object_info(obj, bucket, versions)) + files += dirs + + if delimiter and files and not versions: + self.dircache[path] = files + return files + return self.dircache[path] + + def _listdir( + self, + bucket, + max_items: int = 1000, + delimiter="/", + prefix="", + versions=False, + ): + """ + List objects in a bucket. + :param bucket: The bucket name. + :param max_items: The maximum number of items to return, default is 1000. # noqa: E501 + :param delimiter: The delimiter to use for grouping objects. + :param prefix: The prefix to use for filtering objects. + :param versions: Whether to list object versions. + :return: A list of objects in the bucket. + """ + if versions and not self.version_aware: + raise ValueError( + "versions cannot be specified if the filesystem is " + "not version aware." + ) + + all_results = [] + is_truncated = True + + try: + if self.version_aware: + key_marker, version_id_marker = None, None + while is_truncated: + resp = self.tos_client.list_object_versions( + bucket, + prefix, + delimiter=delimiter, + max_keys=max_items, + key_marker=key_marker, + version_id_marker=version_id_marker, + ) + is_truncated = resp.is_truncated + all_results.extend( + resp.versions + + resp.common_prefixes + + resp.delete_markers + ) + key_marker, version_id_marker = ( + resp.next_key_marker, + resp.next_version_id_marker, + ) + else: + continuation_token = "" + while is_truncated: + resp = self.tos_client.list_objects_type2( + bucket, + prefix, + start_after=prefix, + delimiter=delimiter, + max_keys=max_items, + continuation_token=continuation_token, + ) + is_truncated = resp.is_truncated + continuation_token = resp.next_continuation_token + + all_results.extend(resp.contents + resp.common_prefixes) + + return all_results + except tos.exceptions.TosClientError as e: + logger.error( + "Tosfs failed with client error, message:%s, cause: %s", + e.message, + e.cause, + ) + raise e + except tos.exceptions.TosServerError as e: + logger.error("Tosfs failed with server error: %s", e) + raise e + except Exception as e: + logger.error("Tosfs failed with unknown error: %s", e) + raise TosfsError(f"Tosfs failed with unknown error: {e}") from e + + def _split_path(self, path) -> Tuple[str, str, Optional[str]]: + """ + Normalise tos path string into bucket and key. + + Parameters + ---------- + path : string + Input path, like `tos://mybucket/path/to/file` + + Examples + -------- + >>> split_path("tos://mybucket/path/to/file") + ['mybucket', 'path/to/file', None] + # pylint: disable=line-too-long + >>> split_path("tos://mybucket/path/to/versioned_file?versionId=some_version_id") # noqa: E501 + ['mybucket', 'path/to/versioned_file', 'some_version_id'] + """ + path = self._strip_protocol(path) + path = path.lstrip("/") + if "/" not in path: + return path, "", None + + bucket, keypart = find_bucket_key(path) + key, _, version_id = keypart.partition("?versionId=") + return ( + bucket, + key, + version_id if self.version_aware and version_id else None, + ) + + @staticmethod + def _fill_common_prefix_info(common_prefix: CommonPrefixInfo, bucket): + return { + "name": common_prefix.prefix[:-1], + "Key": "/".join([bucket, common_prefix.prefix]), + "Size": 0, + "type": "directory", + } + + @staticmethod + def _fill_object_info(obj, bucket, versions=False): + result = { + "Key": f"{bucket}/{obj.key}", + "size": obj.size, + "name": f"{bucket}/{obj.key}", + "type": "file", + } + if ( + isinstance(obj, ListedObjectVersion) + and versions + and obj.version_id + and obj.version_id != "null" + ): + result["name"] += f"?versionId={obj.version_id}" + return result diff --git a/tosfs/exceptions.py b/tosfs/exceptions.py new file mode 100644 index 0000000..fe306c4 --- /dev/null +++ b/tosfs/exceptions.py @@ -0,0 +1,25 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains exceptions definition for the tosfs package. +""" + + +class TosfsError(Exception): + """ + Base class for all tosfs exceptions. + """ + + def __init__(self, message: str): + super().__init__(message) diff --git a/tosfs/tests/conftest.py b/tosfs/tests/conftest.py new file mode 100644 index 0000000..4ecfac9 --- /dev/null +++ b/tosfs/tests/conftest.py @@ -0,0 +1,60 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +from tos import EnvCredentialsProvider + +from tosfs.core import TosFileSystem +from tosfs.utils import random_path + + +@pytest.fixture(scope="module") +def tosfs_env_prepare(): + if "TOS_ACCESS_KEY" not in os.environ: + raise EnvironmentError( + "Can not find TOS_ACCESS_KEY in environment variables." + ) + if "TOS_SECRET_KEY" not in os.environ: + raise EnvironmentError( + "Can not find TOS_SECRET_KEY in environment variables." + ) + + +@pytest.fixture(scope="module") +def tosfs(tosfs_env_prepare): + tosfs = TosFileSystem( + endpoint_url=os.environ.get("TOS_ENDPOINT"), + region=os.environ.get("TOS_REGION"), + credentials_provider=EnvCredentialsProvider(), + ) + yield tosfs + + +@pytest.fixture(scope="module") +def bucket(): + yield os.environ.get("TOS_BUCKET", "proton-ci") + + +@pytest.fixture(autouse=True) +def temporary_workspace(tosfs, bucket): + workspace = random_path() + # currently, make dir via purely tos python client, + # will replace with tosfs.mkdir in the future + tosfs.tos_client.put_object(bucket=bucket, key=f"{workspace}/") + yield workspace + # currently, remove dir via purely tos python client, + # will replace with tosfs.rmdir in the future + tosfs.tos_client.delete_object(bucket=bucket, key=f"{workspace}/") diff --git a/tosfs/tests/test_logging.py b/tosfs/tests/test_logging.py index 51bdb59..5925cd3 100644 --- a/tosfs/tests/test_logging.py +++ b/tosfs/tests/test_logging.py @@ -1,4 +1,4 @@ -# ByteDance Volcengine EMR, Copyright 2022. +# ByteDance Volcengine EMR, Copyright 2024. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index eabb8b8..f282513 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -1,4 +1,4 @@ -# ByteDance Volcengine EMR, Copyright 2022. +# ByteDance Volcengine EMR, Copyright 2024. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,6 +12,59 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import MagicMock -def test_hello_world(): - assert True +import pytest +from tos.exceptions import TosServerError + + +def test_ls_bucket(tosfs, bucket): + assert bucket in tosfs.ls("", detail=False) + detailed_list = tosfs.ls("", detail=True) + assert detailed_list + for item in detailed_list: + assert "type" in item + assert item["type"] == "directory" + assert item["StorageClass"] == "BUCKET" + + with pytest.raises(TosServerError): + tosfs.ls("nonexistent", detail=False) + + +def test_ls_dir(tosfs, bucket, temporary_workspace): + assert temporary_workspace in tosfs.ls(bucket, detail=False) + detailed_list = tosfs.ls(bucket, detail=True) + assert detailed_list + for item in detailed_list: + if item["name"] == temporary_workspace: + assert item["type"] == "directory" + break + else: + assert ( + False + ), f"Directory {temporary_workspace} not found in {detailed_list}" + + assert tosfs.ls(f"{bucket}/{temporary_workspace}", detail=False) == [] + assert ( + tosfs.ls(f"{bucket}/{temporary_workspace}/nonexistent", detail=False) + == [] + ) + + +def test_ls_cache(tosfs, bucket): + tosfs.tos_client.list_objects_type2 = MagicMock( + return_value=MagicMock( + is_truncated=False, + contents=[MagicMock(key="mock_key", size=123)], + common_prefixes=[], + next_continuation_token=None, + ) + ) + + # Call ls method and get result from server + tosfs.ls(bucket, detail=False, refresh=True) + # Get result from cache + tosfs.ls(bucket, detail=False, refresh=False) + + # Verify that list_objects_type2 was called only once + assert tosfs.tos_client.list_objects_type2.call_count == 1 diff --git a/tosfs/tests/test_utils.py b/tosfs/tests/test_utils.py new file mode 100644 index 0000000..39a8684 --- /dev/null +++ b/tosfs/tests/test_utils.py @@ -0,0 +1,32 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from tosfs.utils import find_bucket_key + + +@pytest.mark.parametrize( + "input_str, expected_output", + [ + ("bucket/key", ("bucket", "key")), + ("bucket/", ("bucket", "")), + ("/key", ("", "key")), + ("bucket/key/with/slashes", ("bucket", "key/with/slashes")), + ("bucket", ("bucket", "")), + ("", ("", "")), + ], +) +def test_find_bucket_key(tosfs, input_str, expected_output): + assert find_bucket_key(input_str) == expected_output diff --git a/tosfs/utils.py b/tosfs/utils.py new file mode 100644 index 0000000..dc8caa5 --- /dev/null +++ b/tosfs/utils.py @@ -0,0 +1,74 @@ +# ByteDance Volcengine EMR, Copyright 2024. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module contains utility functions for the tosfs package. +""" + +import random +import re +import string + + +def random_path(length: int = 5) -> str: + """ + Generate a random path(dir or file) of the given length. + + Args: + length (int): The length of the random string. + + Returns: + str: The random string. + """ + return "".join( + random.choices(string.ascii_letters + string.digits, k=length) + ) + + +def find_bucket_key(tos_path): + """ + This is a helper function that given an tos path such that the path + is of the form: bucket/key + It will return the bucket and the key represented by the tos path + """ + bucket_format_list = [ + re.compile( + r"^(?P:tos:[a-z\-0-9]*:[0-9]{12}:accesspoint[:/][^/]+)/?" # noqa: E501 + r"(?P.*)$" + ), + re.compile( + r"^(?P:tos-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]" + # pylint: disable=line-too-long + r"[a-zA-Z0-9\-]{1,63}[/:](bucket|accesspoint)[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" # noqa: E501 + ), + re.compile( + r"^(?P:tos-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]" + r"[a-zA-Z0-9\-]{1,63}[/:]bucket[/:]" + r"[a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" + ), + re.compile( + r"^(?P:tos-object-lambda:[a-z\-0-9]+:[0-9]{12}:" + r"accesspoint[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P.*)$" + ), + ] + for bucket_format in bucket_format_list: + match = bucket_format.match(tos_path) + if match: + return match.group("bucket"), match.group("key") + tos_components = tos_path.split("/", 1) + bucket = tos_components[0] + tos_key = "" + if len(tos_components) > 1: + tos_key = tos_components[1] + return bucket, tos_key