Skip to content

Commit

Permalink
Core: Implement cp_file API (#55)
Browse files Browse the repository at this point in the history
* Core: Implement cp_file API

* Reformat code
  • Loading branch information
yanghua authored Sep 5, 2024
1 parent 5f8cdaa commit 9b403e3
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 12 deletions.
6 changes: 6 additions & 0 deletions tosfs/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,9 @@
"InternalError",
"ServiceUnavailable",
}

MANAGED_COPY_THRESHOLD = 5 * 2**30

RETRY_NUM = 5
PART_MIN_SIZE = 5 * 2**20
PART_MAX_SIZE = 5 * 2**30
222 changes: 211 additions & 11 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@
UploadPartOutput,
)

from tosfs.consts import TOS_SERVER_RESPONSE_CODE_NOT_FOUND
from tosfs.consts import (
MANAGED_COPY_THRESHOLD,
PART_MAX_SIZE,
RETRY_NUM,
TOS_SERVER_RESPONSE_CODE_NOT_FOUND,
)
from tosfs.exceptions import TosfsError
from tosfs.utils import find_bucket_key, retryable_func_wrapper
from tosfs.utils import find_bucket_key, get_brange, retryable_func_wrapper

# environment variable names
ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL"
Expand Down Expand Up @@ -66,7 +71,6 @@ class TosFileSystem(AbstractFileSystem):
"""

protocol = ("tos", "tosfs")
retries = 5
default_block_size = 5 * 2**20

def __init__(
Expand Down Expand Up @@ -125,7 +129,7 @@ def _open(
best support random access. When reading only a few specific chunks
out of a file, performance may be better if False.
version_id : str
Explicit version of the object to open. This requires that the s3
Explicit version of the object to open. This requires that the tos
filesystem is version aware and bucket versioning is enabled on the
relevant bucket.
cache_type : str
Expand Down Expand Up @@ -636,7 +640,7 @@ def _read_chunks(body: BinaryIO, f: BinaryIO) -> None:
chunk = body.read(2**16)
except tos.exceptions.TosClientError as e:
failed_reads += 1
if failed_reads >= self.retries:
if failed_reads >= self.RETRY_NUM:
raise e
try:
body.close()
Expand Down Expand Up @@ -772,6 +776,206 @@ def find(
else:
return [o["name"] for o in out]

def cp_file(
self,
path1: str,
path2: str,
preserve_etag: Optional[bool] = None,
managed_copy_threshold: Optional[int] = MANAGED_COPY_THRESHOLD,
**kwargs: Any,
) -> None:
"""Copy file between locations on tos.
Parameters
----------
path1 : str
The source path of the file to copy.
path2 : str
The destination path of the file to copy.
preserve_etag : bool, optional
Whether to preserve etag while copying. If the file is uploaded
as a single part, then it will be always equivalent to the md5
hash of the file hence etag will always be preserved. But if the
file is uploaded in multi parts, then this option will try to
reproduce the same multipart upload while copying and preserve
the generated etag.
managed_copy_threshold : int, optional
The threshold size of the file to copy using managed copy. If the
size of the file is greater than this threshold, then the file
will be copied using managed copy (default is 5 * 2**30).
**kwargs : Any, optional
Additional arguments.
Raises
------
FileNotFoundError
If the source file does not exist.
ValueError
If the destination is a versioned file.
TosClientError
If there is a client error while copying the file.
TosServerError
If there is a server error while copying the file.
TosfsError
If there is an unknown error while copying the file.
"""
path1 = self._strip_protocol(path1)
bucket, key, vers = self._split_path(path1)

info = self.info(path1, bucket, key, version_id=vers)
size = info["size"]

_, _, parts_suffix = info.get("ETag", "").strip('"').partition("-")
if preserve_etag and parts_suffix:
self._copy_etag_preserved(path1, path2, size, total_parts=int(parts_suffix))
elif size <= min(
MANAGED_COPY_THRESHOLD,
(
managed_copy_threshold
if managed_copy_threshold
else MANAGED_COPY_THRESHOLD
),
):
self._copy_basic(path1, path2, **kwargs)
else:
# if the preserve_etag is true, either the file is uploaded
# on multiple parts or the size is lower than 5GB
assert not preserve_etag

# serial multipart copy
self._copy_managed(path1, path2, size, **kwargs)

def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None:
"""Copy file between locations on tos.
Not allowed where the origin is larger than 5GB.
"""
buc1, key1, ver1 = self._split_path(path1)
buc2, key2, ver2 = self._split_path(path2)
if ver2:
raise ValueError("Cannot copy to a versioned file!")
try:
self.tos_client.copy_object(
bucket=buc2,
key=key2,
src_bucket=buc1,
src_key=key1,
src_version_id=ver1,
)
except tos.exceptions.TosClientError as e:
raise e
except tos.exceptions.TosServerError as e:
raise e
except Exception as e:
raise TosfsError("Copy failed (%r -> %r): %s" % (path1, path2, e)) from e

def _copy_etag_preserved(
self, path1: str, path2: str, size: int, total_parts: int, **kwargs: Any
) -> None:
"""Copy file as multiple-part while preserving the etag."""
bucket1, key1, version1 = self._split_path(path1)
bucket2, key2, version2 = self._split_path(path2)

upload_id = None

try:
mpu = self.tos_client.create_multipart_upload(bucket2, key2)
upload_id = mpu.upload_id

parts = []
brange_first = 0

for i in range(1, total_parts + 1):
part_size = min(size - brange_first, PART_MAX_SIZE)
brange_last = brange_first + part_size - 1
if brange_last > size:
brange_last = size - 1

part = self.tos_client.upload_part_copy(
bucket=bucket2,
key=key2,
part_number=i,
upload_id=upload_id,
src_bucket=bucket1,
src_key=key1,
copy_source_range_start=brange_first,
copy_source_range_end=brange_last,
)
parts.append(
PartInfo(
part_number=part.part_number,
etag=part.etag,
part_size=size,
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)
)
brange_first += part_size

self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts)
except Exception as e:
self.tos_client.abort_multipart_upload(bucket2, key2, upload_id)
raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e

def _copy_managed(
self,
path1: str,
path2: str,
size: int,
block: int = MANAGED_COPY_THRESHOLD,
**kwargs: Any,
) -> None:
"""Copy file between locations on tos as multiple-part.
block: int
The size of the pieces, must be larger than 5MB and at
most MANAGED_COPY_THRESHOLD.
Smaller blocks mean more calls, only useful for testing.
"""
if block < 5 * 2**20 or block > MANAGED_COPY_THRESHOLD:
raise ValueError("Copy block size must be 5MB<=block<=5GB")

bucket1, key1, version1 = self._split_path(path1)
bucket2, key2, version2 = self._split_path(path2)

upload_id = None

try:
mpu = self.tos_client.create_multipart_upload(bucket2, key2)
upload_id = mpu.upload_id
out = [
self.tos_client.upload_part_copy(
bucket=bucket2,
key=key2,
part_number=i + 1,
upload_id=upload_id,
src_bucket=bucket1,
src_key=key1,
copy_source_range_start=brange_first,
copy_source_range_end=brange_last,
)
for i, (brange_first, brange_last) in enumerate(get_brange(size, block))
]

parts = [
PartInfo(
part_number=i + 1,
etag=o.etag,
part_size=size,
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)
for i, o in enumerate(out)
]

self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts)
except Exception as e:
self.tos_client.abort_multipart_upload(bucket2, key2, upload_id)
raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e

def _find_file_dir(
self, key: str, path: str, prefix: str, withdirs: bool, kwargs: Any
) -> List[dict]:
Expand Down Expand Up @@ -1397,10 +1601,6 @@ def _fill_bucket_info(bucket_name: str) -> dict:
class TosFile(AbstractBufferedFile):
"""File-like operations for TOS."""

retries = 5
part_min = 5 * 2**20
part_max = 5 * 2**30

def __init__(
self,
fs: TosFileSystem,
Expand Down Expand Up @@ -1530,7 +1730,7 @@ def handle_remainder(
# Use the helper function in the main code
if 0 < current_chunk_size < self.blocksize:
previous_chunk, current_chunk = handle_remainder(
previous_chunk, current_chunk, self.blocksize, self.part_max
previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE
)

part = len(self.parts) + 1 if self.parts is not None else 1
Expand Down Expand Up @@ -1577,7 +1777,7 @@ def fetch() -> bytes:
bucket, key, version_id, range_start=start, range_end=end
).read()

return retryable_func_wrapper(fetch, retries=self.fs.retries)
return retryable_func_wrapper(fetch, retries=RETRY_NUM)

def commit(self) -> None:
"""Complete multipart upload or PUT."""
Expand Down
44 changes: 44 additions & 0 deletions tosfs/tests/test_tosfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,50 @@ def test_find(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> No
tosfs.rmdir(f"{bucket}/{temporary_workspace}")


def test_cp_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None:
file_name = random_str()
file_content = "hello world"
src_path = f"{bucket}/{temporary_workspace}/{file_name}"
dest_path = f"{bucket}/{temporary_workspace}/copy_{file_name}"

with tosfs.open(src_path, "w") as f:
f.write(file_content)

tosfs.cp_file(src_path, dest_path)
assert tosfs.exists(dest_path)

with tosfs.open(dest_path, "r") as f:
assert f.read() == file_content

with pytest.raises(FileNotFoundError):
tosfs.cp_file(f"{bucket}/{temporary_workspace}/nonexistent", dest_path)

sub_dir_name = random_str()
dest_path = f"{bucket}/{temporary_workspace}/{sub_dir_name}"
tosfs.cp_file(src_path, dest_path)
assert tosfs.exists(dest_path)
with tosfs.open(dest_path, "r") as f:
assert f.read() == file_content

file_content = "a" * 2048 # 2KB content
with tosfs.open(src_path, "w") as f:
f.write(file_content)

tosfs.cp_file(src_path, dest_path, managed_copy_threshold=1024)
assert tosfs.exists(dest_path)

with tosfs.open(dest_path, "r") as f:
assert f.read() == file_content

# Test cp_file with preserve_etag=True
dest_path_with_etag = f"{bucket}/{temporary_workspace}/etag_{file_name}"
tosfs.cp_file(dest_path, dest_path_with_etag, preserve_etag=True)
assert tosfs.exists(dest_path_with_etag)
with tosfs.open(dest_path_with_etag, "r") as f:
assert f.read() == file_content
assert tosfs.info(dest_path_with_etag)["ETag"] == tosfs.info(dest_path)["ETag"]


###########################################################
# File operation tests #
###########################################################
Expand Down
17 changes: 16 additions & 1 deletion tosfs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import string
import tempfile
import time
from typing import Any, Optional, Tuple
from typing import Any, Generator, Optional, Tuple

import tos

Expand Down Expand Up @@ -91,6 +91,21 @@ def find_bucket_key(tos_path: str) -> Tuple[str, str]:
return bucket, tos_key


def get_brange(size: int, block: int) -> Generator[Tuple[int, int], None, None]:
"""Chunk up a file into zero-based byte ranges.
Parameters
----------
size : int
file size
block : int
block size
"""
for offset in range(0, size, block):
yield offset, min(offset + block - 1, size - 1)


def retryable_func_wrapper(
func: Any, *, args: tuple[()] = (), kwargs: Optional[Any] = None, retries: int = 5
) -> Any:
Expand Down

0 comments on commit 9b403e3

Please sign in to comment.