diff --git a/tosfs/consts.py b/tosfs/consts.py index 01cc726..fc82585 100644 --- a/tosfs/consts.py +++ b/tosfs/consts.py @@ -26,3 +26,9 @@ "InternalError", "ServiceUnavailable", } + +MANAGED_COPY_THRESHOLD = 5 * 2**30 + +RETRY_NUM = 5 +PART_MIN_SIZE = 5 * 2**20 +PART_MAX_SIZE = 5 * 2**30 diff --git a/tosfs/core.py b/tosfs/core.py index 59aea2c..6f42b3b 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -33,9 +33,14 @@ UploadPartOutput, ) -from tosfs.consts import TOS_SERVER_RESPONSE_CODE_NOT_FOUND +from tosfs.consts import ( + MANAGED_COPY_THRESHOLD, + PART_MAX_SIZE, + RETRY_NUM, + TOS_SERVER_RESPONSE_CODE_NOT_FOUND, +) from tosfs.exceptions import TosfsError -from tosfs.utils import find_bucket_key, retryable_func_wrapper +from tosfs.utils import find_bucket_key, get_brange, retryable_func_wrapper # environment variable names ENV_NAME_TOSFS_LOGGING_LEVEL = "TOSFS_LOGGING_LEVEL" @@ -66,7 +71,6 @@ class TosFileSystem(AbstractFileSystem): """ protocol = ("tos", "tosfs") - retries = 5 default_block_size = 5 * 2**20 def __init__( @@ -125,7 +129,7 @@ def _open( best support random access. When reading only a few specific chunks out of a file, performance may be better if False. version_id : str - Explicit version of the object to open. This requires that the s3 + Explicit version of the object to open. This requires that the tos filesystem is version aware and bucket versioning is enabled on the relevant bucket. cache_type : str @@ -636,7 +640,7 @@ def _read_chunks(body: BinaryIO, f: BinaryIO) -> None: chunk = body.read(2**16) except tos.exceptions.TosClientError as e: failed_reads += 1 - if failed_reads >= self.retries: + if failed_reads >= self.RETRY_NUM: raise e try: body.close() @@ -772,6 +776,206 @@ def find( else: return [o["name"] for o in out] + def cp_file( + self, + path1: str, + path2: str, + preserve_etag: Optional[bool] = None, + managed_copy_threshold: Optional[int] = MANAGED_COPY_THRESHOLD, + **kwargs: Any, + ) -> None: + """Copy file between locations on tos. + + Parameters + ---------- + path1 : str + The source path of the file to copy. + path2 : str + The destination path of the file to copy. + preserve_etag : bool, optional + Whether to preserve etag while copying. If the file is uploaded + as a single part, then it will be always equivalent to the md5 + hash of the file hence etag will always be preserved. But if the + file is uploaded in multi parts, then this option will try to + reproduce the same multipart upload while copying and preserve + the generated etag. + managed_copy_threshold : int, optional + The threshold size of the file to copy using managed copy. If the + size of the file is greater than this threshold, then the file + will be copied using managed copy (default is 5 * 2**30). + **kwargs : Any, optional + Additional arguments. + + Raises + ------ + FileNotFoundError + If the source file does not exist. + ValueError + If the destination is a versioned file. + TosClientError + If there is a client error while copying the file. + TosServerError + If there is a server error while copying the file. + TosfsError + If there is an unknown error while copying the file. + + """ + path1 = self._strip_protocol(path1) + bucket, key, vers = self._split_path(path1) + + info = self.info(path1, bucket, key, version_id=vers) + size = info["size"] + + _, _, parts_suffix = info.get("ETag", "").strip('"').partition("-") + if preserve_etag and parts_suffix: + self._copy_etag_preserved(path1, path2, size, total_parts=int(parts_suffix)) + elif size <= min( + MANAGED_COPY_THRESHOLD, + ( + managed_copy_threshold + if managed_copy_threshold + else MANAGED_COPY_THRESHOLD + ), + ): + self._copy_basic(path1, path2, **kwargs) + else: + # if the preserve_etag is true, either the file is uploaded + # on multiple parts or the size is lower than 5GB + assert not preserve_etag + + # serial multipart copy + self._copy_managed(path1, path2, size, **kwargs) + + def _copy_basic(self, path1: str, path2: str, **kwargs: Any) -> None: + """Copy file between locations on tos. + + Not allowed where the origin is larger than 5GB. + """ + buc1, key1, ver1 = self._split_path(path1) + buc2, key2, ver2 = self._split_path(path2) + if ver2: + raise ValueError("Cannot copy to a versioned file!") + try: + self.tos_client.copy_object( + bucket=buc2, + key=key2, + src_bucket=buc1, + src_key=key1, + src_version_id=ver1, + ) + except tos.exceptions.TosClientError as e: + raise e + except tos.exceptions.TosServerError as e: + raise e + except Exception as e: + raise TosfsError("Copy failed (%r -> %r): %s" % (path1, path2, e)) from e + + def _copy_etag_preserved( + self, path1: str, path2: str, size: int, total_parts: int, **kwargs: Any + ) -> None: + """Copy file as multiple-part while preserving the etag.""" + bucket1, key1, version1 = self._split_path(path1) + bucket2, key2, version2 = self._split_path(path2) + + upload_id = None + + try: + mpu = self.tos_client.create_multipart_upload(bucket2, key2) + upload_id = mpu.upload_id + + parts = [] + brange_first = 0 + + for i in range(1, total_parts + 1): + part_size = min(size - brange_first, PART_MAX_SIZE) + brange_last = brange_first + part_size - 1 + if brange_last > size: + brange_last = size - 1 + + part = self.tos_client.upload_part_copy( + bucket=bucket2, + key=key2, + part_number=i, + upload_id=upload_id, + src_bucket=bucket1, + src_key=key1, + copy_source_range_start=brange_first, + copy_source_range_end=brange_last, + ) + parts.append( + PartInfo( + part_number=part.part_number, + etag=part.etag, + part_size=size, + offset=None, + hash_crc64_ecma=None, + is_completed=None, + ) + ) + brange_first += part_size + + self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts) + except Exception as e: + self.tos_client.abort_multipart_upload(bucket2, key2, upload_id) + raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e + + def _copy_managed( + self, + path1: str, + path2: str, + size: int, + block: int = MANAGED_COPY_THRESHOLD, + **kwargs: Any, + ) -> None: + """Copy file between locations on tos as multiple-part. + + block: int + The size of the pieces, must be larger than 5MB and at + most MANAGED_COPY_THRESHOLD. + Smaller blocks mean more calls, only useful for testing. + """ + if block < 5 * 2**20 or block > MANAGED_COPY_THRESHOLD: + raise ValueError("Copy block size must be 5MB<=block<=5GB") + + bucket1, key1, version1 = self._split_path(path1) + bucket2, key2, version2 = self._split_path(path2) + + upload_id = None + + try: + mpu = self.tos_client.create_multipart_upload(bucket2, key2) + upload_id = mpu.upload_id + out = [ + self.tos_client.upload_part_copy( + bucket=bucket2, + key=key2, + part_number=i + 1, + upload_id=upload_id, + src_bucket=bucket1, + src_key=key1, + copy_source_range_start=brange_first, + copy_source_range_end=brange_last, + ) + for i, (brange_first, brange_last) in enumerate(get_brange(size, block)) + ] + + parts = [ + PartInfo( + part_number=i + 1, + etag=o.etag, + part_size=size, + offset=None, + hash_crc64_ecma=None, + is_completed=None, + ) + for i, o in enumerate(out) + ] + + self.tos_client.complete_multipart_upload(bucket2, key2, upload_id, parts) + except Exception as e: + self.tos_client.abort_multipart_upload(bucket2, key2, upload_id) + raise TosfsError(f"Copy failed ({path1} -> {path2}): {e}") from e + def _find_file_dir( self, key: str, path: str, prefix: str, withdirs: bool, kwargs: Any ) -> List[dict]: @@ -1397,10 +1601,6 @@ def _fill_bucket_info(bucket_name: str) -> dict: class TosFile(AbstractBufferedFile): """File-like operations for TOS.""" - retries = 5 - part_min = 5 * 2**20 - part_max = 5 * 2**30 - def __init__( self, fs: TosFileSystem, @@ -1530,7 +1730,7 @@ def handle_remainder( # Use the helper function in the main code if 0 < current_chunk_size < self.blocksize: previous_chunk, current_chunk = handle_remainder( - previous_chunk, current_chunk, self.blocksize, self.part_max + previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE ) part = len(self.parts) + 1 if self.parts is not None else 1 @@ -1577,7 +1777,7 @@ def fetch() -> bytes: bucket, key, version_id, range_start=start, range_end=end ).read() - return retryable_func_wrapper(fetch, retries=self.fs.retries) + return retryable_func_wrapper(fetch, retries=RETRY_NUM) def commit(self) -> None: """Complete multipart upload or PUT.""" diff --git a/tosfs/tests/test_tosfs.py b/tosfs/tests/test_tosfs.py index c693ca5..aa1b527 100644 --- a/tosfs/tests/test_tosfs.py +++ b/tosfs/tests/test_tosfs.py @@ -459,6 +459,50 @@ def test_find(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> No tosfs.rmdir(f"{bucket}/{temporary_workspace}") +def test_cp_file(tosfs: TosFileSystem, bucket: str, temporary_workspace: str) -> None: + file_name = random_str() + file_content = "hello world" + src_path = f"{bucket}/{temporary_workspace}/{file_name}" + dest_path = f"{bucket}/{temporary_workspace}/copy_{file_name}" + + with tosfs.open(src_path, "w") as f: + f.write(file_content) + + tosfs.cp_file(src_path, dest_path) + assert tosfs.exists(dest_path) + + with tosfs.open(dest_path, "r") as f: + assert f.read() == file_content + + with pytest.raises(FileNotFoundError): + tosfs.cp_file(f"{bucket}/{temporary_workspace}/nonexistent", dest_path) + + sub_dir_name = random_str() + dest_path = f"{bucket}/{temporary_workspace}/{sub_dir_name}" + tosfs.cp_file(src_path, dest_path) + assert tosfs.exists(dest_path) + with tosfs.open(dest_path, "r") as f: + assert f.read() == file_content + + file_content = "a" * 2048 # 2KB content + with tosfs.open(src_path, "w") as f: + f.write(file_content) + + tosfs.cp_file(src_path, dest_path, managed_copy_threshold=1024) + assert tosfs.exists(dest_path) + + with tosfs.open(dest_path, "r") as f: + assert f.read() == file_content + + # Test cp_file with preserve_etag=True + dest_path_with_etag = f"{bucket}/{temporary_workspace}/etag_{file_name}" + tosfs.cp_file(dest_path, dest_path_with_etag, preserve_etag=True) + assert tosfs.exists(dest_path_with_etag) + with tosfs.open(dest_path_with_etag, "r") as f: + assert f.read() == file_content + assert tosfs.info(dest_path_with_etag)["ETag"] == tosfs.info(dest_path)["ETag"] + + ########################################################### # File operation tests # ########################################################### diff --git a/tosfs/utils.py b/tosfs/utils.py index 43d8074..d949877 100644 --- a/tosfs/utils.py +++ b/tosfs/utils.py @@ -19,7 +19,7 @@ import string import tempfile import time -from typing import Any, Optional, Tuple +from typing import Any, Generator, Optional, Tuple import tos @@ -91,6 +91,21 @@ def find_bucket_key(tos_path: str) -> Tuple[str, str]: return bucket, tos_key +def get_brange(size: int, block: int) -> Generator[Tuple[int, int], None, None]: + """Chunk up a file into zero-based byte ranges. + + Parameters + ---------- + size : int + file size + block : int + block size + + """ + for offset in range(0, size, block): + yield offset, min(offset + block - 1, size - 1) + + def retryable_func_wrapper( func: Any, *, args: tuple[()] = (), kwargs: Optional[Any] = None, retries: int = 5 ) -> Any: