Skip to content

Commit

Permalink
Refactor unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Aug 15, 2024
1 parent c9eaa77 commit d261732
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 117 deletions.
234 changes: 121 additions & 113 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@
"""
import logging
import os
import re
from typing import Optional, Tuple

import tos
from fsspec import AbstractFileSystem
from fsspec.utils import setup_logging as setup_logger
from tos.models import CommonPrefixInfo
from tos.models2 import ListedObjectVersion

from tosfs.utils import find_bucket_key

# environment variable names
ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL"
Expand Down Expand Up @@ -74,40 +77,49 @@ def __init__(
super().__init__(**kwargs)

def ls(self, path, detail=False, refresh=False, versions=False, **kwargs):
"""
List objects under the given path.
:param path: The path to list.
:param detail: Whether to return detailed information.
:param refresh: Whether to refresh the cache.
:param versions: Whether to list object versions.
:param kwargs: Additional arguments.
:return: A list of objects under the given path.
"""
path = self._strip_protocol(path).rstrip("/")
if path in ["", "/"]:
files = self._lsbuckets(refresh)
else:
files = self._lsdir(path, refresh, versions=versions)
if not files and "/" in path:
try:
files = self._lsdir(
self._parent(path), refresh=refresh, versions=versions
)
except IOError:
pass
files = [
o
for o in files
if o["name"].rstrip("/") == path
and o["type"] != "directory"
]
if not files:
raise FileNotFoundError(path)
if detail:
return files
return files if detail else sorted([o.name for o in files])
return files if detail else sorted([o.name for o in files])

files = self._lsdir(path, refresh, versions=versions)
if not files and "/" in path:
try:
files = self._lsdir(
self._parent(path), refresh=refresh, versions=versions
)
except IOError:
pass
files = [
o
for o in files
if o["name"].rstrip("/") == path and o["type"] != "directory"
]
if detail:
return files

return files if detail else sorted([o["name"] for o in files])

def _lsbuckets(self, refresh=False):
"""
List all buckets in the account.
:param refresh: Whether to refresh the cache.
:return: A list of buckets.
"""
if "" not in self.dircache or refresh:
try:
resp = self.tos_client.list_buckets()
except tos.exceptions.TosClientError as e:
logger.error(
"Tosfs failed with client error, message:%s, cause: %s",
e.message,
e.cause,
)
logger.error("Tosfs failed with client error: %s", e)
return []
except tos.exceptions.TosServerError as e:
logger.error("Tosfs failed with server error: %s", e)
Expand All @@ -117,18 +129,27 @@ def _lsbuckets(self, refresh=False):
return []

self.dircache[""] = resp.buckets
return resp.buckets
return self.dircache[""]

def _lsdir(
self,
path,
refresh=False,
max_items=None,
max_items: int = 1000,
delimiter="/",
prefix="",
versions=False,
):
"""
List objects in a bucket, here we use cache to improve performance.
:param path: The path to list.
:param refresh: Whether to refresh the cache.
:param max_items: The maximum number of items to return, default is 1000. # noqa: E501
:param delimiter: The delimiter to use for grouping objects.
:param prefix: The prefix to use for filtering objects.
:param versions: Whether to list object versions.
:return: A list of objects in the bucket.
"""
bucket, key, _ = self.split_path(path)
if not prefix:
prefix = ""
Expand All @@ -145,10 +166,10 @@ def _lsdir(
prefix=prefix,
versions=versions,
):
if obj["type"] == "directory":
dirs.append(obj)
if isinstance(obj, CommonPrefixInfo):
dirs.append(self._fill_common_prefix_info(obj, bucket))
else:
files.append(obj)
files.append(self._fill_object_info(obj, bucket, versions))
files += dirs

if delimiter and files and not versions:
Expand All @@ -157,60 +178,68 @@ def _lsdir(
return self.dircache[path]

def _listdir(
self, bucket, max_items=None, delimiter="/", prefix="", versions=False
self,
bucket,
max_items: int = 1000,
delimiter="/",
prefix="",
versions=False,
):
"""
List objects in a bucket.
:param bucket: The bucket name.
:param max_items: The maximum number of items to return, default is 1000. # noqa: E501
:param delimiter: The delimiter to use for grouping objects.
:param prefix: The prefix to use for filtering objects.
:param versions: Whether to list object versions.
:return: A list of objects in the bucket.
"""
if versions and not self.version_aware:
raise ValueError(
"versions cannot be specified if the filesystem is "
"not version aware."
)

paging_fetch = max_items is None
pag_size = 50

all_results = []
start_after = None
version_id_marker = None
is_truncated = True

try:
if self.version_aware:
if paging_fetch:
while True:
resp = self.tos_client.list_object_versions(
bucket,
prefix,
delimiter=delimiter,
key_marker=start_after,
version_id_marker=version_id_marker,
max_keys=pag_size,
)
if not resp.versions:
break
all_results.extend(resp.versions)
start_after = resp.versions[-1].key
version_id_marker = resp.versions[-1].version_id
key_marker, version_id_marker = None, None
while is_truncated:
resp = self.tos_client.list_object_versions(
bucket,
prefix,
delimiter=delimiter,
max_keys=max_items,
key_marker=key_marker,
version_id_marker=version_id_marker,
)
is_truncated = resp.is_truncated
all_results.extend(
resp.versions
+ resp.common_prefixes
+ resp.delete_markers
)
key_marker, version_id_marker = (
resp.next_key_marker,
resp.next_version_id_marker,
)
else:
if paging_fetch:
while True:
resp = self.tos_client.list_objects_type2(
bucket,
prefix,
delimiter=delimiter,
start_after=start_after,
max_keys=pag_size,
)
if not resp.contents:
break
all_results.extend(resp.contents)
start_after = resp.contents[-1].key
else:
continuation_token = ""
while is_truncated:
resp = self.tos_client.list_objects_type2(
bucket,
prefix,
start_after=start_after,
start_after=prefix,
delimiter=delimiter,
max_keys=max_items,
continuation_token=continuation_token,
)
all_results.extend(resp.contents)
is_truncated = resp.is_truncated
continuation_token = resp.next_continuation_token

all_results.extend(resp.contents + resp.common_prefixes)

return all_results
except tos.exceptions.TosClientError as e:
Expand Down Expand Up @@ -249,57 +278,36 @@ def split_path(self, path) -> Tuple[str, str, Optional[str]]:
if "/" not in path:
return path, "", None

bucket, keypart = self._find_bucket_key(path)
bucket, keypart = find_bucket_key(path)
key, _, version_id = keypart.partition("?versionId=")
return (
bucket,
key,
version_id if self.version_aware and version_id else None,
)

def _find_bucket_key(self, tos_path):
"""
This is a helper function that given an tos path such that the path
is of the form: bucket/key
It will return the bucket and the key represented by the tos path
"""

bucket_format_list = [
re.compile(
r"^(?P<bucket>:tos:[a-z\-0-9]*:[0-9]{12}:accesspoint[:/][^/]+)/?" # noqa: E501
r"(?P<key>.*)$"
),
re.compile(
r"^(?P<bucket>:tos-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]"
# pylint: disable=line-too-long
r"[a-zA-Z0-9\-]{1,63}[/:](bucket|accesspoint)[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P<key>.*)$" # noqa: E501
),
re.compile(
r"^(?P<bucket>:tos-outposts:[a-z\-0-9]+:[0-9]{12}:outpost[/:]"
r"[a-zA-Z0-9\-]{1,63}[/:]bucket[/:]"
r"[a-zA-Z0-9\-]{1,63})[/:]?(?P<key>.*)$"
),
re.compile(
r"^(?P<bucket>:tos-object-lambda:[a-z\-0-9]+:[0-9]{12}:"
r"accesspoint[/:][a-zA-Z0-9\-]{1,63})[/:]?(?P<key>.*)$"
),
]
for bucket_format in bucket_format_list:
match = bucket_format.match(tos_path)
if match:
return match.group("bucket"), match.group("key")
tos_components = tos_path.split("/", 1)
bucket = tos_components[0]
tos_key = ""
if len(tos_components) > 1:
tos_key = tos_components[1]
return bucket, tos_key
@staticmethod
def _fill_common_prefix_info(common_prefix: CommonPrefixInfo, bucket):
return {
"name": common_prefix.prefix[:-1],
"Key": "/".join([bucket, common_prefix.prefix]),
"Size": 0,
"type": "directory",
}

@staticmethod
def _fill_info(f, bucket, versions=False):
f["size"] = f["Size"]
f["Key"] = "/".join([bucket, f["Key"]])
f["name"] = f["Key"]
version_id = f.get("VersionId")
if versions and version_id and version_id != "null":
f["name"] += f"?versionId={version_id}"
def _fill_object_info(obj, bucket, versions=False):
result = {
"Key": f"{bucket}/{obj.key}",
"size": obj.size,
"name": f"{bucket}/{obj.key}",
"type": "file",
}
if (
isinstance(obj, ListedObjectVersion)
and versions
and obj.version_id
and obj.version_id != "null"
):
result["name"] += f"?versionId={obj.version_id}"
return result
18 changes: 18 additions & 0 deletions tosfs/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tos import EnvCredentialsProvider

from tosfs.core import TosFileSystem
from tosfs.utils import random_path


@pytest.fixture(scope="module")
Expand All @@ -40,3 +41,20 @@ def tosfs(tosfs_env_prepare):
credentials_provider=EnvCredentialsProvider(),
)
yield tosfs


@pytest.fixture(scope="module")
def bucket():
yield os.environ.get("TOS_BUCKET", "proton-ci")


@pytest.fixture(autouse=True)
def temporary_workspace(tosfs, bucket):
workspace = random_path()
# currently, make dir via purely tos python client,
# will replace with tosfs.mkdir in the future
tosfs.tos_client.put_object(bucket=bucket, key=f"{workspace}/")
yield workspace
# currently, remove dir via purely tos python client,
# will replace with tosfs.rmdir in the future
tosfs.tos_client.delete_object(bucket=bucket, key=f"{workspace}/")
12 changes: 9 additions & 3 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@

import os

test_bucket = os.environ.get("TEST_BUCKET", "proton-ci")

def test_ls(tosfs):
test_bucket_name = os.environ.get("TEST_BUCKET", "proton-ci")
assert test_bucket_name in set(tosfs.ls("", detail=False))

def test_ls_bucket(tosfs, bucket):
assert bucket in tosfs.ls("", detail=False)
assert (
tosfs.ls("nonexistent") == []
), "Nonexistent path should return empty list"


def test_ls_dir(tosfs, bucket, temporary_workspace):
# assert temporary_workspace in tosfs.ls(bucket, detail=False)
assert tosfs.ls(f"{bucket}/{temporary_workspace}", detail=False) == []
4 changes: 3 additions & 1 deletion tosfs/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

import pytest

from tosfs.utils import find_bucket_key


@pytest.mark.parametrize(
"input_str, expected_output",
Expand All @@ -27,4 +29,4 @@
],
)
def test_find_bucket_key(tosfs, input_str, expected_output):
assert tosfs._find_bucket_key(input_str) == expected_output
assert find_bucket_key(input_str) == expected_output
Loading

0 comments on commit d261732

Please sign in to comment.