From 46b536fc4d13d5a1533506e57e6908307f6f5933 Mon Sep 17 00:00:00 2001 From: EugeneP Date: Mon, 4 Nov 2024 13:35:11 +0100 Subject: [PATCH] TLK-1864 agents deployments models refactoring - review fixes --- src/backend/routers/agent.py | 104 ++++++++++++++++++----------------- src/backend/routers/utils.py | 18 ++++++ 2 files changed, 71 insertions(+), 51 deletions(-) diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index b2730aafb6..7674eedb23 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -14,7 +14,7 @@ AgentToolMetadata as AgentToolMetadataModel, ) from backend.database_models.database import DBSessionDep -from backend.routers.utils import get_deployment_model_from_agent +from backend.routers.utils import get_deployment_model_from_agent, get_default_deployment_model from backend.schemas.agent import ( Agent, AgentPublic, @@ -53,6 +53,7 @@ ) router.name = RouterName.AGENT + @router.post( "", response_model=AgentPublic, @@ -62,9 +63,9 @@ ], ) async def create_agent( - session: DBSessionDep, - agent: CreateAgentRequest, - ctx: Context = Depends(get_context), + session: DBSessionDep, + agent: CreateAgentRequest, + ctx: Context = Depends(get_context), ) -> AgentPublic: """ Create an agent. @@ -83,6 +84,7 @@ async def create_agent( logger = ctx.get_logger() deployment_db, model_db = get_deployment_model_from_agent(agent, session) + default_deployment_db, default_model_db = get_default_deployment_model(session) try: if deployment_db and model_db: agent_data = AgentModel( @@ -94,8 +96,8 @@ async def create_agent( organization_id=agent.organization_id, tools=agent.tools, is_private=agent.is_private, - deployment_id=deployment_db.id if deployment_db else None, - model_id=model_db.id if model_db else None, + deployment_id=deployment_db.id if deployment_db else default_deployment_db.id if default_deployment_db else None, + model_id=model_db.id if model_db else default_model_db.id if default_model_db else None, ) created_agent = agent_crud.create_agent(session, agent_data) @@ -117,13 +119,13 @@ async def create_agent( @router.get("", response_model=list[AgentPublic]) async def list_agents( - *, - offset: int = 0, - limit: int = 100, - session: DBSessionDep, - visibility: AgentVisibility = AgentVisibility.ALL, - organization_id: Optional[str] = None, - ctx: Context = Depends(get_context), + *, + offset: int = 0, + limit: int = 100, + session: DBSessionDep, + visibility: AgentVisibility = AgentVisibility.ALL, + organization_id: Optional[str] = None, + ctx: Context = Depends(get_context), ) -> list[AgentPublic]: """ List all agents. @@ -161,7 +163,7 @@ async def list_agents( @router.get("/{agent_id}", response_model=AgentPublic) async def get_agent_by_id( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) ) -> Agent: """ Args: @@ -196,7 +198,7 @@ async def get_agent_by_id( @router.get("/{agent_id}/deployments", response_model=list[DeploymentSchema]) async def get_agent_deployment( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) ) -> DeploymentSchema: """ Args: @@ -228,10 +230,10 @@ async def get_agent_deployment( ], ) async def update_agent( - agent_id: str, - new_agent: UpdateAgentRequest, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + new_agent: UpdateAgentRequest, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> AgentPublic: """ Update an agent by ID. @@ -285,9 +287,9 @@ async def update_agent( @router.delete("/{agent_id}", response_model=DeleteAgent) async def delete_agent( - agent_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgent: """ Delete an agent by ID. @@ -319,10 +321,10 @@ async def delete_agent( async def handle_tool_metadata_update( - agent: Agent, - new_agent: Agent, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent: Agent, + new_agent: Agent, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> Agent: """Update or create tool metadata for an agent. @@ -360,10 +362,10 @@ async def handle_tool_metadata_update( async def update_or_create_tool_metadata( - agent: Agent, - new_tool_metadata: AgentToolMetadata, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent: Agent, + new_tool_metadata: AgentToolMetadata, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> None: """Update or create tool metadata for an agent. @@ -389,7 +391,7 @@ async def update_or_create_tool_metadata( @router.get("/{agent_id}/tool-metadata", response_model=list[AgentToolMetadataPublic]) async def list_agent_tool_metadata( - agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) + agent_id: str, session: DBSessionDep, ctx: Context = Depends(get_context) ) -> list[AgentToolMetadataPublic]: """ List all agent tool metadata by agent ID. @@ -421,10 +423,10 @@ async def list_agent_tool_metadata( response_model=AgentToolMetadataPublic, ) def create_agent_tool_metadata( - session: DBSessionDep, - agent_id: str, - agent_tool_metadata: CreateAgentToolMetadataRequest, - ctx: Context = Depends(get_context), + session: DBSessionDep, + agent_id: str, + agent_tool_metadata: CreateAgentToolMetadataRequest, + ctx: Context = Depends(get_context), ) -> AgentToolMetadataPublic: """ Create an agent tool metadata. @@ -470,11 +472,11 @@ def create_agent_tool_metadata( @router.put("/{agent_id}/tool-metadata/{agent_tool_metadata_id}") async def update_agent_tool_metadata( - agent_id: str, - agent_tool_metadata_id: str, - session: DBSessionDep, - new_agent_tool_metadata: UpdateAgentToolMetadataRequest, - ctx: Context = Depends(get_context), + agent_id: str, + agent_tool_metadata_id: str, + session: DBSessionDep, + new_agent_tool_metadata: UpdateAgentToolMetadataRequest, + ctx: Context = Depends(get_context), ) -> AgentToolMetadata: """ Update an agent tool metadata by ID. @@ -514,10 +516,10 @@ async def update_agent_tool_metadata( @router.delete("/{agent_id}/tool-metadata/{agent_tool_metadata_id}") async def delete_agent_tool_metadata( - agent_id: str, - agent_tool_metadata_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + agent_tool_metadata_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgentToolMetadata: """ Delete an agent tool metadata by ID. @@ -556,9 +558,9 @@ async def delete_agent_tool_metadata( @router.post("/batch_upload_file", response_model=list[UploadAgentFileResponse]) async def batch_upload_file( - session: DBSessionDep, - files: list[FastAPIUploadFile] = RequestFile(...), - ctx: Context = Depends(get_context), + session: DBSessionDep, + files: list[FastAPIUploadFile] = RequestFile(...), + ctx: Context = Depends(get_context), ) -> UploadAgentFileResponse: user_id = ctx.get_user_id() @@ -580,10 +582,10 @@ async def batch_upload_file( @router.delete("/{agent_id}/files/{file_id}") async def delete_agent_file( - agent_id: str, - file_id: str, - session: DBSessionDep, - ctx: Context = Depends(get_context), + agent_id: str, + file_id: str, + session: DBSessionDep, + ctx: Context = Depends(get_context), ) -> DeleteAgentFileResponse: """ Delete an agent file by ID. diff --git a/src/backend/routers/utils.py b/src/backend/routers/utils.py index 4b554535f2..dada42e225 100644 --- a/src/backend/routers/utils.py +++ b/src/backend/routers/utils.py @@ -1,3 +1,4 @@ +from backend.config.deployments import ModelDeploymentName from backend.database_models.database import DBSessionDep from backend.schemas.agent import Agent @@ -21,3 +22,20 @@ def get_deployment_model_from_agent(agent: Agent, session: DBSessionDep): None, ) return deployment_db, model_db + + +def get_default_deployment_model(session: DBSessionDep): + from backend.crud import deployment as deployment_crud + + deployment_db = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + model_db = None + if deployment_db: + model_db = next( + ( + model + for model in deployment_db.models + if model.name == 'command-r-plus' + ), + None, + ) + return deployment_db, model_db