diff --git a/s3fs/core.py b/s3fs/core.py index 2da6f0bd..60753786 100644 --- a/s3fs/core.py +++ b/s3fs/core.py @@ -2086,8 +2086,15 @@ async def _invalidate_region_cache(self): async def open_async(self, path, mode="rb", **kwargs): if "b" not in mode or kwargs.get("compression"): raise ValueError + if "w" in mode or "a" in mode: + return await self._open_for_writing(path, mode, **kwargs) return S3AsyncStreamedFile(self, path, mode) + async def _open_for_writing(self, path, mode, **kwargs): + # Parse the path to get bucket and key + bucket, key, _ = self.split_path(path) + return S3AsyncStreamWriter(self, bucket, key) + class S3File(AbstractBufferedFile): """ @@ -2429,6 +2436,26 @@ def _abort_mpu(self): self.mpu = None +# Define a new class to represent the file-like object for writing +class S3AsyncStreamWriter: + def __init__(self, s3_fs, bucket, key): + self.s3_fs = s3_fs + self.bucket = bucket + self.key = key + self.closed = False + self.loc = 0 + + async def write(self, data): + # Write data directly to S3 object + await self.s3_fs._call_s3( + "put_object", + Bucket=self.bucket, + Key=self.key, + Body=data, + ) + self.loc += len(data) + + class S3AsyncStreamedFile(AbstractAsyncStreamedFile): def __init__(self, fs, path, mode): self.fs = fs @@ -2481,3 +2508,4 @@ async def _call_and_read(): resp["Body"].close() return await _error_wrapper(_call_and_read, retries=fs.retries) + diff --git a/s3fs/tests/test_s3fs.py b/s3fs/tests/test_s3fs.py index d3d90899..dd59e240 100644 --- a/s3fs/tests/test_s3fs.py +++ b/s3fs/tests/test_s3fs.py @@ -2696,6 +2696,18 @@ async def read_stream(): break out.append(got) + async def write_stream(): + fs = S3FileSystem( + anon=False, + client_kwargs={"endpoint_url": endpoint_uri}, + skip_instance_cache=True, + ) + await fs._mkdir(test_bucket_name) + f = await fs.open_async(fn, mode="wb") + await f.write(data) + await f.close() + + asyncio.run(write_stream()) asyncio.run(read_stream()) assert b"".join(out) == data