Skip to content

Commit

Permalink
Optimize: Introduce multiple disk write for MPU
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 20, 2024
1 parent 93bd8b1 commit eb03197
Showing 1 changed file with 129 additions and 88 deletions.
217 changes: 129 additions & 88 deletions tosfs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

"""The core module of TOSFS."""
import io
import itertools
import logging
import mimetypes
import os
import tempfile
import time
from concurrent.futures import ThreadPoolExecutor
from glob import has_magic
from typing import Any, BinaryIO, Collection, Generator, List, Optional, Tuple, Union

Expand Down Expand Up @@ -113,6 +116,11 @@ def __init__(
default_block_size: Optional[int] = None,
default_fill_cache: bool = True,
default_cache_type: str = "readahead",
multipart_staging_dirs: str = "/tmp",
multipart_size: int = 8388608,
multipart_thread_pool_size: int = 96,
multipart_staging_buffer_size: int = 4096,
multipart_threshold: int = 10485760,
**kwargs: Any,
) -> None:
"""Initialise the TosFileSystem.
Expand Down Expand Up @@ -184,6 +192,12 @@ def __init__(
self.default_cache_type = default_cache_type
self.max_retry_num = max_retry_num

self.multipart_staging_dirs = [d.strip() for d in multipart_staging_dirs.split(",")]
self.multipart_size = multipart_size
self.multipart_thread_pool_size = multipart_thread_pool_size
self.multipart_staging_buffer_size = multipart_staging_buffer_size
self.multipart_threshold = multipart_threshold

super().__init__(**kwargs)

def _open(
Expand Down Expand Up @@ -2002,6 +2016,14 @@ def __init__(
self.append_block = False
self.buffer: Optional[io.BytesIO] = io.BytesIO()

self.staging_dirs = itertools.cycle(fs.multipart_staging_dirs)
self.part_size = fs.multipart_size
self.thread_pool_size = fs.multipart_thread_pool_size
self.staging_buffer_size = fs.multipart_staging_buffer_size
self.multipart_threshold = fs.multipart_threshold
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.staging_files = []

if "a" in mode and fs.exists(path):
head = retryable_func_executor(
lambda: self.fs.tos_client.head_object(bucket, key),
Expand Down Expand Up @@ -2080,79 +2102,6 @@ def _upload_chunk(self, final: bool = False) -> bool:

return not final

def _upload_multiple_chunks(self, bucket: str, key: str) -> None:
if self.buffer:
self.buffer.seek(0)
current_chunk: Optional[bytes] = self.buffer.read(self.blocksize)

while current_chunk:
(previous_chunk, current_chunk) = (
current_chunk,
self.buffer.read(self.blocksize) if self.buffer else None,
)
current_chunk_size = len(current_chunk if current_chunk else b"")

# Define a helper function to handle the remainder logic
def handle_remainder(
previous_chunk: bytes,
current_chunk: Optional[bytes],
blocksize: int,
part_max: int,
) -> Tuple[bytes, Optional[bytes]]:
if current_chunk:
remainder = previous_chunk + current_chunk
else:
remainder = previous_chunk

remainder_size = (
blocksize + len(current_chunk) if current_chunk else blocksize
)

if remainder_size <= part_max:
return remainder, None
else:
partition = remainder_size // 2
return remainder[:partition], remainder[partition:]

# Use the helper function in the main code
if 0 < current_chunk_size < self.blocksize:
previous_chunk, current_chunk = handle_remainder(
previous_chunk, current_chunk, self.blocksize, PART_MAX_SIZE
)

part = len(self.parts) + 1 if self.parts is not None else 1
logger.debug("Upload chunk %s, %s", self, part)

def _call_upload_part(
part: int = part, previous_chunk: Optional[bytes] = previous_chunk
) -> UploadPartOutput:
return self.fs.tos_client.upload_part(
bucket=bucket,
key=key,
part_number=part,
upload_id=self.mpu.upload_id,
content=previous_chunk,
)

out = retryable_func_executor(
_call_upload_part, max_retry_num=self.fs.max_retry_num
)

(
self.parts.append(
PartInfo(
part_number=part,
etag=out.etag,
part_size=len(previous_chunk),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)
)
if self.parts is not None
else None
)

def _fetch_range(self, start: int, end: int) -> bytes:
if start == end:
logger.debug(
Expand All @@ -2176,6 +2125,111 @@ def fetch() -> bytes:

return retryable_func_executor(fetch, max_retry_num=self.fs.max_retry_num)

# def commit(self) -> None:
# """Complete multipart upload or PUT."""
# logger.debug("Commit %s", self)
# if self.tell() == 0:
# if self.buffer is not None:
# logger.debug("Empty file committed %s", self)
# self._abort_mpu()
# self.fs.touch(self.path, **self.kwargs)
# elif not self.parts:
# if self.buffer is not None:
# logger.debug("One-shot upload of %s", self)
# self.buffer.seek(0)
# data = self.buffer.read()
# write_result = retryable_func_executor(
# lambda: self.fs.tos_client.put_object(
# self.bucket, self.key, content=data
# ),
# max_retry_num=self.fs.max_retry_num,
# )
# else:
# raise RuntimeError
# else:
# logger.debug("Complete multi-part upload for %s ", self)
# write_result = retryable_func_executor(
# lambda: self.fs.tos_client.complete_multipart_upload(
# self.bucket,
# self.key,
# upload_id=self.mpu.upload_id,
# parts=self.parts,
# ),
# max_retry_num=self.fs.max_retry_num,
# )
#
# if self.fs.version_aware:
# self.version_id = write_result.version_id
#
# self.buffer = None

def discard(self) -> None:
"""Close the file without writing."""
self._abort_mpu()
self.buffer = None # file becomes unusable

def _abort_mpu(self) -> None:
if self.mpu:
retryable_func_executor(
lambda: self.fs.tos_client.abort_multipart_upload(
self.bucket, self.key, self.mpu.upload_id
),
max_retry_num=self.fs.max_retry_num,
)
self.mpu = None

def _upload_multiple_chunks(self, bucket: str, key: str) -> None:
if self.buffer:
self.buffer.seek(0)
while True:
chunk = self.buffer.read(self.part_size)
if not chunk:
break

staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(chunk)
self.staging_files.append(tmp.name)

def _upload_staged_files(self):
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
futures.append(
self.executor.submit(self._upload_part_from_file, staging_file, part_number)
)

for future in futures:
part_info = future.result()
self.parts.append(part_info)

self.staging_files = []

def _upload_part_from_file(self, staging_file: str, part_number: int) -> PartInfo:
with open(staging_file, "rb") as f:
content = f.read()

out = retryable_func_executor(
lambda: self.fs.tos_client.upload_part(
bucket=self.bucket,
key=self.key,
part_number=part_number,
upload_id=self.mpu.upload_id,
content=content,
),
max_retry_num=self.fs.max_retry_num,
)

os.remove(staging_file)
return PartInfo(
part_number=part_number,
etag=out.etag,
part_size=len(content),
offset=None,
hash_crc64_ecma=None,
is_completed=None,
)

def commit(self) -> None:
"""Complete multipart upload or PUT."""
logger.debug("Commit %s", self)
Expand All @@ -2184,7 +2238,7 @@ def commit(self) -> None:
logger.debug("Empty file committed %s", self)
self._abort_mpu()
self.fs.touch(self.path, **self.kwargs)
elif not self.parts:
elif not self.staging_files:
if self.buffer is not None:
logger.debug("One-shot upload of %s", self)
self.buffer.seek(0)
Expand All @@ -2198,6 +2252,7 @@ def commit(self) -> None:
else:
raise RuntimeError
else:
self._upload_staged_files()
logger.debug("Complete multi-part upload for %s ", self)
write_result = retryable_func_executor(
lambda: self.fs.tos_client.complete_multipart_upload(
Expand All @@ -2214,17 +2269,3 @@ def commit(self) -> None:

self.buffer = None

def discard(self) -> None:
"""Close the file without writing."""
self._abort_mpu()
self.buffer = None # file becomes unusable

def _abort_mpu(self) -> None:
if self.mpu:
retryable_func_executor(
lambda: self.fs.tos_client.abort_multipart_upload(
self.bucket, self.key, self.mpu.upload_id
),
max_retry_num=self.fs.max_retry_num,
)
self.mpu = None

0 comments on commit eb03197

Please sign in to comment.