Skip to content

Commit

Permalink
Support streaming zip file download
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Nov 24, 2024
1 parent cc4bf52 commit 82d9a3d
Show file tree
Hide file tree
Showing 8 changed files with 248 additions and 31 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.39.post18"
"version": "0.20.39.post19"
}
171 changes: 169 additions & 2 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@
AsyncSession,
)

from datetime import datetime
from stat import S_IFREG
from stream_zip import ZIP_32, async_stream_zip
import httpx

from hrid import HRID
from hypha.utils import remove_objects_async, list_objects_async, safe_join
from hypha.utils.zenodo import ZenodoClient
from botocore.exceptions import ClientError
from hypha.s3 import FSFileResponse
from aiobotocore.session import get_session

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import RedirectResponse
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse, StreamingResponse
from hypha.core import (
UserInfo,
UserPermission,
Expand Down Expand Up @@ -247,6 +252,168 @@ async def list_children(
status_code=500, detail=f"An unexpected error occurred: {str(e)}"
)

@router.get("/{workspace}/artifacts/{artifact_alias}/zip-files")
async def create_zip_file(
workspace: str,
artifact_alias: str,
files: List[str] = Query(None, alias="file"),
token: str = None,
version: str = None,
silent: bool = False,
user_info: self.store.login_optional = Depends(self.store.login_optional),
):
try:
# Validate artifact and permissions
artifact_id = self._validate_artifact_id(
artifact_alias, {"ws": workspace}
)
session = await self._get_session(read_only=True)
if token:
user_info = await self.store.parse_user_token(token)

async with session.begin():
# Fetch artifact and check permissions
(
artifact,
parent_artifact,
) = await self._get_artifact_with_permission(
user_info, artifact_id, "get_file", session
)
version_index = self._get_version_index(artifact, version)
s3_config = self._get_s3_config(artifact, parent_artifact)

async with self._create_client_async(s3_config) as s3_client:
if files is None:
# List all files in the artifact
root_dir_key = safe_join(
s3_config["prefix"],
workspace,
f"{self._artifacts_dir}/{artifact.id}/v{version_index}",
)

async def list_all_files(dir_path=""):
try:
dir_key = f"{root_dir_key}/{dir_path}".strip("/")
items = await list_objects_async(
s3_client,
s3_config["bucket"],
dir_key + "/",
)
for item in items:
item_path = f"{dir_path}/{item['name']}".strip(
"/"
)
if item["type"] == "file":
yield item_path
elif item["type"] == "directory":
async for sub_item in list_all_files(
item_path
):
yield sub_item
except Exception as e:
logger.error(f"Error listing files: {str(e)}")
raise HTTPException(
status_code=500, detail="Error listing files"
)

files = list_all_files()
else:

async def validate_files(files):
for file in files:
yield file

files = validate_files(files)

async def file_stream_generator(presigned_url: str):
"""Fetch file content from presigned URL in chunks."""
try:
async with httpx.AsyncClient() as client:
async with client.stream(
"GET", presigned_url
) as response:
if response.status_code != 200:
logger.error(
f"Failed to fetch file from URL: {presigned_url}, Status: {response.status_code}"
)
raise HTTPException(
status_code=404,
detail=f"Failed to fetch file: {presigned_url}",
)
async for chunk in response.aiter_bytes(
1024 * 64
): # 64KB chunks
yield chunk
except Exception as e:
logger.error(f"Error fetching file stream: {str(e)}")
raise HTTPException(
status_code=500,
detail="Error fetching file content",
)

async def member_files():
"""Yield file metadata and content for stream_zip."""
modified_at = datetime.now()
mode = S_IFREG | 0o600
async for path in files:
file_key = safe_join(
s3_config["prefix"],
workspace,
f"{self._artifacts_dir}/{artifact.id}/v{version_index}",
path,
)
try:
presigned_url = (
await s3_client.generate_presigned_url(
"get_object",
Params={
"Bucket": s3_config["bucket"],
"Key": file_key,
},
)
)
yield (
path,
modified_at,
mode,
ZIP_32,
file_stream_generator(presigned_url),
)
except Exception as e:
logger.error(
f"Error processing file {path}: {str(e)}"
)
raise HTTPException(
status_code=500,
detail=f"Error processing file: {path}",
)

# Return the ZIP file as a streaming response
return StreamingResponse(
async_stream_zip(member_files()),
media_type="application/zip",
headers={
"Content-Disposition": f"attachment; filename={artifact_alias}.zip"
},
)

await session.commit()

except KeyError:
raise HTTPException(status_code=404, detail="Artifact not found")
except PermissionError:
raise HTTPException(status_code=403, detail="Permission denied")
except HTTPException as e:
logger.error(f"HTTPException: {str(e)}")
raise e # Re-raise HTTPExceptions to be handled by FastAPI
except Exception as e:
logger.error(f"Unhandled exception in create_zip: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Internal server error: {str(e)}"
)
finally:
await session.close()

# HTTP endpoint for retrieving files within an artifact
@router.get("/{workspace}/artifacts/{artifact_alias}/files/{path:path}")
async def get_file(
Expand Down
2 changes: 1 addition & 1 deletion hypha/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,6 @@ async def _subscribe_redis(self):
logger.info("Unknown channel: %s", channel)
except Exception as exp:
logger.exception(f"Error processing message: {exp}")
await asyncio.sleep(0.01)
await asyncio.sleep(0)
except Exception as exp:
self._ready.set_exception(exp)
3 changes: 2 additions & 1 deletion hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from starlette.routing import Mount
from pydantic.fields import Field
from aiocache.backends.redis import RedisCache
from aiocache.serializers import PickleSerializer

from hypha import __version__
from hypha.core import (
Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(

self._redis = aioredis.FakeRedis.from_url("redis://localhost:9997/11")

self._redis_cache = RedisCache()
self._redis_cache = RedisCache(serializer=PickleSerializer())
self._redis_cache.client = self._redis

self._root_user = None
Expand Down
5 changes: 4 additions & 1 deletion hypha/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if self.background is not None:
await self.background()


async def fetch_zip_tail(s3_client, workspace_bucket, s3_key, content_length):
"""
Fetch the tail part of the zip file that contains the central directory.
Expand Down Expand Up @@ -314,7 +315,9 @@ async def get_zip_file_content(
)

# Fetch the ZIP's central directory from cache or download if not cached
cache_key = f"zip_tail:{self.workspace_bucket}:{s3_key}:{content_length}"
cache_key = (
f"zip_tail:{self.workspace_bucket}:{s3_key}:{content_length}"
)
zip_tail = await cache.get(cache_key)
if zip_tail is None:
zip_tail = await fetch_zip_tail(
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ hrid==0.2.4
qdrant-client==1.12.1
ollama==0.3.3
fastembed==0.4.2
asgiproxy==0.1.1
asgiproxy==0.1.1
stream-zip==0.0.83
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"alembic>=1.14.0",
"hrid>=0.2.4",
"asgiproxy>=0.1.1",
"stream-zip>=0.0.83",
]

ROOT_DIR = Path(__file__).parent.resolve()
Expand Down
92 changes: 68 additions & 24 deletions tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
import numpy as np
import random
from hypha_rpc import connect_to_server
from io import BytesIO
from zipfile import ZipFile
import httpx

from . import SERVER_URL, SERVER_URL_SQLITE, find_item

Expand Down Expand Up @@ -395,7 +398,7 @@ async def test_serve_artifact_endpoint(minio_server, fastapi_server, test_user_t
async def test_http_file_and_directory_endpoint(
minio_server, fastapi_server, test_user_token
):
"""Test the HTTP file serving and directory listing endpoint."""
"""Test the HTTP file serving and directory listing endpoint, including nested files."""

# Connect and get the artifact manager service
api = await connect_to_server(
Expand Down Expand Up @@ -433,50 +436,91 @@ async def test_http_file_and_directory_endpoint(
file_path="example.txt",
download_weight=1,
)
response = requests.put(put_url, data=file_contents)
assert response.ok
async with httpx.AsyncClient(timeout=20) as client:
response = await client.put(put_url, data=file_contents)
assert response.status_code == 200

# Add another file to the dataset artifact in a nested directory
file_contents2 = "file contents of nested/example2.txt"
nested_file_path = "nested/example2.txt"
put_url = await artifact_manager.put_file(
artifact_id=dataset.id,
file_path=nested_file_path,
download_weight=1,
)
async with httpx.AsyncClient(timeout=20) as client:
response = await client.put(put_url, data=file_contents2)
assert response.status_code == 200

# Commit the dataset artifact
await artifact_manager.commit(artifact_id=dataset.id)

files = await artifact_manager.list_files(artifact_id=dataset.id)
assert len(files) == 1
assert len(files) == 2

# Retrieve the file via HTTP
response = requests.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/example.txt",
allow_redirects=True,
)
# Check if the connection has been redirected
assert response.history[0].status_code == 302
assert response.status_code == 200
assert response.text == file_contents
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/example.txt",
follow_redirects=True,
)
assert response.status_code == 200
assert response.text == file_contents

# Check download count increment
artifact = await artifact_manager.read(
artifact_id=dataset.id,
)
assert artifact["download_count"] == 1

# Try to get it using http proxy
response = requests.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/example.txt?use_proxy=1"
)
assert response.status_code == 200
assert response.text == file_contents
# Try to get it using HTTP proxy
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/example.txt?use_proxy=1"
)
assert response.status_code == 200
assert response.text == file_contents

# Check download count increment
artifact = await artifact_manager.read(
artifact_id=dataset.id,
)
assert artifact["download_count"] == 2

# Attempt to list directory contents (should be successful after attempting file)
response = requests.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/"
)
assert response.status_code == 200
assert "example.txt" in [file["name"] for file in response.json()]
# Attempt to list directory contents
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/files/"
)
assert response.status_code == 200
assert "example.txt" in [file["name"] for file in response.json()]

# Get the zip file with specific files
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/zip-files?file=example.txt&file={nested_file_path}"
)
assert response.status_code == 200
# Write the zip file in a io.BytesIO object, then check if the file contents are correct
zip_file = ZipFile(BytesIO(response.content))
assert sorted(zip_file.namelist()) == sorted(
["example.txt", "nested/example2.txt"]
)
assert zip_file.read("example.txt").decode() == "file contents of example.txt"
assert zip_file.read("nested/example2.txt").decode() == file_contents2

# Get the zip file with all files
async with httpx.AsyncClient(timeout=20) as client:
response = await client.get(
f"{SERVER_URL}/{api.config.workspace}/artifacts/{dataset.alias}/zip-files"
)
assert response.status_code == 200, response.text
zip_file = ZipFile(BytesIO(response.content))
assert sorted(zip_file.namelist()) == sorted(
["example.txt", "nested/example2.txt"]
)
assert zip_file.read("example.txt").decode() == "file contents of example.txt"
assert zip_file.read("nested/example2.txt").decode() == file_contents2


async def test_artifact_permissions(
Expand Down

0 comments on commit 82d9a3d

Please sign in to comment.