From 66882d7303c46f3c419cf104d7628526556a2d4b Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Thu, 10 Oct 2024 10:44:10 -0700 Subject: [PATCH] Support download statistics --- hypha/artifact.py | 202 ++++++++++++++++++++++++++-------------- hypha/core/__init__.py | 5 +- hypha/core/store.py | 4 - hypha/core/workspace.py | 2 - tests/test_artifact.py | 130 ++++++++++++++++++++++++++ 5 files changed, 264 insertions(+), 79 deletions(-) diff --git a/hypha/artifact.py b/hypha/artifact.py index 3bdb542e..ecf4c668 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -1,10 +1,12 @@ import logging import sys +import copy from sqlalchemy import ( event, Column, String, Integer, + Float, JSON, UniqueConstraint, select, @@ -14,6 +16,7 @@ ) from hypha.utils import remove_objects_async, list_objects_async, safe_join from botocore.exceptions import ClientError +from sqlalchemy import update from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.asyncio import ( async_sessionmaker, @@ -50,6 +53,11 @@ class ArtifactModel(Base): manifest = Column(JSON, nullable=True) # Store committed manifest stage_manifest = Column(JSON, nullable=True) # Store staged manifest stage_files = Column(JSON, nullable=True) # Store staged files during staging + download_weights = Column( + JSON, nullable=True + ) # Store the weights for counting downloads; a dictionary of file paths and their weights 0-1 + download_count = Column(Float, nullable=False, default=0.0) # New counter field + view_count = Column(Float, nullable=False, default=0.0) # New counter field __table_args__ = ( UniqueConstraint("workspace", "prefix", name="_workspace_prefix_uc"), ) @@ -127,58 +135,6 @@ def end_transaction(session, transaction): return session - async def _read_manifest(self, workspace, prefix, stage=False): - session = await self._get_session() - try: - async with session.begin(): - query = select(ArtifactModel).filter( - ArtifactModel.workspace == workspace, - ArtifactModel.prefix == prefix, - ) - result = await session.execute(query) - artifact = result.scalar_one_or_none() - - if not artifact: - raise KeyError(f"Artifact under prefix '{prefix}' does not exist.") - - manifest = artifact.stage_manifest if stage else artifact.manifest - if not manifest: - raise KeyError(f"No manifest found for artifact '{prefix}'.") - - # If the artifact is a collection, dynamically populate the 'collection' field - if manifest.get("type") == "collection": - sub_prefix = f"{prefix}/" - query = select(ArtifactModel).filter( - ArtifactModel.workspace == workspace, - ArtifactModel.prefix.like(f"{sub_prefix}%"), - ) - result = await session.execute(query) - sub_artifacts = result.scalars().all() - - # Populate the 'collection' field with summary_fields for each sub-artifact - summary_fields = manifest.get( - "summary_fields", ["id", "name", "description"] - ) - collection = [] - for artifact in sub_artifacts: - sub_manifest = artifact.manifest - summary = {"_prefix": artifact.prefix} - for field in summary_fields: - value = sub_manifest.get(field) - if value is not None: - summary[field] = value - collection.append(summary) - - manifest["collection"] = collection - - if stage: - manifest["stage_files"] = artifact.stage_files - return manifest - except Exception as e: - raise e - finally: - await session.close() - async def _get_artifact(self, session, workspace, prefix): query = select(ArtifactModel).filter( ArtifactModel.workspace == workspace, @@ -253,6 +209,23 @@ async def _read_manifest(self, workspace, prefix, stage=False): if stage: manifest["stage_files"] = artifact.stage_files + else: + # increase view count + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + # atomically increment the view count + .values(view_count=ArtifactModel.view_count + 1) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + manifest["_stats"] = { + "download_count": artifact.download_count, + "view_count": artifact.view_count, + } + if manifest.get("type") == "collection": + manifest["_stats"]["child_count"] = len(collection) return manifest except Exception as e: raise e @@ -320,6 +293,7 @@ async def create( manifest=None if stage else manifest, stage_manifest=manifest if stage else None, stage_files=[] if stage else None, + download_weights=None, type=manifest["type"], ) session.add(new_artifact) @@ -332,6 +306,38 @@ async def create( return manifest + async def reset_stats(self, prefix, context: dict): + """Reset the artifact's manifest's download count and view count.""" + if context is None or "ws" not in context: + raise ValueError("Context must include 'ws' (workspace).") + ws = context["ws"] + + user_info = UserInfo.model_validate(context["user"]) + if not user_info.check_permission(ws, UserPermission.read_write): + raise PermissionError( + "User does not have write permission to the workspace." + ) + + session = await self._get_session() + try: + async with session.begin(): + artifact = await self._get_artifact(session, ws, prefix) + if not artifact: + raise KeyError(f"Artifact under prefix '{prefix}' does not exist.") + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + .values(download_count=0, view_count=0) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + logger.info(f"Reset artifact under prefix: {prefix}") + except Exception as e: + raise e + finally: + await session.close() + async def read(self, prefix, stage=False, context: dict = None): """Read the artifact's manifest from the database and populate collections dynamically.""" if context is None or "ws" not in context: @@ -359,17 +365,18 @@ async def edit(self, prefix, manifest=None, context: dict = None): "User does not have write permission to the workspace." ) - # Validate the manifest - if manifest["type"] == "collection": - CollectionArtifact.model_validate(manifest) - elif manifest["type"] == "application": - ApplicationArtifact.model_validate(manifest) - elif manifest["type"] == "workspace": - WorkspaceInfo.model_validate(manifest) + if manifest: + # Validate the manifest + if manifest["type"] == "collection": + CollectionArtifact.model_validate(manifest) + elif manifest["type"] == "application": + ApplicationArtifact.model_validate(manifest) + elif manifest["type"] == "workspace": + WorkspaceInfo.model_validate(manifest) - # Convert ObjectProxy to dict if necessary - if isinstance(manifest, ObjectProxy): - manifest = ObjectProxy.toDict(manifest) + # Convert ObjectProxy to dict if necessary + if isinstance(manifest, ObjectProxy): + manifest = ObjectProxy.toDict(manifest) session = await self._get_session() try: @@ -377,7 +384,8 @@ async def edit(self, prefix, manifest=None, context: dict = None): artifact = await self._get_artifact(session, ws, prefix) if not artifact: raise KeyError(f"Artifact under prefix '{prefix}' does not exist.") - + if manifest is None: + manifest = copy.deepcopy(artifact.manifest) artifact.stage_manifest = manifest flag_modified(artifact, "stage_manifest") # Mark JSON field as modified session.add(artifact) @@ -411,6 +419,7 @@ async def commit(self, prefix, context: dict): manifest = artifact.stage_manifest + download_weights = {} # Validate files exist in S3 if the staged files list is present if artifact.stage_files: async with self.s3_controller.create_client_async() as s3_client: @@ -424,7 +433,10 @@ async def commit(self, prefix, context: dict): raise FileNotFoundError( f"File '{file_info['path']}' does not exist in the artifact." ) - + if file_info.get("download_weight") is not None: + download_weights[file_info["path"]] = file_info[ + "download_weight" + ] # Validate the schema if the artifact belongs to a collection parent_prefix = "/".join(prefix.split("/")[:-1]) if parent_prefix: @@ -447,7 +459,10 @@ async def commit(self, prefix, context: dict): artifact.manifest = manifest artifact.stage_manifest = None artifact.stage_files = None + artifact.download_weights = download_weights flag_modified(artifact, "manifest") + flag_modified(artifact, "stage_files") + flag_modified(artifact, "download_weights") session.add(artifact) await session.commit() logger.info(f"Committed artifact under prefix: {prefix}") @@ -636,7 +651,9 @@ async def search( finally: await session.close() - async def put_file(self, prefix, file_path, context: dict = None): + async def put_file( + self, prefix, file_path, options: dict = None, context: dict = None + ): """Generate a pre-signed URL to upload a file to an artifact in S3 and update the manifest.""" ws = context["ws"] user_info = UserInfo.model_validate(context["user"]) @@ -645,6 +662,8 @@ async def put_file(self, prefix, file_path, context: dict = None): "User does not have write permission to the workspace." ) + options = options or {} + async with self.s3_controller.create_client_async() as s3_client: file_key = safe_join(ws, f"{prefix}/{file_path}") presigned_url = await s3_client.generate_presigned_url( @@ -663,7 +682,12 @@ async def put_file(self, prefix, file_path, context: dict = None): artifact.stage_files = artifact.stage_files or [] if not any(f["path"] == file_path for f in artifact.stage_files): - artifact.stage_files.append({"path": file_path}) + artifact.stage_files.append( + { + "path": file_path, + "download_weight": options.get("download_weight"), + } + ) flag_modified(artifact, "stage_files") session.add(artifact) @@ -676,7 +700,7 @@ async def put_file(self, prefix, file_path, context: dict = None): return presigned_url - async def get_file(self, prefix, path, context: dict): + async def get_file(self, prefix, path, options: dict = None, context: dict = None): """Generate a pre-signed URL to download a file from an artifact in S3.""" ws = context["ws"] @@ -694,6 +718,30 @@ async def get_file(self, prefix, path, context: dict): ExpiresIn=3600, ) logger.info(f"Generated pre-signed URL for file download: {path}") + + if options is None or not options.get("silent"): + session = await self._get_session() + try: + async with session.begin(): + artifact = await self._get_artifact(session, ws, prefix) + if artifact.download_weights and path in artifact.download_weights: + # if it has download_weights, increment the download count by the weight + stmt = ( + update(ArtifactModel) + .where(ArtifactModel.id == artifact.id) + # atomically increment the download count by the weight + .values( + download_count=ArtifactModel.download_count + + artifact.download_weights[path] + ) + .execution_options(synchronize_session="fetch") + ) + await session.execute(stmt) + await session.commit() + except Exception as e: + raise e + finally: + await session.close() return presigned_url async def remove_file(self, prefix, file_path, context: dict): @@ -714,11 +762,20 @@ async def remove_file(self, prefix, file_path, context: dict): raise KeyError( f"Artifact under prefix '{prefix}' is not in staging mode." ) - # remove the file from the staged files list - artifact.stage_files = [ - f for f in artifact.stage_files if f["path"] != file_path - ] - flag_modified(artifact, "stage_files") + if artifact.stage_files: + # remove the file from the staged files list + artifact.stage_files = [ + f for f in artifact.stage_files if f["path"] != file_path + ] + flag_modified(artifact, "stage_files") + if artifact.download_weights: + # remove the file from download_weights if it's there + artifact.download_weights = { + k: v + for k, v in artifact.download_weights.items() + if k != file_path + } + flag_modified(artifact, "download_weights") session.add(artifact) await session.commit() except Exception as e: @@ -740,6 +797,7 @@ def get_artifact_service(self): "name": "Artifact Manager", "description": "Manage artifacts in a workspace.", "create": self.create, + "reset_stats": self.reset_stats, "edit": self.edit, "read": self.read, "commit": self.commit, diff --git a/hypha/core/__init__.py b/hypha/core/__init__.py index d436125c..e5b79a55 100644 --- a/hypha/core/__init__.py +++ b/hypha/core/__init__.py @@ -444,7 +444,10 @@ async def disconnect(self, reason=None): class RedisEventBus: """Represent a redis event bus.""" - _counter = Counter("event_bus", "Counts the events on the redis event bus", ["event"]) + + _counter = Counter( + "event_bus", "Counts the events on the redis event bus", ["event"] + ) def __init__(self, redis) -> None: """Initialize the event bus.""" diff --git a/hypha/core/store.py b/hypha/core/store.py index d222eacb..b7fa1c49 100644 --- a/hypha/core/store.py +++ b/hypha/core/store.py @@ -96,7 +96,6 @@ def __init__( """Initialize the redis store.""" self._s3_controller = None self._artifact_manager = None - self._logging_service = None self._app = app self._codecs = {} self._disconnected_plugins = [] @@ -392,8 +391,6 @@ async def init(self, reset_redis, startup_functions=None): logger.warning("RESETTING ALL REDIS DATA!!!") await self._redis.flushall() await self._event_bus.init() - if self._logging_service: - await self._logging_service.init_db() if self._artifact_manager: await self._artifact_manager.init_db() await self.setup_root_user() @@ -640,7 +637,6 @@ async def register_workspace_manager(self): self._sql_engine, self._s3_controller, self._artifact_manager, - self._logging_service, ) await manager.setup() return manager diff --git a/hypha/core/workspace.py b/hypha/core/workspace.py index 4043eb86..2fdaa04b 100644 --- a/hypha/core/workspace.py +++ b/hypha/core/workspace.py @@ -108,7 +108,6 @@ def __init__( sql_engine: Optional[str] = None, s3_controller: Optional[Any] = None, artifact_manager: Optional[Any] = None, - logging_service: Optional[Any] = None, ): self._redis = redis self._initialized = False @@ -119,7 +118,6 @@ def __init__( self._client_id = client_id self._s3_controller = s3_controller self._artifact_manager = artifact_manager - self._logging_service = logging_service self._sql_engine = sql_engine if self._sql_engine: self.SessionLocal = async_sessionmaker( diff --git a/tests/test_artifact.py b/tests/test_artifact.py index db18bc08..1d6fb534 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -168,6 +168,11 @@ async def test_edit_existing_artifact(minio_server, fastapi_server): "_prefix", "collections/edit-test-collection/edit-test-dataset", ) + initial_view_count = collection["_stats"]["view_count"] + assert initial_view_count > 0 + assert collection["_stats"]["child_count"] > 0 + collection = await artifact_manager.read(prefix="collections/edit-test-collection") + assert collection["_stats"]["view_count"] == initial_view_count + 1 # Edit the artifact's manifest edited_manifest = { @@ -725,3 +730,128 @@ async def test_artifact_search_with_filters(minio_server, fastapi_server): prefix=f"collections/search-test-collection/test-dataset-{i}" ) await artifact_manager.delete(prefix="collections/search-test-collection") + + +async def test_download_count(minio_server, fastapi_server): + """Test the download count functionality for artifacts.""" + api = await connect_to_server({"name": "test-client", "server_url": SERVER_URL}) + artifact_manager = await api.get_service("public/artifact-manager") + + # Create a collection for testing download count + collection_manifest = { + "id": "download-test-collection", + "name": "Download Test Collection", + "description": "A collection to test download count functionality", + "type": "collection", + "collection": [], + } + await artifact_manager.create( + prefix="collections/download-test-collection", + manifest=collection_manifest, + stage=False, + ) + + # Create an artifact inside the collection + dataset_manifest = { + "id": "download-test-dataset", + "name": "Download Test Dataset", + "description": "A test dataset for download count", + "type": "dataset", + } + await artifact_manager.create( + prefix="collections/download-test-collection/download-test-dataset", + manifest=dataset_manifest, + stage=True, + ) + + # Put a file in the artifact + put_url = await artifact_manager.put_file( + prefix="collections/download-test-collection/download-test-dataset", + file_path="example.txt", + options={ + "download_weight": 0.5 + }, # Set the file as primary so downloading it will be count as a download + ) + source = "file contents of example.txt" + response = requests.put(put_url, data=source) + assert response.ok + + # put another file in the artifact but not setting weights + put_url = await artifact_manager.put_file( + prefix="collections/download-test-collection/download-test-dataset", + file_path="example2.txt", + ) + source = "file contents of example2.txt" + response = requests.put(put_url, data=source) + assert response.ok + + # Commit the artifact + await artifact_manager.commit( + prefix="collections/download-test-collection/download-test-dataset" + ) + + # Ensure that the download count is initially zero + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0 + + # Increment the download count of the artifact by download the primary file + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example.txt", + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # If we get the example file in silent mode, the download count won't increment + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example.txt", + options={"silent": True}, + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is not incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # download example 2 won't increment the download count + get_url = await artifact_manager.get_file( + prefix="collections/download-test-collection/download-test-dataset", + path="example2.txt", + ) + response = requests.get(get_url) + assert response.ok + + # Ensure that the download count is incremented + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0.5 + + # Let's call reset_stats to reset the download count + await artifact_manager.reset_stats( + prefix="collections/download-test-collection/download-test-dataset" + ) + + # Ensure that the download count is reset + artifact = await artifact_manager.read( + prefix="collections/download-test-collection/download-test-dataset" + ) + assert artifact["_stats"]["download_count"] == 0 + + # Clean up by deleting the dataset and the collection + await artifact_manager.delete( + prefix="collections/download-test-collection/download-test-dataset" + ) + await artifact_manager.delete(prefix="collections/download-test-collection")