Skip to content

Commit

Permalink
Add test cases for stability
Browse files Browse the repository at this point in the history
  • Loading branch information
yanghua committed Sep 24, 2024
1 parent 45cee34 commit f408ce1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 13 deletions.
23 changes: 16 additions & 7 deletions tosfs/mpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,9 @@ def upload_multiple_chunks(self, buffer: Optional[io.BytesIO]) -> None:
def _write_to_staging_buffer(self, chunk: bytes) -> None:
self.staging_buffer.write(chunk)
if self.staging_buffer.tell() >= self.part_size:
self._flush_staging_buffer()
self._flush_staging_buffer(False)

def _flush_staging_buffer(self) -> None:
def _flush_staging_buffer(self, final: bool = False) -> None:
if self.staging_buffer.tell() == 0:
return

Expand All @@ -93,13 +93,22 @@ def _flush_staging_buffer(self) -> None:
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

# Move remaining data to a new buffer
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
if not final:
# Move remaining data to a new buffer
remaining_data = self.staging_buffer.read()
self.staging_buffer = io.BytesIO()
self.staging_buffer.write(remaining_data)
else:
staging_dir = next(self.staging_dirs)
with tempfile.NamedTemporaryFile(delete=False, dir=staging_dir) as tmp:
tmp.write(self.staging_buffer.read())
self.staging_files.append(tmp.name)
buffer_size -= self.part_size

self.staging_buffer = io.BytesIO()

def _upload_staged_files(self) -> None:
self._flush_staging_buffer()
self._flush_staging_buffer(True)
futures = []
for i, staging_file in enumerate(self.staging_files):
part_number = i + 1
Expand Down
16 changes: 10 additions & 6 deletions tosfs/tests/test_stability.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,37 @@

def test_write_breakpoint_continuation(tosfs, bucket, temporary_workspace):
file_name = f"{random_str()}.txt"
first_part = random_str(9 * 1024 * 1024)
second_part = random_str(9 * 1024 * 1024)
first_part = random_str(10 * 1024 * 1024)
second_part = random_str(10 * 1024 * 1024)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "w") as f:
f.write(first_part)
# mock a very long block(business processing or network issue)
sleep(60)
f.write(second_part)

assert tosfs.info(f"{bucket}/{temporary_workspace}/{file_name}")["size"] == len(
first_part + second_part
)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f:
assert f.read() == first_part + second_part


def test_read_breakpoint_continuation(tosfs, bucket, temporary_workspace):
file_name = f"{random_str()}.txt"
first_part = random_str(9 * 1024 * 1024)
second_part = random_str(9 * 1024 * 1024)
first_part = random_str(10 * 1024 * 1024)
second_part = random_str(10 * 1024 * 1024)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "w") as f:
f.write(first_part)
f.write(second_part)

with tosfs.open(f"{bucket}/{temporary_workspace}/{file_name}", "r") as f:
read_first_part = f.read(9 * 1024 * 1024)
read_first_part = f.read(10 * 1024 * 1024)
assert read_first_part == first_part
# mock a very long block(business processing or network issue)
sleep(60)
read_second_part = f.read(9 * 1024 * 1024)
read_second_part = f.read(10 * 1024 * 1024)
assert read_second_part == second_part
assert read_first_part + read_second_part == first_part + second_part

0 comments on commit f408ce1

Please sign in to comment.