Skip to content

Commit

Permalink
Support download statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Oct 10, 2024
1 parent c08ba6e commit 66882d7
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 79 deletions.
202 changes: 130 additions & 72 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging
import sys
import copy
from sqlalchemy import (
event,
Column,
String,
Integer,
Float,
JSON,
UniqueConstraint,
select,
Expand All @@ -14,6 +16,7 @@
)
from hypha.utils import remove_objects_async, list_objects_async, safe_join
from botocore.exceptions import ClientError
from sqlalchemy import update
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.asyncio import (
async_sessionmaker,
Expand Down Expand Up @@ -50,6 +53,11 @@ class ArtifactModel(Base):
manifest = Column(JSON, nullable=True) # Store committed manifest
stage_manifest = Column(JSON, nullable=True) # Store staged manifest
stage_files = Column(JSON, nullable=True) # Store staged files during staging
download_weights = Column(
JSON, nullable=True
) # Store the weights for counting downloads; a dictionary of file paths and their weights 0-1
download_count = Column(Float, nullable=False, default=0.0) # New counter field
view_count = Column(Float, nullable=False, default=0.0) # New counter field
__table_args__ = (
UniqueConstraint("workspace", "prefix", name="_workspace_prefix_uc"),
)
Expand Down Expand Up @@ -127,58 +135,6 @@ def end_transaction(session, transaction):

return session

async def _read_manifest(self, workspace, prefix, stage=False):
session = await self._get_session()
try:
async with session.begin():
query = select(ArtifactModel).filter(
ArtifactModel.workspace == workspace,
ArtifactModel.prefix == prefix,
)
result = await session.execute(query)
artifact = result.scalar_one_or_none()

if not artifact:
raise KeyError(f"Artifact under prefix '{prefix}' does not exist.")

manifest = artifact.stage_manifest if stage else artifact.manifest
if not manifest:
raise KeyError(f"No manifest found for artifact '{prefix}'.")

# If the artifact is a collection, dynamically populate the 'collection' field
if manifest.get("type") == "collection":
sub_prefix = f"{prefix}/"
query = select(ArtifactModel).filter(
ArtifactModel.workspace == workspace,
ArtifactModel.prefix.like(f"{sub_prefix}%"),
)
result = await session.execute(query)
sub_artifacts = result.scalars().all()

# Populate the 'collection' field with summary_fields for each sub-artifact
summary_fields = manifest.get(
"summary_fields", ["id", "name", "description"]
)
collection = []
for artifact in sub_artifacts:
sub_manifest = artifact.manifest
summary = {"_prefix": artifact.prefix}
for field in summary_fields:
value = sub_manifest.get(field)
if value is not None:
summary[field] = value
collection.append(summary)

manifest["collection"] = collection

if stage:
manifest["stage_files"] = artifact.stage_files
return manifest
except Exception as e:
raise e
finally:
await session.close()

async def _get_artifact(self, session, workspace, prefix):
query = select(ArtifactModel).filter(
ArtifactModel.workspace == workspace,
Expand Down Expand Up @@ -253,6 +209,23 @@ async def _read_manifest(self, workspace, prefix, stage=False):

if stage:
manifest["stage_files"] = artifact.stage_files
else:
# increase view count
stmt = (
update(ArtifactModel)
.where(ArtifactModel.id == artifact.id)
# atomically increment the view count
.values(view_count=ArtifactModel.view_count + 1)
.execution_options(synchronize_session="fetch")
)
await session.execute(stmt)
await session.commit()
manifest["_stats"] = {
"download_count": artifact.download_count,
"view_count": artifact.view_count,
}
if manifest.get("type") == "collection":
manifest["_stats"]["child_count"] = len(collection)
return manifest
except Exception as e:
raise e
Expand Down Expand Up @@ -320,6 +293,7 @@ async def create(
manifest=None if stage else manifest,
stage_manifest=manifest if stage else None,
stage_files=[] if stage else None,
download_weights=None,
type=manifest["type"],
)
session.add(new_artifact)
Expand All @@ -332,6 +306,38 @@ async def create(

return manifest

async def reset_stats(self, prefix, context: dict):
"""Reset the artifact's manifest's download count and view count."""
if context is None or "ws" not in context:
raise ValueError("Context must include 'ws' (workspace).")
ws = context["ws"]

user_info = UserInfo.model_validate(context["user"])
if not user_info.check_permission(ws, UserPermission.read_write):
raise PermissionError(
"User does not have write permission to the workspace."
)

session = await self._get_session()
try:
async with session.begin():
artifact = await self._get_artifact(session, ws, prefix)
if not artifact:
raise KeyError(f"Artifact under prefix '{prefix}' does not exist.")
stmt = (
update(ArtifactModel)
.where(ArtifactModel.id == artifact.id)
.values(download_count=0, view_count=0)
.execution_options(synchronize_session="fetch")
)
await session.execute(stmt)
await session.commit()
logger.info(f"Reset artifact under prefix: {prefix}")
except Exception as e:
raise e
finally:
await session.close()

async def read(self, prefix, stage=False, context: dict = None):
"""Read the artifact's manifest from the database and populate collections dynamically."""
if context is None or "ws" not in context:
Expand Down Expand Up @@ -359,25 +365,27 @@ async def edit(self, prefix, manifest=None, context: dict = None):
"User does not have write permission to the workspace."
)

# Validate the manifest
if manifest["type"] == "collection":
CollectionArtifact.model_validate(manifest)
elif manifest["type"] == "application":
ApplicationArtifact.model_validate(manifest)
elif manifest["type"] == "workspace":
WorkspaceInfo.model_validate(manifest)
if manifest:
# Validate the manifest
if manifest["type"] == "collection":
CollectionArtifact.model_validate(manifest)
elif manifest["type"] == "application":
ApplicationArtifact.model_validate(manifest)
elif manifest["type"] == "workspace":
WorkspaceInfo.model_validate(manifest)

# Convert ObjectProxy to dict if necessary
if isinstance(manifest, ObjectProxy):
manifest = ObjectProxy.toDict(manifest)
# Convert ObjectProxy to dict if necessary
if isinstance(manifest, ObjectProxy):
manifest = ObjectProxy.toDict(manifest)

session = await self._get_session()
try:
async with session.begin():
artifact = await self._get_artifact(session, ws, prefix)
if not artifact:
raise KeyError(f"Artifact under prefix '{prefix}' does not exist.")

if manifest is None:
manifest = copy.deepcopy(artifact.manifest)
artifact.stage_manifest = manifest
flag_modified(artifact, "stage_manifest") # Mark JSON field as modified
session.add(artifact)
Expand Down Expand Up @@ -411,6 +419,7 @@ async def commit(self, prefix, context: dict):

manifest = artifact.stage_manifest

download_weights = {}
# Validate files exist in S3 if the staged files list is present
if artifact.stage_files:
async with self.s3_controller.create_client_async() as s3_client:
Expand All @@ -424,7 +433,10 @@ async def commit(self, prefix, context: dict):
raise FileNotFoundError(
f"File '{file_info['path']}' does not exist in the artifact."
)

if file_info.get("download_weight") is not None:
download_weights[file_info["path"]] = file_info[
"download_weight"
]
# Validate the schema if the artifact belongs to a collection
parent_prefix = "/".join(prefix.split("/")[:-1])
if parent_prefix:
Expand All @@ -447,7 +459,10 @@ async def commit(self, prefix, context: dict):
artifact.manifest = manifest
artifact.stage_manifest = None
artifact.stage_files = None
artifact.download_weights = download_weights
flag_modified(artifact, "manifest")
flag_modified(artifact, "stage_files")
flag_modified(artifact, "download_weights")
session.add(artifact)
await session.commit()
logger.info(f"Committed artifact under prefix: {prefix}")
Expand Down Expand Up @@ -636,7 +651,9 @@ async def search(
finally:
await session.close()

async def put_file(self, prefix, file_path, context: dict = None):
async def put_file(
self, prefix, file_path, options: dict = None, context: dict = None
):
"""Generate a pre-signed URL to upload a file to an artifact in S3 and update the manifest."""
ws = context["ws"]
user_info = UserInfo.model_validate(context["user"])
Expand All @@ -645,6 +662,8 @@ async def put_file(self, prefix, file_path, context: dict = None):
"User does not have write permission to the workspace."
)

options = options or {}

async with self.s3_controller.create_client_async() as s3_client:
file_key = safe_join(ws, f"{prefix}/{file_path}")
presigned_url = await s3_client.generate_presigned_url(
Expand All @@ -663,7 +682,12 @@ async def put_file(self, prefix, file_path, context: dict = None):
artifact.stage_files = artifact.stage_files or []

if not any(f["path"] == file_path for f in artifact.stage_files):
artifact.stage_files.append({"path": file_path})
artifact.stage_files.append(
{
"path": file_path,
"download_weight": options.get("download_weight"),
}
)
flag_modified(artifact, "stage_files")

session.add(artifact)
Expand All @@ -676,7 +700,7 @@ async def put_file(self, prefix, file_path, context: dict = None):

return presigned_url

async def get_file(self, prefix, path, context: dict):
async def get_file(self, prefix, path, options: dict = None, context: dict = None):
"""Generate a pre-signed URL to download a file from an artifact in S3."""
ws = context["ws"]

Expand All @@ -694,6 +718,30 @@ async def get_file(self, prefix, path, context: dict):
ExpiresIn=3600,
)
logger.info(f"Generated pre-signed URL for file download: {path}")

if options is None or not options.get("silent"):
session = await self._get_session()
try:
async with session.begin():
artifact = await self._get_artifact(session, ws, prefix)
if artifact.download_weights and path in artifact.download_weights:
# if it has download_weights, increment the download count by the weight
stmt = (
update(ArtifactModel)
.where(ArtifactModel.id == artifact.id)
# atomically increment the download count by the weight
.values(
download_count=ArtifactModel.download_count
+ artifact.download_weights[path]
)
.execution_options(synchronize_session="fetch")
)
await session.execute(stmt)
await session.commit()
except Exception as e:
raise e
finally:
await session.close()
return presigned_url

async def remove_file(self, prefix, file_path, context: dict):
Expand All @@ -714,11 +762,20 @@ async def remove_file(self, prefix, file_path, context: dict):
raise KeyError(
f"Artifact under prefix '{prefix}' is not in staging mode."
)
# remove the file from the staged files list
artifact.stage_files = [
f for f in artifact.stage_files if f["path"] != file_path
]
flag_modified(artifact, "stage_files")
if artifact.stage_files:
# remove the file from the staged files list
artifact.stage_files = [
f for f in artifact.stage_files if f["path"] != file_path
]
flag_modified(artifact, "stage_files")
if artifact.download_weights:
# remove the file from download_weights if it's there
artifact.download_weights = {
k: v
for k, v in artifact.download_weights.items()
if k != file_path
}
flag_modified(artifact, "download_weights")
session.add(artifact)
await session.commit()
except Exception as e:
Expand All @@ -740,6 +797,7 @@ def get_artifact_service(self):
"name": "Artifact Manager",
"description": "Manage artifacts in a workspace.",
"create": self.create,
"reset_stats": self.reset_stats,
"edit": self.edit,
"read": self.read,
"commit": self.commit,
Expand Down
5 changes: 4 additions & 1 deletion hypha/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,10 @@ async def disconnect(self, reason=None):

class RedisEventBus:
"""Represent a redis event bus."""
_counter = Counter("event_bus", "Counts the events on the redis event bus", ["event"])

_counter = Counter(
"event_bus", "Counts the events on the redis event bus", ["event"]
)

def __init__(self, redis) -> None:
"""Initialize the event bus."""
Expand Down
4 changes: 0 additions & 4 deletions hypha/core/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def __init__(
"""Initialize the redis store."""
self._s3_controller = None
self._artifact_manager = None
self._logging_service = None
self._app = app
self._codecs = {}
self._disconnected_plugins = []
Expand Down Expand Up @@ -392,8 +391,6 @@ async def init(self, reset_redis, startup_functions=None):
logger.warning("RESETTING ALL REDIS DATA!!!")
await self._redis.flushall()
await self._event_bus.init()
if self._logging_service:
await self._logging_service.init_db()
if self._artifact_manager:
await self._artifact_manager.init_db()
await self.setup_root_user()
Expand Down Expand Up @@ -640,7 +637,6 @@ async def register_workspace_manager(self):
self._sql_engine,
self._s3_controller,
self._artifact_manager,
self._logging_service,
)
await manager.setup()
return manager
Expand Down
2 changes: 0 additions & 2 deletions hypha/core/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def __init__(
sql_engine: Optional[str] = None,
s3_controller: Optional[Any] = None,
artifact_manager: Optional[Any] = None,
logging_service: Optional[Any] = None,
):
self._redis = redis
self._initialized = False
Expand All @@ -119,7 +118,6 @@ def __init__(
self._client_id = client_id
self._s3_controller = s3_controller
self._artifact_manager = artifact_manager
self._logging_service = logging_service
self._sql_engine = sql_engine
if self._sql_engine:
self.SessionLocal = async_sessionmaker(
Expand Down
Loading

0 comments on commit 66882d7

Please sign in to comment.