From b851a070708e16daab84bb9afdd7def8c264e9ea Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Thu, 26 Sep 2024 11:39:00 -0400 Subject: [PATCH] backend: Remove metrics (#784) * wip * Remove metrics code * add workflow * test * test * test * test * test * fix --- .github/workflows/python-lint.yml | 6 +- .github/workflows/python-typecheck.yml | 15 +- helper_scripts/metrics_helper.py | 170 ------- pyproject.toml | 2 +- src/backend/config/tools.py | 1 - src/backend/main.py | 2 - src/backend/model_deployments/azure.py | 3 - src/backend/model_deployments/bedrock.py | 3 - .../model_deployments/cohere_platform.py | 3 - src/backend/model_deployments/sagemaker.py | 3 - .../model_deployments/single_container.py | 3 - src/backend/routers/agent.py | 31 -- src/backend/routers/chat.py | 12 - src/backend/routers/conversation.py | 7 - src/backend/routers/user.py | 5 - src/backend/schemas/context.py | 23 - src/backend/schemas/metrics.py | 112 ----- src/backend/services/metrics.py | 435 ------------------ .../tests/integration/routers/test_agent.py | 84 ---- .../integration/routers/test_conversation.py | 44 -- src/backend/tests/unit/routers/test_agent.py | 102 ---- src/backend/tests/unit/routers/test_chat.py | 48 -- src/backend/tests/unit/routers/test_user.py | 55 --- 23 files changed, 18 insertions(+), 1151 deletions(-) delete mode 100644 helper_scripts/metrics_helper.py delete mode 100644 src/backend/schemas/metrics.py delete mode 100644 src/backend/services/metrics.py diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml index d8405658e3..b54e41f6ce 100644 --- a/.github/workflows/python-lint.yml +++ b/.github/workflows/python-lint.yml @@ -9,7 +9,11 @@ jobs: lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + clean: true - name: Run lint checks uses: chartboost/ruff-action@v1 with: diff --git a/.github/workflows/python-typecheck.yml b/.github/workflows/python-typecheck.yml index 5241aaea90..facef2aff4 100644 --- a/.github/workflows/python-typecheck.yml +++ b/.github/workflows/python-typecheck.yml @@ -1,4 +1,4 @@ -name: Typecheck +name: Typecheck newly added Python files on: push: @@ -13,10 +13,19 @@ jobs: - run: pipx install poetry - uses: actions/setup-python@v4 with: - python-version: '3.11' + python-version: '3.11' cache: 'poetry' - run: poetry install - run: echo "$(poetry env info --path)/bin" >> $GITHUB_PATH - - uses: jakebailey/pyright-action@v2 + - name: Get new files added in the PR + id: get_new_files + run: | + git fetch origin main + NEW_PY_FILES=$(git diff --name-only --diff-filter=A origin/main HEAD | grep '\.py$' | tr '\n' ' ') + echo "New files: $NEW_PY_FILES" + echo "new_py_files=$NEW_PY_FILES" >> $GITHUB_OUTPUT + - name: Typecheck new files + uses: jakebailey/pyright-action@v2 with: version: 1.1.311 + extra-args: ${{ steps.get_new_files.outputs.new_py_files }} \ No newline at end of file diff --git a/helper_scripts/metrics_helper.py b/helper_scripts/metrics_helper.py deleted file mode 100644 index 0d6c5345f3..0000000000 --- a/helper_scripts/metrics_helper.py +++ /dev/null @@ -1,170 +0,0 @@ -import re -from uuid import uuid4 - -import requests - - -def agents(): - print("Running Agents") - ## Agents - # Create Agent - response = requests.post( - f"{base_url}/agents", - headers=headers, - json={ - "name": str(uuid4()), - "version": 1, - "description": "test description", - "preamble": "test preamble", - "temperature": 0.5, - "model": "command-r-plus", - "deployment": "Cohere Platform", - "tools": [ - "search_file", - "read_document", - "toolkit_calculator", - "wikipedia", - ], - }, - ) - print("create agent") - print(response.status_code) - agent_id = response.json()["id"] - # # List Agents - response = requests.get(f"{base_url}/agents", headers=headers) - print("list agents") - print(response.status_code) - - # # Get Agent - response = requests.get(f"{base_url}/agents/{agent_id}", headers=headers) - print("get agent") - print(response.status_code) - - # # Update Agent - response = requests.put( - f"{base_url}/agents/{agent_id}", headers=headers, json={"name": str(uuid4())} - ) - print("update agent") - print(response.status_code) - # print(response.json()) - - return agent_id - - -## Users -# Create User -def users(): - print("running users") - # # List Users - response = requests.get(f"{base_url}/users", headers=headers) - print(response.status_code) - # # Get User - response = requests.get(f"{base_url}/users/{user_id}", headers=headers) - print(response.status_code) - # # Update User - response = requests.put( - f"{base_url}/users/{user_id}", headers=headers, json={"fullname": "new name"} - ) - print(response.status_code) - - -# Chat -def chat(agent_id): - print("Running chat") - - response = requests.post( - f"http://localhost:8000/v1/chat-stream?agent_id={agent_id}", - headers=headers, - json={"message": "who is bo burnham?", "tools": [{"name": "web_search"}]}, - ) - - print(response.status_code) - - conversation_id = None - for event in response.iter_lines(): - if not event: - continue - - str_event = str(event) - - if "stream-start" in str_event: - match = re.search(r'"conversation_id": "([^"]*)"', str_event) - if match: - conversation_id = match.group(1) - - return conversation_id - - -def tools(conversation_id): - print("Running tools") - ## Tools - # List Tools - res = requests.get(f"{base_url}/tools", headers=headers) - print(res.status_code) - # List Tools per Agent - res = requests.get(f"{base_url}/tools?agent_id={agent_id}", headers=headers) - print(res.status_code) - ## Conversations - # List Conversations - res = requests.get(f"{base_url}/conversations", headers=headers) - print(res.status_code) - # Get Conversation - res = requests.get(f"{base_url}/conversations/{conversation_id}", headers=headers) - print(res.status_code) - # Update Conversation - res = requests.put( - f"{base_url}/conversations/{conversation_id}", - headers=headers, - json={"title": "new_title"}, - ) - - # del conversation - res = requests.delete( - f"{base_url}/conversations/{conversation_id}", headers=headers - ) - print(res.status_code) - - -# Delete Everything -def cleanup(user_id, agent_id): - print("cleaning up") - response = requests.delete(f"{base_url}/users/{user_id}", headers=headers) - print(response.status_code) - if agent_id: - response = requests.delete(f"{base_url}/agents/{agent_id}", headers=headers) - print(response.status_code) - - -base_url = "http://localhost:8000/v1" -headers = { - "User-Id": "admin", - "Deployment-Name": "Cohere Platform", - "Content-Type": "application/json", -} - -# Notes: -# web_search implicitly calls rerank -# - TAVILY_API_KEY required for web search -# in case of issues, prune docker images and try again -# TODO: please do not use global variables :,( - -# initial setup -print("setting up") -response = requests.post( - f"{base_url}/users", headers=headers, json={"fullname": "qa tester"} -) -response_json = response.json() -user_id = response_json["id"] -# update user id with correct value going forward -headers["User-Id"] = user_id -print("Setup user info") -print(response_json) - - -# TODO: make these into tests -users() -agent_id = None -agent_id = agents() -conversation_id = chat(agent_id=agent_id) -tools(conversation_id=conversation_id) -cleanup(user_id=user_id, agent_id=agent_id) diff --git a/pyproject.toml b/pyproject.toml index 7809955de6..bb1a8709a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ build-backend = "poetry.core.masonry.api" include = [ "src/backend/services/metrics.py", "src/backend/tools/google_drive/sync/actions/", - ] +] defineConstant = { DEBUG = true } reportMissingImports = true reportMissingTypeStubs = false diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index 490d5aaf41..a6da90b087 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -31,7 +31,6 @@ Don't forget to add the implementation to this AVAILABLE_TOOLS dictionary! """ - class ToolName(StrEnum): Wiki_Retriever_LangChain = LangChainWikiRetriever.NAME Search_File = SearchFileTool.NAME diff --git a/src/backend/main.py b/src/backend/main.py index df43269ef4..f500101078 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -29,7 +29,6 @@ from backend.routers.user import router as user_router from backend.services.context import ContextMiddleware, get_context from backend.services.logger.middleware import LoggingMiddleware -from backend.services.metrics import MetricsMiddleware load_dotenv() @@ -81,7 +80,6 @@ def create_app(): allow_headers=["*"], ) app.add_middleware(LoggingMiddleware) - app.add_middleware(MetricsMiddleware) app.add_middleware(ContextMiddleware) # This should be the first middleware app.add_exception_handler(SCIMException, scim_exception_handler) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index e9074dada1..4c373087f3 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -8,7 +8,6 @@ from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.services.metrics import collect_metrics_chat_stream, collect_metrics_rerank AZURE_API_KEY_ENV_VAR = "AZURE_API_KEY" # Example URL: "https://..inference.ai.azure.com/v1" @@ -69,7 +68,6 @@ async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: ) yield to_dict(response) - @collect_metrics_chat_stream async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs ) -> AsyncGenerator[Any, Any]: @@ -80,7 +78,6 @@ async def invoke_chat_stream( for event in stream: yield to_dict(event) - @collect_metrics_rerank async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], ctx: Context ) -> Any: diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index dd288556bf..fa3eb5613b 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -8,7 +8,6 @@ from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.services.metrics import collect_metrics_chat_stream, collect_metrics_rerank BEDROCK_ACCESS_KEY_ENV_VAR = "BEDROCK_ACCESS_KEY" BEDROCK_SECRET_KEY_ENV_VAR = "BEDROCK_SECRET_KEY" @@ -80,7 +79,6 @@ async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: ) yield to_dict(response) - @collect_metrics_chat_stream async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> AsyncGenerator[Any, Any]: @@ -95,7 +93,6 @@ async def invoke_chat_stream( for event in stream: yield to_dict(event) - @collect_metrics_rerank async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], ctx: Context ) -> Any: diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index 81f667eceb..e6d08a1075 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -10,7 +10,6 @@ from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.services.logger.utils import LoggerFactory -from backend.services.metrics import collect_metrics_chat_stream, collect_metrics_rerank COHERE_API_KEY_ENV_VAR = "COHERE_API_KEY" COHERE_ENV_VARS = [COHERE_API_KEY_ENV_VAR] @@ -73,7 +72,6 @@ async def invoke_chat( ) yield to_dict(response) - @collect_metrics_chat_stream async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> Any: @@ -96,7 +94,6 @@ async def invoke_chat_stream( yield event_dict - @collect_metrics_rerank async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any ) -> Any: diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index b5e8288392..5eafbd763a 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -9,7 +9,6 @@ from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.services.metrics import collect_metrics_chat_stream, collect_metrics_rerank SAGE_MAKER_ACCESS_KEY_ENV_VAR = "SAGE_MAKER_ACCESS_KEY" SAGE_MAKER_SECRET_KEY_ENV_VAR = "SAGE_MAKER_SECRET_KEY" @@ -93,7 +92,6 @@ def is_available(cls) -> bool: and SageMakerDeployment.aws_session_token is not None ) - @collect_metrics_chat_stream async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> AsyncGenerator[Any, Any]: @@ -115,7 +113,6 @@ async def invoke_chat_stream( stream_event["index"] = index yield stream_event - @collect_metrics_rerank async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], ctx: Context ) -> Any: diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index c77c34edaf..2cfc36cd31 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -8,7 +8,6 @@ from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context -from backend.services.metrics import collect_metrics_chat_stream, collect_metrics_rerank DEFAULT_RERANK_MODEL = "rerank-english-v2.0" SC_URL_ENV_VAR = "SINGLE_CONTAINER_URL" @@ -61,7 +60,6 @@ async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: ) yield to_dict(response) - @collect_metrics_chat_stream async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> AsyncGenerator[Any, Any]: @@ -74,7 +72,6 @@ async def invoke_chat_stream( for event in stream: yield to_dict(event) - @collect_metrics_rerank async def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], ctx: Context ) -> Any: diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index d15a2d4b24..3d5cea2403 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -31,12 +31,6 @@ from backend.schemas.context import Context from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.file import DeleteAgentFileResponse, UploadAgentFileResponse -from backend.schemas.metrics import ( - DEFAULT_METRICS_AGENT, - GenericResponseMessage, - MetricsMessageType, - agent_to_metrics_agent, -) from backend.services.agent import ( raise_db_error, validate_agent_exists, @@ -84,8 +78,6 @@ async def create_agent( Raises: HTTPException: If the agent creation fails. """ - # add user data into request state for metrics - ctx.with_event_type(MetricsMessageType.ASSISTANT_CREATED) ctx.with_user(session) user_id = ctx.get_user_id() logger = ctx.get_logger() @@ -127,7 +119,6 @@ async def create_agent( agent_schema = Agent.model_validate(created_agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent_schema)) return created_agent except Exception as e: @@ -195,7 +186,6 @@ async def get_agent_by_id( Raises: HTTPException: If the agent with the given ID is not found. """ - ctx.with_event_type(MetricsMessageType.ASSISTANT_ACCESSED) user_id = ctx.get_user_id() agent = None @@ -212,7 +202,6 @@ async def get_agent_by_id( agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) return agent @@ -275,7 +264,6 @@ async def update_agent( """ user_id = ctx.get_user_id() ctx.with_user(session) - ctx.with_event_type(MetricsMessageType.ASSISTANT_UPDATED) agent = validate_agent_exists(session, agent_id, user_id) if new_agent.tools_metadata is not None: @@ -342,7 +330,6 @@ async def update_agent( ) agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: @@ -374,11 +361,9 @@ async def delete_agent( HTTPException: If the agent with the given ID is not found. """ user_id = ctx.get_user_id() - ctx.with_event_type(MetricsMessageType.ASSISTANT_DELETED) agent = validate_agent_exists(session, agent_id, user_id) agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) deleted = agent_crud.delete_agent(session, agent_id, user_id) if not deleted: @@ -686,19 +671,3 @@ async def delete_agent_file( prefix="/v1/default_agent", ) default_agent_router.name = RouterName.DEFAULT_AGENT - - -@default_agent_router.get("/", response_model=GenericResponseMessage) -async def get_default_agent(ctx: Context = Depends(get_context)): - """Get the default agent - used for logging purposes. - - Args: - session (DBSessionDep): Database session. - ctx (Context): Context object. - - Returns: - GenericResponseMessage: OK message. - """ - ctx.with_event_type(MetricsMessageType.ASSISTANT_ACCESSED) - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) - return {"message": "OK"} diff --git a/src/backend/routers/chat.py b/src/backend/routers/chat.py index 1547016116..e5d939ee99 100644 --- a/src/backend/routers/chat.py +++ b/src/backend/routers/chat.py @@ -14,7 +14,6 @@ from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.schemas.langchain_chat import LangchainChatRequest -from backend.schemas.metrics import DEFAULT_METRICS_AGENT, agent_to_metrics_agent from backend.services.agent import validate_agent_exists from backend.services.chat import ( generate_chat_response, @@ -70,10 +69,6 @@ async def chat_stream( ] ctx.with_agent_tool_metadata(agent_tool_metadata_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) - else: - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) - ( session, chat_request, @@ -144,10 +139,6 @@ async def regenerate_chat_stream( ] ctx.with_agent_tool_metadata(agent_tool_metadata_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) - else: - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) - ( session, chat_request, @@ -216,9 +207,6 @@ async def chat( AgentToolMetadata.model_validate(x) for x in agent_tool_metadata ] ctx.with_agent_tool_metadata(agent_tool_metadata_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) - else: - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) ( session, diff --git a/src/backend/routers/conversation.py b/src/backend/routers/conversation.py index cb85cf8d56..8da4e25622 100644 --- a/src/backend/routers/conversation.py +++ b/src/backend/routers/conversation.py @@ -25,7 +25,6 @@ ListConversationFile, UploadConversationFileResponse, ) -from backend.schemas.metrics import DEFAULT_METRICS_AGENT, agent_to_metrics_agent from backend.services.agent import validate_agent_exists from backend.services.context import get_context from backend.services.conversation import ( @@ -312,9 +311,6 @@ async def search_conversations( if agent_id: agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) - else: - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) conversations = conversation_crud.get_conversations( session, offset=offset, limit=limit, user_id=user_id, agent_id=agent_id @@ -530,9 +526,6 @@ async def generate_title( agent = agent_crud.get_agent_by_id(session, agent_id, user_id) agent_schema = Agent.model_validate(agent) ctx.with_agent(agent_schema) - ctx.with_metrics_agent(agent_to_metrics_agent(agent)) - else: - ctx.with_metrics_agent(DEFAULT_METRICS_AGENT) title, error = await generate_conversation_title( session, diff --git a/src/backend/routers/user.py b/src/backend/routers/user.py index 4a83706f7f..8b13d9649a 100644 --- a/src/backend/routers/user.py +++ b/src/backend/routers/user.py @@ -5,7 +5,6 @@ from backend.database_models import User as UserModel from backend.database_models.database import DBSessionDep from backend.schemas.context import Context -from backend.schemas.metrics import MetricsMessageType from backend.schemas.user import CreateUser, DeleteUser, UpdateUser, User from backend.schemas.user import User as UserSchema from backend.services.context import get_context @@ -31,8 +30,6 @@ async def create_user( Returns: User: Created user. """ - ctx.with_event_type(MetricsMessageType.USER_CREATED) - db_user = UserModel(**user.model_dump(exclude_none=True)) db_user = user_crud.create_user(session, db_user) @@ -122,7 +119,6 @@ async def update_user( HTTPException: If the user with the given ID is not found. """ user = user_crud.get_user(session, user_id) - ctx.with_event_type(MetricsMessageType.USER_UPDATED) if not user: raise HTTPException( @@ -156,7 +152,6 @@ async def delete_user( Raises: HTTPException: If the user with the given ID is not found. """ - ctx.with_event_type(MetricsMessageType.USER_DELETED) user = user_crud.get_user(session, user_id) if not user: diff --git a/src/backend/schemas/context.py b/src/backend/schemas/context.py index eefde4142a..19365e62a8 100644 --- a/src/backend/schemas/context.py +++ b/src/backend/schemas/context.py @@ -7,7 +7,6 @@ from backend.database_models.database import DBSessionDep from backend.schemas import Organization from backend.schemas.agent import Agent, AgentToolMetadata -from backend.schemas.metrics import MetricsAgent, MetricsMessageType, MetricsUser from backend.schemas.user import User from backend.services.logger.utils import LoggerFactory from backend.services.utils import get_deployment_config @@ -19,7 +18,6 @@ class Context(BaseModel): receive: Optional[dict] = {} trace_id: str = "default" user_id: str = "default" - event_type: MetricsMessageType = None user: Optional[User] = None agent: Optional[Agent] = None agent_tool_metadata: Optional[AgentToolMetadata] = None @@ -34,10 +32,6 @@ class Context(BaseModel): organization: Optional[Organization] = None use_global_filtering: Optional[bool] = False - # Metrics - metrics_user: Optional[MetricsUser] = None - metrics_agent: Optional[MetricsAgent] = None - def __init__(self): super().__init__() self.with_logger() @@ -69,10 +63,6 @@ def with_user_id(self, user_id: str): def with_deployment_name(self, deployment_name: str): self.deployment_name = deployment_name - def with_event_type(self, event_type: MetricsMessageType) -> "Context": - self.event_type = event_type - return self - def with_user( self, session: DBSessionDep | None = None, user: User | None = None ) -> "Context": @@ -84,9 +74,6 @@ def with_user( user = User.model_validate(user) if user: - self.metrics_user = MetricsUser( - id=user.id, email=user.email, fullname=user.fullname - ) self.user = user return self @@ -95,10 +82,6 @@ def with_agent(self, agent: Agent | None) -> "Context": self.agent = agent return self - def with_metrics_agent(self, metrics_agent: MetricsAgent) -> "Context": - self.metrics_agent = metrics_agent - return self - def with_agent_tool_metadata( self, agent_tool_metadata: AgentToolMetadata ) -> "Context": @@ -187,12 +170,6 @@ def get_user_id(self): def get_event_type(self): return self.event_type - def get_metrics_user(self): - return self.metrics_user - - def get_metrics_agent(self): - return self.metrics_agent - def get_model(self): return self.model diff --git a/src/backend/schemas/metrics.py b/src/backend/schemas/metrics.py deleted file mode 100644 index 46d79e6b11..0000000000 --- a/src/backend/schemas/metrics.py +++ /dev/null @@ -1,112 +0,0 @@ -from enum import Enum -from typing import Any - -from pydantic import BaseModel - -from backend.schemas.agent import Agent - - -class GenericResponseMessage(BaseModel): - message: str - - -class MetricsMessageType(str, Enum): - # users: implemented, has tests - USER_CREATED = "user_created" - USER_UPDATED = "user_updated" - USER_DELETED = "user_deleted" - # agents: implemented, has tests - ASSISTANT_CREATED = "assistant_created" - ASSISTANT_UPDATED = "assistant_updated" - ASSISTANT_DELETED = "assistant_deleted" - ASSISTANT_ACCESSED = "assistant_accessed" - # implemented, has tests - CHAT_API_SUCCESS = "chat_api_call_success" - # implemented, needs tests - CHAT_API_FAIL = "chat_api_call_failure" - # implemented, has tests - RERANK_API_SUCCESS = "rerank_api_call_success" - # implemented, needs tests - RERANK_API_FAIL = "rerank_api_call_failure" - # pending implementation - ENV_LIVENESS = "env_liveness" - UNKNOWN_SIGNAL = "unknown" - - -class MetricsDataBase(BaseModel): - id: str - user_id: str - trace_id: str - message_type: MetricsMessageType - timestamp: float - - -class MetricsUser(BaseModel): - id: str - fullname: str - email: str | None - - -class MetricsAgent(BaseModel): - id: str - version: int - name: str - temperature: float - model: str | None - deployment: str | None - preamble: str | None - description: str | None - - -class MetricsModelAttrs(BaseModel): - input_nb_tokens: int - output_nb_tokens: int - search_units: int - model: str - assistant_id: str - - -class MetricsData(MetricsDataBase): - input_nb_tokens: int | None = None - output_nb_tokens: int | None = None - search_units: int | None = None - model: str | None = None - error: str | None = None - duration_ms: float | None = None - meta: dict[str, Any] | None = None - assistant_id: str | None = None - assistant: MetricsAgent | None = None - user: MetricsUser | None = None - - -class MetricsSignal(BaseModel): - signal: MetricsData - - -DEFAULT_METRICS_AGENT = MetricsAgent( - id="9c300cfd-1506-408b-829d-a6464137a7c1", - version=1, - name="Default Agent", - temperature=0.3, - model="command-r-plus", - deployment="Cohere", - preamble="", - description="default", -) - - -def agent_to_metrics_agent(agent: Agent | None) -> MetricsAgent: - if not agent: - return None - # TODO Eugene: Check agent.model and agent.deployment after the refactor Agent deployment - # and model to object(if needed) - return MetricsAgent( - id=agent.id, - version=agent.version, - name=agent.name, - temperature=agent.temperature, - model=agent.model if agent.model else None, - deployment=agent.deployment if agent.deployment else None, - preamble=agent.preamble, - description=agent.description, - ) diff --git a/src/backend/services/metrics.py b/src/backend/services/metrics.py deleted file mode 100644 index dbb32a2f9a..0000000000 --- a/src/backend/services/metrics.py +++ /dev/null @@ -1,435 +0,0 @@ -import asyncio -import json -import os -import time -import uuid -from functools import wraps -from typing import Any, Callable, Dict, Optional - -from httpx import AsyncHTTPTransport -from httpx._client import AsyncClient -from starlette.background import BackgroundTask -from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request -from starlette.responses import Response - -from backend.chat.collate import to_dict -from backend.chat.enums import StreamEvent -from backend.schemas.cohere_chat import CohereChatRequest -from backend.schemas.context import Context -from backend.schemas.metrics import ( - MetricsData, - MetricsMessageType, - MetricsModelAttrs, - MetricsSignal, -) -from backend.services.context import get_context -from backend.services.logger.utils import LoggerFactory - -REPORT_ENDPOINT = os.getenv("REPORT_ENDPOINT", None) -REPORT_SECRET = os.getenv("REPORT_SECRET", None) -METRICS_LOGS_CURLS = os.getenv("METRICS_LOGS_CURLS", None) -NUM_RETRIES = 0 -HEALTH_ENDPOINT = "health" -HEALTH_ENDPOINT_USER_ID = "health" -# TODO: fix this hack eventually -DEFAULT_RERANK_MODEL = "rerank-english-v2.0" - - -class MetricsMiddleware(BaseHTTPMiddleware): - """ - Middleware class for handling metrics in the application. - - This middleware is responsible for tracking and reporting select events for incoming requests. - It follows the fire and forget mechanism and should never throw exceptions. - For chat streams, and rerank events, additional decorators are also required. - - Attributes: - None - - Methods: - dispatch: Dispatches the request to the next middleware or application handler. - _init_req_state: Initializes the state of the request. - _confirm_env: Confirms the environment setup for reporting metrics. - _send_signal: Sends the metrics signal to the reporting endpoint. - _get_event_signal: Retrieves the metrics signal for the current request. - _get_user: Retrieves the user information from the request. - - """ - - async def dispatch(self, request: Request, call_next: Callable): - self._confirm_env() - - start_time = time.perf_counter() - response = await call_next(request) - duration_ms = time.perf_counter() - start_time - - ctx = get_context(request) - self._send_signal(request, response, duration_ms, ctx) - - return response - - def _confirm_env(self): - logger = LoggerFactory().get_logger() - if not REPORT_SECRET: - logger.warning(event="[Metrics] No report secret set") - if not REPORT_ENDPOINT: - logger.warning(event="[Metrics] No report endpoint set") - - def _should_send_signal( - self, - signal: Optional[MetricsSignal], - event_type: Optional[MetricsMessageType], - response: Response, - ) -> bool: - middleware_allowed_signals = { - MetricsMessageType.USER_CREATED, - MetricsMessageType.USER_UPDATED, - MetricsMessageType.USER_DELETED, - MetricsMessageType.ASSISTANT_CREATED, - MetricsMessageType.ASSISTANT_UPDATED, - MetricsMessageType.ASSISTANT_DELETED, - MetricsMessageType.ASSISTANT_ACCESSED, - } - - return ( - True - if ( - event_type in middleware_allowed_signals - and signal - # TODO: we may want to log failing reqeusts as well in the future - # right now we only track failures from chat streams and rerank - # through the decorators - and response.status_code >= 200 - and response.status_code < 300 - ) - else False - ) - - def _send_signal( - self, request: Request, response: Response, duration_ms: float, ctx: Context - ) -> None: - signal = self._get_event_signal(request, duration_ms, ctx) - event_type = ctx.get_event_type() - if self._should_send_signal(signal, event_type, response): - # signal is being checked in the condition above - response.background = BackgroundTask(report_metrics, signal, ctx) # type: ignore - - def _get_event_signal( - self, request: Request, duration_ms: float, ctx: Context - ) -> MetricsSignal | None: - if request.scope["type"] != "http": - return None - - message_type = ctx.get_event_type() - if not message_type: - return None - - logger = ctx.get_logger() - - user = ctx.get_metrics_user() - # when user is created, user_id is not in the header - trace_id = ctx.get_trace_id() - user_id = ctx.get_user_id() - agent = ctx.get_metrics_agent() - agent_id = agent.id if agent else None - event_id = str(uuid.uuid4()) - now_unix_seconds = time.time() - - try: - data = MetricsData( - id=event_id, - user_id=user_id, - timestamp=now_unix_seconds, - user=user, - message_type=message_type, - trace_id=trace_id, - assistant=agent, - assistant_id=agent_id, - duration_ms=duration_ms, - ) - signal = MetricsSignal(signal=data) - return signal - except Exception as e: - logger.warning(event=f"[Metrics] Failed to process event data: {e}") - return None - - -async def report_metrics(signal: MetricsSignal, ctx: Context) -> None: - """ - Reports the given metrics signal to the specified endpoint. - This is the key function for reporting metrics. It should never throw exceptions but log them. - - Args: - signal (MetricsSignal): The metrics signal to be reported. - - Returns: - None - """ - logger = ctx.get_logger() - - if METRICS_LOGS_CURLS == "true": - MetricsHelper.log_signal_curl(signal, ctx) - if not REPORT_SECRET: - return - if not REPORT_ENDPOINT: - return - - try: - signal = to_dict(signal) - transport = AsyncHTTPTransport(retries=NUM_RETRIES) - async with AsyncClient(transport=transport) as client: - headers = { - "Authorization": f"Bearer {REPORT_SECRET}", - "Content-Type": "application/json", - } - await client.post(REPORT_ENDPOINT, json=signal, headers=headers) - except Exception as e: - logger.error(event=f"[Metrics] Error posting report: {e}") - - -def collect_metrics_chat_stream(func: Callable) -> Callable: - """ - Decorator for collecting metrics for chat streams. - Use with the middleware as needed. - Args: - func (Callable): the original function to be decorated, must return an async generator - - Returns: - Callable: wrapped function that yields the original values - - Yields: - Iterator[Callable]: the original values from the stream - """ - - @wraps(func) - async def wrapper( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any - ) -> Any: - stream = func(self, chat_request, ctx, **kwargs) - async for v in stream: - ChatMetricHelper.report_streaming_chat_event(v, ctx, **kwargs) - yield v - - return wrapper - - -def collect_metrics_rerank(func: Callable) -> Callable: - """ - Decorator for collecting metrics for rerank events. - Use with the middleware as needed. - Args: - func (Callable): function to be decorated - - Raises: - e: original exception raised by the function - - Returns: - Callable: the wrapped function - """ - - @wraps(func) - async def wrapper( - self, query: str, documents: Dict[str, Any], ctx: Context, **kwargs: Any - ) -> Any: - start_time = time.perf_counter() - try: - response = await func(self, query, documents, ctx, **kwargs) - duration_ms = time.perf_counter() - start_time - RerankMetricsHelper.report_rerank_metrics( - response, duration_ms, ctx, **kwargs - ) - return response - except Exception as e: - duration_ms = time.perf_counter() - start_time - RerankMetricsHelper.report_rerank_failed_metrics( - duration_ms, e, ctx, **kwargs - ) - raise e - - return wrapper - - -class MetricsHelper: - # TODO: remove the logging once metrics are configured correctly - @staticmethod - def log_signal_curl(signal: MetricsSignal, ctx: Context) -> None: - logger = ctx.get_logger() - s = to_dict(signal) - json_signal = json.dumps(s) - # just general curl commands to test the endpoint for now - logger.info( - event=f"\n\ncurl -X POST -H \"Content-Type: application/json\" -d '{json_signal}' $ENDPOINT\n\n" - ) - - -# DO NOT THROW EXPCEPTIONS IN THIS FUNCTION -class ChatMetricHelper: - @staticmethod - def report_streaming_chat_event( - event: dict[str, Any], ctx: Context, **kwargs: Any - ) -> None: - logger = ctx.get_logger() - - try: - event_type = event["event_type"] - if event_type == StreamEvent.STREAM_START: - ctx.with_stream_start_ms(time.perf_counter()) - - if event_type != StreamEvent.STREAM_END: - return - - duration_ms = None - time_start = ctx.get_stream_start_ms() - if time_start: - duration_ms = time.perf_counter() - time_start - trace_id = ctx.get_trace_id() - model = ctx.get_model() - user_id = ctx.get_user_id() - agent = ctx.get_metrics_agent() - agent_id = agent.id if agent else None - event_dict = to_dict(event).get("response", {}) - input_tokens = ( - event_dict.get("meta", {}) - .get("billed_units", {}) - .get("input_tokens", 0) - ) - output_tokens = ( - event_dict.get("meta", {}) - .get("billed_units", {}) - .get("output_tokens", 0) - ) - search_units = ( - event_dict.get("meta", {}) - .get("billed_units", {}) - .get("search_units", 0) - ) - search_units = search_units if search_units else 0 - is_error = ( - event_dict.get("event_type") == StreamEvent.STREAM_END - and event_dict.get("finish_reason") != "COMPLETE" - and event_dict.get("finish_reason") != "MAX_TOKENS" - ) - - message_type = ( - MetricsMessageType.CHAT_API_FAIL - if is_error - else MetricsMessageType.CHAT_API_SUCCESS - ) - # validate successful event metrics, ignore type errors to rely on pydantic exceptions - if not is_error: - MetricsModelAttrs( - input_nb_tokens=input_tokens, - output_nb_tokens=output_tokens, - search_units=search_units, - model=model, # type: ignore - assistant_id=agent_id, # type: ignore - ) - - metrics = MetricsData( - id=str(uuid.uuid4()), - user_id=user_id, - trace_id=trace_id, - duration_ms=duration_ms, - message_type=message_type, - timestamp=time.time(), - input_nb_tokens=input_tokens, - output_nb_tokens=output_tokens, - search_units=search_units, - model=model, - assistant_id=agent_id, - assistant=agent, - error=event_dict.get("finish_reason", None) if is_error else None, - ) - signal = MetricsSignal(signal=metrics) - # do not await, fire and forget - asyncio.create_task(report_metrics(signal, ctx)) - - except Exception as e: - logger.error(event=f"Failed to report streaming event: {e}") - - -class RerankMetricsHelper: - # DO NOT THROW EXPCEPTIONS IN THIS FUNCTION - @staticmethod - def report_rerank_metrics( - response: Any, duration_ms: float, ctx: Context, **kwargs: Any - ): - logger = ctx.get_logger() - - try: - (trace_id, model, user_id, agent, agent_id) = ( - RerankMetricsHelper._get_init_data(ctx) - ) - response_dict = to_dict(response) - search_units = ( - response_dict.get("meta", {}) - .get("billed_units", {}) - .get("search_units") - ) - message_type = MetricsMessageType.RERANK_API_SUCCESS - # ensure valid MetricsChat object - _ = MetricsModelAttrs( - input_nb_tokens=0, - output_nb_tokens=0, - search_units=search_units, - model=model, - assistant_id=agent_id, - ) - - metrics_data = MetricsData( - id=str(uuid.uuid4()), - message_type=message_type, - trace_id=trace_id, - user_id=user_id, - assistant_id=agent_id, - assistant=agent, - model=model, - input_nb_tokens=0, - output_nb_tokens=0, - search_units=search_units, - timestamp=time.time(), - duration_ms=duration_ms, - ) - signal = MetricsSignal(signal=metrics_data) - asyncio.create_task(report_metrics(signal, ctx)) - except Exception as e: - logger.error(event=f"[Metrics] Error reporting rerank metrics: {e}") - - @staticmethod - def report_rerank_failed_metrics( - duration_ms: float, error: Exception, ctx: Context, **kwargs: Any - ): - logger = ctx.get_logger() - - try: - (trace_id, model, user_id, agent, agent_id) = ( - RerankMetricsHelper._get_init_data(ctx) - ) - message_type = MetricsMessageType.RERANK_API_FAIL - error_message = str(error) - metrics_data = MetricsData( - id=str(uuid.uuid4()), - message_type=message_type, - trace_id=trace_id, - user_id=user_id, - assistant_id=agent_id, - assistant=agent, - model=model, - duration_ms=duration_ms, - timestamp=time.time(), - error=error_message, - ) - signal = MetricsSignal(signal=metrics_data) - asyncio.create_task(report_metrics(signal, ctx)) - except Exception as e: - logger.error(event=f"Failed to report rerank metrics: {e}") - - @staticmethod - def _get_init_data(ctx: Context) -> tuple: - trace_id = ctx.get_trace_id() - model = DEFAULT_RERANK_MODEL - user_id = ctx.get_user_id() - agent = ctx.get_metrics_agent() - agent_id = agent.id if agent else ctx.get_agent_id() - return (trace_id, model, user_id, agent, agent_id) diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 4a52db8671..02e6d1f62a 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -1,4 +1,3 @@ -from unittest.mock import patch from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -7,7 +6,6 @@ from backend.config.tools import ToolName from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata -from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.tests.unit.factories import get_factory @@ -139,88 +137,6 @@ def test_create_agent_missing_non_required_fields( assert agent.temperature == 0.3 assert agent.model == request_json["model"] - -def test_update_agent_metric(session_client: TestClient, session: Session) -> None: - user = get_factory("User", session).create(fullname="John Doe") - agent = get_factory("Agent", session).create( - name="test agent", - version=1, - description="test description", - preamble="test preamble", - temperature=0.5, - model="command-r-plus", - deployment=ModelDeploymentName.CoherePlatform, - user_id=user.id, - ) - - request_json = { - "name": "updated name", - "version": 2, - "description": "updated description", - "preamble": "updated preamble", - "temperature": 0.7, - "model": "command-r", - "deployment": ModelDeploymentName.CoherePlatform, - } - - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.put( - f"/v1/agents/{agent.id}", - json=request_json, - headers={"User-Id": user.id}, - ) - - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.ASSISTANT_UPDATED - assert m_args.assistant.name == request_json["name"] - assert m_args.user.fullname == user.fullname - - -def test_update_agent_mock_metrics( - session_client: TestClient, session: Session, user -) -> None: - agent = get_factory("Agent", session).create( - name="test agent", - version=1, - description="test description", - preamble="test preamble", - temperature=0.5, - model="command-r-plus", - deployment=ModelDeploymentName.CoherePlatform, - user_id=user.id, - ) - - request_json = { - "name": "updated name", - "version": 2, - "description": "updated description", - "preamble": "updated preamble", - "temperature": 0.7, - "model": "command-r", - "deployment": ModelDeploymentName.CoherePlatform, - } - - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.put( - f"/v1/agents/{agent.id}", - json=request_json, - headers={"User-Id": user.id}, - ) - - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.ASSISTANT_UPDATED - assert m_args.assistant.name == request_json["name"] - assert m_args.user.fullname == user.fullname - - def test_update_agent(session_client: TestClient, session: Session, user) -> None: agent = get_factory("Agent", session).create( name="test agent", diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 9ea7024c3c..d7cce31caf 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -1,5 +1,4 @@ import os -from unittest.mock import patch import pytest from fastapi.testclient import TestClient @@ -7,7 +6,6 @@ from backend.config.deployments import ModelDeploymentName from backend.database_models import Conversation -from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -84,48 +82,6 @@ def test_search_conversations_no_conversations( assert response.json() == [] -@pytest.mark.skipif( - os.environ.get("COHERE_API_KEY") is None, - reason="Cohere API key not set, skipping test", -) -def test_search_conversations_with_reranking_sends_metrics( - session_client: TestClient, - session: Session, - user: User, -) -> None: - _ = get_factory("Conversation", session).create( - title="Hello, how are you?", text_messages=[], user_id=user.id - ) - _ = get_factory("Conversation", session).create( - title="There are are seven colors in the rainbow", - text_messages=[], - user_id=user.id, - ) - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.get( - "/v1/conversations:search", - headers={ - "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, - }, - params={"query": "color"}, - ) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.user_id == user.id - assert m_args.model == "rerank-english-v2.0" - - assert m_args.message_type == MetricsMessageType.RERANK_API_SUCCESS - assert m_args.duration_ms is not None and m_args.duration_ms > 0 - assert m_args.assistant_id is not None - assert m_args.assistant.name is not None - assert m_args.model is not None - assert m_args.search_units > 0 - - # MISC diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index 86652eb9b4..0e4645b42a 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -1,6 +1,4 @@ -from unittest.mock import patch -import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -10,40 +8,9 @@ from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.database_models.snapshot import Snapshot -from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.tests.unit.factories import get_factory -async def test_create_agent_metric( - session_client: TestClient, session: Session -) -> None: - user = get_factory("User", session).create(fullname="John Doe") - request_json = { - "name": "test agent", - "version": 1, - "description": "test description", - "preamble": "test preamble", - "temperature": 0.5, - "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, - "tools": [ToolName.Calculator, ToolName.Search_File, ToolName.Read_File], - } - - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.post( - "/v1/agents", json=request_json, headers={"User-Id": user.id} - ) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.user_id == user.id - assert m_args.message_type == MetricsMessageType.ASSISTANT_CREATED - assert m_args.assistant.name == request_json["name"] - assert m_args.user.fullname == user.fullname - - def test_create_agent_missing_name( session_client: TestClient, session: Session, user ) -> None: @@ -368,58 +335,6 @@ def test_list_agents_with_pagination( response_agents = response.json() assert len(response_agents) == 1 - -@pytest.mark.asyncio -async def test_get_agent_metric( - session_client: TestClient, session: Session, user -) -> None: - agent = get_factory("Agent", session).create(name="test agent", user_id=user.id) - get_factory("AgentToolMetadata", session).create( - user_id=user.id, - agent_id=agent.id, - tool_name=ToolName.Google_Drive, - artifacts=[ - { - "name": "/folder1", - "ids": "folder1", - "type": "folder_id", - }, - { - "name": "file1.txt", - "ids": "file1", - "type": "file_id", - }, - ], - ) - - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.get( - f"/v1/agents/{agent.id}", headers={"User-Id": user.id} - ) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.ASSISTANT_ACCESSED - assert m_args.assistant.name == agent.name - - -@pytest.mark.asyncio -async def test_get_default_agent_metric( - session_client: TestClient, session: Session, user -) -> None: - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.get("/v1/default_agent", headers={"User-Id": user.id}) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.ASSISTANT_ACCESSED - assert m_args.assistant.name == "Default Agent" - - def test_get_agent(session_client: TestClient, session: Session, user) -> None: agent = get_factory("Agent", session).create(name="test agent", user_id=user.id) agent_tool_metadata = get_factory("AgentToolMetadata", session).create( @@ -1017,23 +932,6 @@ def test_update_agent_change_visibility_to_private_delete_snapshot( assert snapshot is None -def test_delete_agent_metric( - session_client: TestClient, session: Session, user -) -> None: - agent = get_factory("Agent", session).create(user=user) - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.delete( - f"/v1/agents/{agent.id}", headers={"User-Id": user.id} - ) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.ASSISTANT_DELETED - assert m_args.assistant_id == agent.id - - def test_delete_public_agent( session_client: TestClient, session: Session, user ) -> None: diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 50706cdf26..6c61dc0080 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -2,7 +2,6 @@ import os import uuid from typing import Any -from unittest.mock import patch import pytest from fastapi.testclient import TestClient @@ -14,7 +13,6 @@ from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User -from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.schemas.tool import Category from backend.tests.unit.factories import get_factory @@ -86,52 +84,6 @@ def test_streaming_new_chat( ) -# TODO: add test case for when stream raises an error -@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") -def test_streaming_new_chat_metrics_with_agent( - session_client_chat: TestClient, session_chat: Session, user: User -): - agent = get_factory("Agent", session_chat).create(user=user) - deployment = get_factory("Deployment", session_chat).create() - model = get_factory("Model", session_chat).create(deployment=deployment) - get_factory("AgentDeploymentModel", session_chat).create( - agent=agent, - deployment=deployment, - model=model, - is_default_deployment=True, - is_default_model=True, - ) - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client_chat.post( - "/v1/chat-stream", - headers={ - "User-Id": agent.user.id, - "Deployment-Name": agent.deployment, - }, - params={"agent_id": agent.id}, - json={ - "message": "Hello", - "max_tokens": 10, - "agent_id": agent.id, - }, - ) - # finish all the event stream - assert response.status_code == 200 - for line in response.iter_lines(): - continue - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.user_id == agent.user.id - assert m_args.message_type == MetricsMessageType.CHAT_API_SUCCESS - assert m_args.assistant_id == agent.id - assert m_args.assistant.name == agent.name - assert m_args.model is not None - assert m_args.input_nb_tokens > 0 - assert m_args.output_nb_tokens > 0 - - @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_new_chat_with_agent( session_client_chat: TestClient, session_chat: Session, user: User diff --git a/src/backend/tests/unit/routers/test_user.py b/src/backend/tests/unit/routers/test_user.py index 0df9028253..29c439b528 100644 --- a/src/backend/tests/unit/routers/test_user.py +++ b/src/backend/tests/unit/routers/test_user.py @@ -1,11 +1,8 @@ -from unittest.mock import patch -import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session from backend.database_models.user import User -from backend.schemas.metrics import MetricsData, MetricsMessageType from backend.services.auth import BasicAuthentication from backend.tests.unit.factories import get_factory @@ -51,24 +48,6 @@ def test_fail_get_nonexistent_user( assert response.json() == {"detail": "User with ID: 123 not found."} -@pytest.mark.asyncio -def test_create_user_metric(session_client: TestClient, session: Session) -> None: - user_data_req = { - "fullname": "John Doe", - "email": "john@email.com", - } - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.post("/v1/users", json=user_data_req) - assert response.status_code == 200 - response.json() - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.USER_CREATED - assert m_args.user.fullname == user_data_req["fullname"] - - def test_create_user(session_client: TestClient, session: Session) -> None: user_data_req = { "fullname": "John Doe", @@ -131,25 +110,6 @@ def test_fail_create_user_missing_fullname( } -def test_update_user_metric(session_client: TestClient, session: Session) -> None: - user = get_factory("User", session).create(fullname="John Doe") - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.put( - f"/v1/users/{user.id}", - json={"fullname": "new name"}, - headers={"User-Id": user.id}, - ) - response.json() - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.USER_UPDATED - assert m_args.user_id == user.id - assert m_args.user.fullname == "new name" - - def test_update_user(session_client: TestClient, session: Session) -> None: user = get_factory("User", session).create(fullname="John Doe") @@ -186,21 +146,6 @@ def test_fail_update_nonexistent_user( assert response.json() == {"detail": "User with ID: 123 not found."} -def test_delete_user_metric(session_client: TestClient, session: Session) -> None: - user = get_factory("User", session).create(fullname="John Doe") - with patch( - "backend.services.metrics.report_metrics", - return_value=None, - ) as mock_metrics: - response = session_client.delete( - f"/v1/users/{user.id}", headers={"User-Id": user.id} - ) - assert response.status_code == 200 - m_args: MetricsData = mock_metrics.await_args.args[0].signal - assert m_args.message_type == MetricsMessageType.USER_DELETED - assert m_args.user_id == user.id - - def test_delete_user(session_client: TestClient, session: Session) -> None: user = get_factory("User", session).create(fullname="John Doe")