From d341c24088028841ee0947d7bdca7f1a3bf2872a Mon Sep 17 00:00:00 2001 From: Eric Zawadski Date: Wed, 15 Jan 2025 06:12:59 -0800 Subject: [PATCH] chore(backend): fix issues from rebase --- src/backend/model_deployments/azure.py | 20 +++++------ src/backend/model_deployments/base.py | 23 ++++++------- src/backend/model_deployments/bedrock.py | 20 +++++------ .../model_deployments/cohere_platform.py | 18 +++++----- src/backend/model_deployments/sagemaker.py | 18 +++++----- .../model_deployments/single_container.py | 20 +++++------ .../tests/integration/routers/test_agent.py | 33 ++++++++++++------- .../mock_deployments/mock_azure.py | 18 +++++----- .../mock_deployments/mock_bedrock.py | 20 +++++------ .../mock_deployments/mock_cohere_platform.py | 15 +++++---- .../mock_deployments/mock_sagemaker.py | 20 +++++------ .../mock_deployments/mock_single_container.py | 20 +++++------ 12 files changed, 129 insertions(+), 116 deletions(-) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index 2463f610d5..dce6660516 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -43,16 +43,16 @@ def __init__(self, **kwargs: Any): base_url=self.chat_endpoint_url, api_key=self.api_key ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [AZURE_API_KEY_ENV_VAR, AZURE_CHAT_URL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -62,14 +62,14 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( AzureDeployment.default_api_key is not None and AzureDeployment.default_chat_endpoint_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), ) @@ -86,6 +86,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index f578d15f80..f6bec64471 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator +from typing import Any from backend.config.settings import Settings from backend.schemas.cohere_chat import CohereChatRequest @@ -25,31 +25,32 @@ def __init__(self, db_id=None, **kwargs: Any): def id(cls) -> str: return cls.db_id if cls.db_id else cls.name().replace(" ", "_").lower() - @classmethod + @staticmethod @abstractmethod - def name(cls) -> str: ... + def name() -> str: ... - @classmethod + @staticmethod @abstractmethod - def env_vars(cls) -> List[str]: ... + def env_vars() -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def rerank_enabled(cls) -> bool: ... + def rerank_enabled() -> bool: ... @classmethod @abstractmethod def list_models(cls) -> list[str]: ... - @classmethod + @staticmethod @abstractmethod - def is_available(cls) -> bool: ... + def is_available() -> bool: ... @classmethod def is_community(cls) -> bool: return False - def config(cls) -> Dict[str, Any]: + @classmethod + def config(cls) -> dict[str, Any]: config = Settings().get(f"deployments.{cls.id()}") config_dict = {} if not config else dict(config) for key, value in config_dict.items(): @@ -78,7 +79,7 @@ async def invoke_chat( @abstractmethod async def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any - ) -> AsyncGenerator[Any, Any]: ... + ) -> Any: ... @abstractmethod async def invoke_rerank( diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index aba4e5a091..9403deed47 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -42,12 +42,12 @@ def __init__(self, **kwargs: Any): ), ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ BEDROCK_ACCESS_KEY_ENV_VAR, BEDROCK_SECRET_KEY_ENV_VAR, @@ -55,8 +55,8 @@ def env_vars(cls) -> List[str]: BEDROCK_REGION_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -66,8 +66,8 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( BedrockDeployment.access_key is not None and BedrockDeployment.secret_access_key is not None @@ -75,7 +75,7 @@ def is_available(cls) -> bool: and BedrockDeployment.region_name is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: # bedrock accepts a subset of the chat request fields bedrock_chat_req = chat_request.model_dump( exclude={"tools", "conversation_id", "model", "stream"}, exclude_none=True @@ -101,6 +101,6 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index dac13d1347..a718d0b68e 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -29,16 +29,16 @@ def __init__(self, **kwargs: Any): ) self.client = cohere.Client(api_key, client_name=self.client_name) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [COHERE_API_KEY_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod @@ -64,12 +64,12 @@ def list_models(cls) -> list[str]: models = response.json()["models"] return [model["name"] for model in models if model.get("endpoints") and "chat" in model["endpoints"]] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return CohereDeployment.api_key is not None async def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: response = self.client.chat( **chat_request.model_dump(exclude={"stream", "file_ids", "agent_id"}), diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index 317ec70a1a..6a686a378c 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -65,12 +65,12 @@ def __init__(self, **kwargs: Any): "ContentType": "application/json", } - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [ SAGE_MAKER_ACCESS_KEY_ENV_VAR, SAGE_MAKER_SECRET_KEY_ENV_VAR, @@ -79,8 +79,8 @@ def env_vars(cls) -> List[str]: SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, ] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod @@ -90,8 +90,8 @@ def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SageMakerDeployment.region_name is not None and SageMakerDeployment.aws_access_key_id is not None @@ -121,7 +121,7 @@ async def invoke_chat_stream( yield stream_event async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return None diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 78c7bf0a0a..4ddcb2e174 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -33,16 +33,16 @@ def __init__(self, **kwargs: Any): base_url=self.url, client_name=self.client_name, api_key="none" ) - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [SC_URL_ENV_VAR, SC_MODEL_ENV_VAR] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return SingleContainerDeployment.default_model.startswith("rerank") @classmethod @@ -52,14 +52,14 @@ def list_models(cls) -> list[str]: return [SingleContainerDeployment.default_model] - @classmethod - def is_available(cls) -> bool: + @staticmethod + def is_available() -> bool: return ( SingleContainerDeployment.default_model is not None and SingleContainerDeployment.default_url is not None ) - async def invoke_chat(self, chat_request: CohereChatRequest) -> Any: + async def invoke_chat(self, chat_request: CohereChatRequest, **kwargs) -> Any: response = self.client.chat( **chat_request.model_dump( exclude={"stream", "file_ids", "model", "agent_id"} @@ -80,7 +80,7 @@ async def invoke_chat_stream( yield to_dict(event) async def invoke_rerank( - self, query: str, documents: list[str], ctx: Context + self, query: str, documents: list[str], ctx: Context, **kwargs ) -> Any: return self.client.rerank( query=query, documents=documents, model=DEFAULT_RERANK_MODEL diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index cb8f863386..af276bbfda 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -21,7 +21,12 @@ ) -def test_create_agent(session_client: TestClient, session: Session, user: User, mock_cohere_list_models) -> None: +def test_create_agent( + session_client: TestClient, + session: Session, + user: User, + mock_cohere_list_models, +) -> None: request_json = { "name": "test agent", "version": 1, @@ -297,7 +302,7 @@ def test_create_agent_deployment_not_in_db( def test_create_agent_invalid_tool( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: request_json = { "name": "test agent", @@ -314,7 +319,7 @@ def test_create_agent_invalid_tool( def test_create_existing_agent( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: agent = get_factory("Agent", session).create(name="test agent") request_json = { @@ -336,7 +341,9 @@ def test_list_agents_empty_returns_default_agent(session_client: TestClient, ses assert len(response_agents) == 1 -def test_list_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_agents( + session_client: TestClient, session: Session, user: User, +) -> None: num_agents = 3 for _ in range(num_agents): _ = get_factory("Agent", session).create(user=user) @@ -350,7 +357,7 @@ def test_list_agents(session_client: TestClient, session: Session, user) -> None def test_list_organization_agents( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -379,7 +386,7 @@ def test_list_organization_agents( def test_list_organization_agents_query_param( session_client: TestClient, session: Session, - user, + user: User, ) -> None: num_agents = 3 organization = get_factory("Organization", session).create() @@ -408,7 +415,7 @@ def test_list_organization_agents_query_param( def test_list_organization_agents_nonexistent_organization( session_client: TestClient, session: Session, - user, + user: User, ) -> None: response = session_client.get( "/v1/agents", headers={"User-Id": user.id, "Organization-Id": "123"} @@ -418,7 +425,7 @@ def test_list_organization_agents_nonexistent_organization( def test_list_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -438,7 +445,9 @@ def test_list_private_agents( assert len(response_agents) == 3 -def test_list_public_agents(session_client: TestClient, session: Session, user) -> None: +def test_list_public_agents( + session_client: TestClient, session: Session, user: User, +) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -451,6 +460,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) ) assert response.status_code == 200 + breakpoint() response_agents = filter_default_agent(response.json()) # Only the agents created by user should be returned @@ -458,7 +468,7 @@ def test_list_public_agents(session_client: TestClient, session: Session, user) def list_public_and_private_agents( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(3): _ = get_factory("Agent", session).create(user=user, is_private=True) @@ -479,7 +489,7 @@ def list_public_and_private_agents( def test_list_agents_with_pagination( - session_client: TestClient, session: Session, user + session_client: TestClient, session: Session, user: User, ) -> None: for _ in range(5): _ = get_factory("Agent", session).create(user=user) @@ -495,6 +505,7 @@ def test_list_agents_with_pagination( "/v1/agents?limit=2&offset=4", headers={"User-Id": user.id} ) assert response.status_code == 200 + breakpoint() response_agents = filter_default_agent(response.json()) assert len(response_agents) == 1 diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index 2d4129cf9b..12279f1c23 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockAzureDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Azure" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["AZURE_API_KEY", "AZURE_CHAT_ENDPOINT_URL"] - @classmethod - def rerank_enabled(cls) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: if not cls.is_available(): return [] @@ -42,7 +42,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Any: event = { "text": "Hi! Hello there! How's it going?", diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index f2fd3eb9a2..53fa171faa 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockBedrockDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Bedrock" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,6 +93,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index 02120cbe97..dd2e95b03a 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -6,6 +6,7 @@ from backend.chat.enums import StreamEvent from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.services.conversation import SEARCH_RELEVANCE_THRESHOLD from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( MockDeployment, ) @@ -19,16 +20,16 @@ class MockCohereDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Cohere Platform" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return ["COHERE_API_KEY"] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return True @classmethod @@ -40,7 +41,7 @@ def is_available() -> bool: return True @classmethod - def config(cls) -> Dict[str, Any]: + def config(cls) -> dict[str, Any]: return {"COHERE_API_KEY": "fake-api-key"} def invoke_chat( diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py index 0b7353ff53..2f0f577562 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockSageMakerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "SageMaker" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: pass @@ -72,6 +72,6 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: return None diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index 7451af8541..2ed0464ff4 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Generator, List +from typing import Any, Generator from cohere.types import StreamedChatResponse @@ -18,20 +18,20 @@ class MockSingleContainerDeployment(MockDeployment): def __init__(self, **kwargs: Any): pass - @classmethod - def name(cls) -> str: + @staticmethod + def name() -> str: return "Single Container" - @classmethod - def env_vars(cls) -> List[str]: + @staticmethod + def env_vars() -> list[str]: return [] - @property - def rerank_enabled(self) -> bool: + @staticmethod + def rerank_enabled() -> bool: return False @classmethod - def list_models(cls) -> List[str]: + def list_models(cls) -> list[str]: return cls.DEFAULT_MODELS @staticmethod @@ -39,7 +39,7 @@ def is_available() -> bool: return True def invoke_chat( - self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: event = { "text": "Hi! Hello there! How's it going?", @@ -93,7 +93,7 @@ def invoke_chat_stream( yield event def invoke_rerank( - self, query: str, documents: List[Dict[str, Any]], ctx: Context, **kwargs: Any + self, query: str, documents: list[str], ctx: Context, **kwargs: Any ) -> Any: # TODO: Add pass