diff --git a/tosfs/core.py b/tosfs/core.py index c22279a..a7ced85 100644 --- a/tosfs/core.py +++ b/tosfs/core.py @@ -14,10 +14,13 @@ """The core module of TOSFS.""" import io +import itertools import logging import mimetypes import os +import tempfile import time +from concurrent.futures import ThreadPoolExecutor from glob import has_magic from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union @@ -35,7 +38,6 @@ ListObjectVersionsOutput, PartInfo, UploadPartCopyOutput, - UploadPartOutput, ) from tosfs.consts import ( @@ -113,6 +115,11 @@ def __init__( default_block_size: Optional[int] = None, default_fill_cache: bool = True, default_cache_type: str = "readahead", + multipart_staging_dirs: str = tempfile.mkdtemp(), + multipart_size: int = 8 << 20, + multipart_thread_pool_size: int = max(2, os.cpu_count() or 1), + multipart_staging_buffer_size: int = 4 << 10, + multipart_threshold: int = 10 << 20, **kwargs: Any, ) -> None: """Initialise the TosFileSystem. @@ -157,6 +164,26 @@ def __init__( Whether to fill the cache (default is True). default_cache_type : str, optional The default cache type (default is 'readahead'). + multipart_staging_dirs : str, optional + The staging directories for multipart uploads (default is a temporary + directory). Separate the staging dirs with comma if there are many + staging dir paths. + multipart_size : int, optional + The multipart upload part size of the given object storage. + (default is 8MB). + multipart_thread_pool_size : int, optional + The size of thread pool used for uploading multipart in parallel for the + given object storage. (default is max(2, os.cpu_count()). + multipart_staging_buffer_size : int, optional + The max byte size which will buffer the staging data in-memory before + flushing to the staging file. It will decrease the random write in local + staging disk dramatically if writing plenty of small files. + (default is 4096). + multipart_threshold : int, optional + The threshold which control whether enable multipart upload during + writing data to the given object storage, if the write data size is less + than threshold, will write data via simple put instead of multipart upload. + (default is 10 MB). kwargs : Any, optional Additional arguments. @@ -184,6 +211,14 @@ def __init__( self.default_cache_type = default_cache_type self.max_retry_num = max_retry_num + self.multipart_staging_dirs = [ + d.strip() for d in multipart_staging_dirs.split(",") + ] + self.multipart_size = multipart_size + self.multipart_thread_pool_size = multipart_thread_pool_size + self.multipart_staging_buffer_size = multipart_staging_buffer_size + self.multipart_threshold = multipart_threshold + super().__init__(**kwargs) def _open( @@ -1998,10 +2033,19 @@ def __init__( self.mode = mode self.autocommit = autocommit self.mpu: CreateMultipartUploadOutput = None - self.parts: Optional[list] = None + self.parts: list = [] self.append_block = False self.buffer: Optional[io.BytesIO] = io.BytesIO() + self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs) + self.part_size = fs.multipart_size + self.thread_pool_size = fs.multipart_thread_pool_size + self.staging_buffer_size = fs.multipart_staging_buffer_size + self.multipart_threshold = fs.multipart_threshold + self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size) + self.staging_files: list[str] = [] + self.staging_buffer: io.BytesIO = io.BytesIO() + if "a" in mode and fs.exists(path): head = retryable_func_executor( lambda: self.fs.tos_client.head_object(bucket, key), @@ -2022,27 +2066,6 @@ def _initiate_upload(self) -> None: # only happens when closing small file, use on-shot PUT return logger.debug("Initiate upload for %s", self) - self.parts = [] - - self.mpu = retryable_func_executor( - lambda: self.fs.tos_client.create_multipart_upload(self.bucket, self.key), - max_retry_num=self.fs.max_retry_num, - ) - - if self.append_block: - # use existing data in key when appending, - # and block is big enough - out = retryable_func_executor( - lambda: self.fs.tos_client.upload_part_copy( - bucket=self.bucket, - key=self.key, - part_number=1, - upload_id=self.mpu.upload_id, - ), - max_retry_num=self.fs.max_retry_num, - ) - - self.parts.append({"PartNumber": out.part_number, "ETag": out.etag}) def _upload_chunk(self, final: bool = False) -> bool: """Write one part of a multi-block file upload. @@ -2068,11 +2091,35 @@ def _upload_chunk(self, final: bool = False) -> bool: self.autocommit and not self.append_block and final - and self.tell() < self.blocksize + and self.tell() < max(self.blocksize, self.multipart_threshold) ): # only happens when closing small file, use one-shot PUT pass else: + self.parts = [] + + self.mpu = retryable_func_executor( + lambda: self.fs.tos_client.create_multipart_upload( + self.bucket, self.key + ), + max_retry_num=self.fs.max_retry_num, + ) + + if self.append_block: + # use existing data in key when appending, + # and block is big enough + out = retryable_func_executor( + lambda: self.fs.tos_client.upload_part_copy( + bucket=self.bucket, + key=self.key, + part_number=1, + upload_id=self.mpu.upload_id, + ), + max_retry_num=self.fs.max_retry_num, + ) + + self.parts.append({"PartNumber": out.part_number, "ETag": out.etag}) + self._upload_multiple_chunks(bucket, key) if self.autocommit and final: @@ -2083,75 +2130,71 @@ def _upload_chunk(self, final: bool = False) -> bool: def _upload_multiple_chunks(self, bucket: str, key: str) -> None: if self.buffer: self.buffer.seek(0) - current_chunk: Optional[bytes] = self.buffer.read(self.blocksize) + while True: + chunk = self.buffer.read(self.part_size) + if not chunk: + break - while current_chunk: - (previous_chunk, current_chunk) = ( - current_chunk, - self.buffer.read(self.blocksize) if self.buffer else None, - ) - current_chunk_size = len(current_chunk if current_chunk else b"") - - # Define a helper function to handle the remainder logic - def handle_remainder( - previous_chunk: bytes, - current_chunk: Optional[bytes], - blocksize: int, - part_max: int, - ) -> Tuple[bytes, Optional[bytes]]: - if current_chunk: - remainder = previous_chunk + current_chunk - else: - remainder = previous_chunk + self._write_to_staging_buffer(chunk) - remainder_size = ( - blocksize + len(current_chunk) if current_chunk else blocksize - ) + def _write_to_staging_buffer(self, chunk: bytes) -> None: + self.staging_buffer.write(chunk) + if self.staging_buffer.tell() >= self.staging_buffer_size: + self._flush_staging_buffer() - if remainder_size <= part_max: - return remainder, None - else: - partition = remainder_size // 2 - return remainder[:partition], remainder[partition:] + def _flush_staging_buffer(self) -> None: + if self.staging_buffer.tell() == 0: + return - # Use the helper function in the main code - if 0 < current_chunk_size < self.blocksize: - previous_chunk, current_chunk = handle_remainder( - previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE - ) + self.staging_buffer.seek(0) + staging_dir = next(self.staging_dirs) + with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp: + tmp.write(self.staging_buffer.read()) + self.staging_files.append(tmp.name) + + self.staging_buffer = io.BytesIO() + + def _upload_staged_files(self) -> None: + self._flush_staging_buffer() + futures = [] + for i, staging_file in enumerate(self.staging_files): + part_number = i + 1 + futures.append( + self.executor.submit( + self._upload_part_from_file, staging_file, part_number + ) + ) - part = len(self.parts) + 1 if self.parts is not None else 1 - logger.debug("Upload chunk %s, %s", self, part) + for future in futures: + part_info = future.result() + self.parts.append(part_info) - def _call_upload_part( - part: int = part, previous_chunk: Optional[bytes] = previous_chunk - ) -> UploadPartOutput: - return self.fs.tos_client.upload_part( - bucket=bucket, - key=key, - part_number=part, - upload_id=self.mpu.upload_id, - content=previous_chunk, - ) + self.staging_files = [] - out = retryable_func_executor( - _call_upload_part, max_retry_num=self.fs.max_retry_num - ) + def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo: + with open(staging_file, "rb") as f: + content = f.read() - ( - self.parts.append( - PartInfo( - part_number=part, - etag=out.etag, - part_size=len(previous_chunk), - offset=None, - hash_crc64_ecma=None, - is_completed=None, - ) - ) - if self.parts is not None - else None - ) + out = retryable_func_executor( + lambda: self.fs.tos_client.upload_part( + bucket=self.bucket, + key=self.key, + part_number=part_number, + upload_id=self.mpu.upload_id, + content=content, + ), + max_retry_num=self.fs.max_retry_num, + ) + + os.remove(staging_file) + return PartInfo( + part_number=part_number, + etag=out.etag, + part_size=len(content), + offset=None, + hash_crc64_ecma=None, + is_completed=None, + ) def _fetch_range(self, start: int, end: int) -> bytes: if start == end: @@ -2184,7 +2227,7 @@ def commit(self) -> None: logger.debug("Empty file committed %s", self) self._abort_mpu() self.fs.touch(self.path, **self.kwargs) - elif not self.parts: + elif not self.staging_files: if self.buffer is not None: logger.debug("One-shot upload of %s", self) self.buffer.seek(0) @@ -2198,6 +2241,7 @@ def commit(self) -> None: else: raise RuntimeError else: + self._upload_staged_files() logger.debug("Complete multi-part upload for %s ", self) write_result = retryable_func_executor( lambda: self.fs.tos_client.complete_multipart_upload( diff --git a/tosfs/stability.py b/tosfs/stability.py index f8479ad..9293959 100644 --- a/tosfs/stability.py +++ b/tosfs/stability.py @@ -32,8 +32,10 @@ from tosfs.exceptions import TosfsError +CONFLICT_CODE = "409" + TOS_SERVER_RETRYABLE_STATUS_CODES = { - "409", # CONFLICT + CONFLICT_CODE, # CONFLICT "429", # TOO_MANY_REQUESTS "500", # INTERNAL_SERVER_ERROR } @@ -93,13 +95,17 @@ def retryable_func_executor( raise e if is_retryable_exception(e): - logger.warn("Retry TOS request in the %d times, error: %s", attempt, e) + logger.warning( + "Retry TOS request in the %d times, error: %s", attempt, e + ) try: time.sleep(min(1.7**attempt * 0.1, 15)) except InterruptedError as ie: raise TosfsError(f"Request {func} interrupted.") from ie else: raise e + # Note: maybe not all the retryable exceptions are warped by `TosError` + # Will pay attention to those cases except Exception as e: raise TosfsError(f"{e}") from e @@ -112,13 +118,14 @@ def is_retryable_exception(e: TosError) -> bool: def _is_retryable_tos_server_exception(e: TosError) -> bool: - return ( - isinstance(e, TosServerError) - and e.status_code in TOS_SERVER_RETRYABLE_STATUS_CODES - # exclude some special error code under 409(conflict) status code - # let it fast fail - and e.code not in TOS_SERVER_NOT_RETRYABLE_CONFLICT_ERROR_CODES - ) + if not isinstance(e, TosServerError): + return False + + # not all conflict errors are retryable + if e.status_code == CONFLICT_CODE: + return e.code not in TOS_SERVER_NOT_RETRYABLE_CONFLICT_ERROR_CODES + + return e.status_code in TOS_SERVER_RETRYABLE_STATUS_CODES def _is_retryable_tos_client_exception(e: TosError) -> bool: