Skip to content

Commit

Permalink
toolkit: add file content viewer (#825)
Browse files Browse the repository at this point in the history
* Add icons

* Run format-web for MessageRow

* Add handlers

* Fetch conv files

* Fetch agent files

* Add unit tests

* Generate web client

* Add API calls

* Add content for FileViewer

* Add padding settings for modals

* Refactor styles

* Minor clean up

* Run format-web

* Merge AgentFileFull and ConversationFileFull

* Add error message
  • Loading branch information
danylo-boiko authored Nov 12, 2024
1 parent 535a77a commit 6bf5847
Show file tree
Hide file tree
Showing 16 changed files with 538 additions and 56 deletions.
9 changes: 7 additions & 2 deletions src/backend/crud/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def batch_create_files(db: Session, files: list[File]) -> list[File]:


@validate_transaction
def get_file(db: Session, file_id: str, user_id: str) -> File:
def get_file(db: Session, file_id: str, user_id: str | None = None) -> File:
"""
Get a file by ID.
Expand All @@ -47,7 +47,12 @@ def get_file(db: Session, file_id: str, user_id: str) -> File:
Returns:
File: File with the given ID.
"""
return db.query(File).filter(File.id == file_id, File.user_id == user_id).first()
filters = [File.id == file_id]

if user_id:
filters.append(File.user_id == user_id)

return db.query(File).filter(*filters).first()


@validate_transaction
Expand Down
57 changes: 55 additions & 2 deletions src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from backend.config.routers import RouterName
from backend.crud import agent as agent_crud
from backend.crud import agent_tool_metadata as agent_tool_metadata_crud
from backend.crud import file as file_crud
from backend.crud import snapshot as snapshot_crud
from backend.database_models.agent import Agent as AgentModel
from backend.database_models.agent_tool_metadata import (
Expand All @@ -34,7 +35,11 @@
)
from backend.schemas.context import Context
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.file import DeleteAgentFileResponse, UploadAgentFileResponse
from backend.schemas.file import (
DeleteAgentFileResponse,
FileMetadata,
UploadAgentFileResponse,
)
from backend.services.agent import (
raise_db_error,
validate_agent_exists,
Expand Down Expand Up @@ -583,6 +588,54 @@ async def batch_upload_file(
return uploaded_files


@router.get("/{agent_id}/files/{file_id}", response_model=FileMetadata)
async def get_agent_file(
agent_id: str,
file_id: str,
session: DBSessionDep,
ctx: Context = Depends(get_context),
) -> FileMetadata:
"""
Get an agent file by ID.
Args:
agent_id (str): Agent ID.
file_id (str): File ID.
session (DBSessionDep): Database session.
ctx (Context): Context object.
Returns:
FileMetadata: File with the given ID.
Raises:
HTTPException: If the agent or file with the given ID is not found, or if the file does not belong to the agent.
"""
user_id = ctx.get_user_id()

if file_id not in get_file_service().get_file_ids_by_agent_id(session, user_id, agent_id, ctx):
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} does not belong to the agent with ID: {agent_id}."
)

file = file_crud.get_file(session, file_id)

if not file:
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} not found.",
)

return FileMetadata(
id=file.id,
file_name=file.file_name,
file_content=file.file_content,
file_size=file.file_size,
created_at=file.created_at,
updated_at=file.updated_at,
)


@router.delete("/{agent_id}/files/{file_id}")
async def delete_agent_file(
agent_id: str,
Expand All @@ -605,7 +658,7 @@ async def delete_agent_file(
HTTPException: If the agent with the given ID is not found.
"""
user_id = ctx.get_user_id()
_ = validate_agent_exists(session, agent_id)
_ = validate_agent_exists(session, agent_id, user_id)
validate_file(session, file_id, user_id)

# Delete the File DB object
Expand Down
45 changes: 43 additions & 2 deletions src/backend/routers/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from backend.schemas.file import (
DeleteConversationFileResponse,
FileMetadata,
ListConversationFile,
UploadConversationFileResponse,
)
Expand Down Expand Up @@ -461,6 +462,47 @@ async def list_files(
return files_with_conversation_id


@router.get("/{conversation_id}/files/{file_id}", response_model=FileMetadata)
async def get_file(
conversation_id: str, file_id: str, session: DBSessionDep, ctx: Context = Depends(get_context)
) -> FileMetadata:
"""
Get a conversation file by ID.
Args:
conversation_id (str): Conversation ID.
file_id (str): File ID.
session (DBSessionDep): Database session.
ctx (Context): Context object.
Returns:
FileMetadata: File with the given ID.
Raises:
HTTPException: If the conversation or file with the given ID is not found, or if the file does not belong to the conversation.
"""
user_id = ctx.get_user_id()

conversation = validate_conversation(session, conversation_id, user_id)

if file_id not in conversation.file_ids:
raise HTTPException(
status_code=404,
detail=f"File with ID: {file_id} does not belong to the conversation with ID: {conversation.id}."
)

file = validate_file(session, file_id, user_id)

return FileMetadata(
id=file.id,
file_name=file.file_name,
file_content=file.file_content,
file_size=file.file_size,
created_at=file.created_at,
updated_at=file.updated_at,
)


@router.delete("/{conversation_id}/files/{file_id}")
async def delete_file(
conversation_id: str,
Expand All @@ -484,8 +526,7 @@ async def delete_file(
"""
user_id = ctx.get_user_id()
_ = validate_conversation(session, conversation_id, user_id)
validate_file(session, file_id, user_id )

validate_file(session, file_id, user_id)
# Delete the File DB object
get_file_service().delete_conversation_file_by_id(
session, conversation_id, file_id, user_id, ctx
Expand Down
11 changes: 10 additions & 1 deletion src/backend/schemas/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ class ConversationFilePublic(BaseModel):
file_size: int = Field(default=0, ge=0)



class AgentFilePublic(BaseModel):
id: str
created_at: datetime.datetime
Expand All @@ -39,6 +38,16 @@ class AgentFilePublic(BaseModel):
file_name: str
file_size: int = Field(default=0, ge=0)


class FileMetadata(BaseModel):
id: str
file_name: str
file_content: str
file_size: int = Field(default=0, ge=0)
created_at: datetime.datetime
updated_at: datetime.datetime


class ListConversationFile(ConversationFilePublic):
pass

Expand Down
69 changes: 44 additions & 25 deletions src/backend/services/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,49 +119,66 @@ async def create_agent_files(

return uploaded_files

def get_files_by_agent_id(
def get_file_ids_by_agent_id(
self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context
) -> list[File]:
) -> list[str]:
"""
Get files by agent ID
Get file IDs associated with a specific agent ID
Args:
session (DBSessionDep): The database session
user_id (str): The user ID
agent_id (str): The agent ID
ctx (Context): Context object
Returns:
list[File]: The files that were created
list[str]: IDs of files that were created
"""
from backend.config.tools import Tool
from backend.tools.files import FileToolsArtifactTypes

agent = validate_agent_exists(session, agent_id, user_id)

files = []
agent_tool_metadata = agent.tools_metadata
if agent_tool_metadata is not None and len(agent_tool_metadata) > 0:
artifacts = next(
(
tool_metadata.artifacts
for tool_metadata in agent_tool_metadata
if tool_metadata.tool_name == Tool.Read_File.value.ID
or tool_metadata.tool_name == Tool.Search_File.value.ID
),
[], # Default value if the generator is empty
)
if not agent.tools_metadata:
return []

artifacts = next(
(
tool_metadata.artifacts
for tool_metadata in agent.tools_metadata
if tool_metadata.tool_name == Tool.Read_File.value.ID
or tool_metadata.tool_name == Tool.Search_File.value.ID
),
[], # Default value if the generator is empty
)

file_ids = list(
{
artifact.get("id")
for artifact in artifacts
if artifact.get("type") == FileToolsArtifactTypes.local_file
}
)
return [
artifact.get("id")
for artifact in artifacts
if artifact.get("type") == FileToolsArtifactTypes.local_file
]

files = file_crud.get_files_by_ids(session, file_ids, user_id)
def get_files_by_agent_id(
self, session: DBSessionDep, user_id: str, agent_id: str, ctx: Context
) -> list[File]:
"""
Get files by agent ID
return files
Args:
session (DBSessionDep): The database session
user_id (str): The user ID
agent_id (str): The agent ID
ctx (Context): Context object
Returns:
list[File]: The files that were created
"""
file_ids = self.get_file_ids_by_agent_id(session, user_id, agent_id, ctx)

if not file_ids:
return []

return file_crud.get_files_by_ids(session, file_ids, user_id)

def get_files_by_conversation_id(
self, session: DBSessionDep, user_id: str, conversation_id: str, ctx: Context
Expand Down Expand Up @@ -312,6 +329,8 @@ def validate_file(
detail=f"File with ID: {file_id} not found.",
)

return file


async def insert_files_in_db(
session: DBSessionDep,
Expand Down
51 changes: 51 additions & 0 deletions src/backend/tests/unit/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,57 @@ def test_list_files_missing_user_id(
assert response.json() == {"detail": "User-Id required in request headers."}


def test_get_file(
session_client: TestClient, session: Session, user: User
) -> None:
conversation = get_factory("Conversation", session).create(user_id=user.id)
response = session_client.post(
"/v1/conversations/batch_upload_file",
headers={"User-Id": conversation.user_id},
files=[
("files", ("Mariana_Trench.pdf", open("src/backend/tests/unit/test_data/Mariana_Trench.pdf", "rb")))
],
data={"conversation_id": conversation.id},
)
assert response.status_code == 200
uploaded_file = response.json()[0]

response = session_client.get(
f"/v1/conversations/{conversation.id}/files/{uploaded_file['id']}",
headers={"User-Id": conversation.user_id},
)

assert response.status_code == 200
response_file = response.json()
assert response_file["id"] == uploaded_file["id"]
assert response_file["file_name"] == uploaded_file["file_name"]


def test_fail_get_file_nonexistent_conversation(
session_client: TestClient, session: Session, user: User
) -> None:
response = session_client.get(
"/v1/conversations/123/files/456",
headers={"User-Id": user.id},
)

assert response.status_code == 404
assert response.json() == {"detail": "Conversation with ID: 123 not found."}


def test_fail_get_file_nonbelong_file(
session_client: TestClient, session: Session, user: User
) -> None:
conversation = get_factory("Conversation", session).create(user_id=user.id)
response = session_client.get(
f"/v1/conversations/{conversation.id}/files/123",
headers={"User-Id": conversation.user_id},
)

assert response.status_code == 404
assert response.json() == {"detail": f"File with ID: 123 does not belong to the conversation with ID: {conversation.id}."}


def test_batch_upload_file_existing_conversation(
session_client: TestClient, session: Session, user
) -> None:
Expand Down
20 changes: 20 additions & 0 deletions src/interfaces/assistants_web/src/cohere-client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ export class CohereClient {
});
}

public getConversationFile({
conversationId,
fileId,
}: {
conversationId: string;
fileId: string;
}) {
return this.cohereService.default.getFileV1ConversationsConversationIdFilesFileIdGet({
conversationId,
fileId,
});
}

public batchUploadConversationFile(
formData: Body_batch_upload_file_v1_conversations_batch_upload_file_post
) {
Expand All @@ -61,6 +74,13 @@ export class CohereClient {
});
}

public getAgentFile({ agentId, fileId }: { agentId: string; fileId: string }) {
return this.cohereService.default.getAgentFileV1AgentsAgentIdFilesFileIdGet({
agentId,
fileId,
});
}

public batchUploadAgentFile(formData: Body_batch_upload_file_v1_agents_batch_upload_file_post) {
return this.cohereService.default.batchUploadFileV1AgentsBatchUploadFilePost({
formData,
Expand Down
Loading

0 comments on commit 6bf5847

Please sign in to comment.