diff --git a/src/backend/database_models/seeders/deplyments_models_seed.py b/src/backend/database_models/seeders/deplyments_models_seed.py index 0b8cef3685..0a53ff0786 100644 --- a/src/backend/database_models/seeders/deplyments_models_seed.py +++ b/src/backend/database_models/seeders/deplyments_models_seed.py @@ -6,8 +6,15 @@ from sqlalchemy import text from sqlalchemy.orm import Session -from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS from backend.database_models import Deployment, Model, Organization +from backend.model_deployments import ( + CohereDeployment, + SingleContainerDeployment, + SageMakerDeployment, + AzureDeployment, + BedrockDeployment, +) from community.config.deployments import ( AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, ) @@ -18,7 +25,7 @@ model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP) MODELS_NAME_MAPPING = { - ModelDeploymentName.CoherePlatform: { + CohereDeployment.name(): { "command": { "cohere_name": "command", "is_default": False, @@ -60,7 +67,7 @@ "is_default": False, }, }, - ModelDeploymentName.SingleContainer: { + SingleContainerDeployment.name(): { "command": { "cohere_name": "command", "is_default": False, @@ -102,19 +109,19 @@ "is_default": False, }, }, - ModelDeploymentName.SageMaker: { + SageMakerDeployment.name(): { "sagemaker-command": { "cohere_name": "command", "is_default": True, }, }, - ModelDeploymentName.Azure: { + AzureDeployment.name(): { "azure-command": { "cohere_name": "command-r", "is_default": True, }, }, - ModelDeploymentName.Bedrock: { + BedrockDeployment.name(): { "cohere.command-r-plus-v1:0": { "cohere_name": "command-r-plus", "is_default": True, diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index 9bb208ad4d..88eaa1382d 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -50,7 +50,7 @@ def config(cls) -> Dict[str, Any]: return config.dict() if config else {} @classmethod - def to_deployment_info(cls) -> DeploymentDefinition: + def to_deployment_definition(cls) -> DeploymentDefinition: return DeploymentDefinition( id=cls.id(), name=cls.name(), diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index 64466fc0dd..a9d69ab6a9 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -18,9 +18,9 @@ class SingleContainerDeployment(BaseDeployment): """Single Container Deployment.""" client_name = "cohere-toolkit" - config = Settings().get('deployments.single_container') - default_url = config.url - default_model = config.model + sc_config = Settings().get('deployments.single_container') + default_url = sc_config.url + default_model = sc_config.model def __init__(self, **kwargs: Any): self.url = get_model_config_var( diff --git a/src/backend/services/deployment.py b/src/backend/services/deployment.py index befb00895a..f6f63fa184 100644 --- a/src/backend/services/deployment.py +++ b/src/backend/services/deployment.py @@ -68,7 +68,7 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_id) - return deployment.to_deployment_info() + return deployment.to_deployment_definition() def get_deployment_definition_by_name(session: DBSessionDep, deployment_name: str) -> DeploymentDefinition: definitions = get_deployment_definitions(session) @@ -84,7 +84,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti } installed_deployments = [ - deployment.to_deployment_info() + deployment.to_deployment_definition() for deployment in AVAILABLE_MODEL_DEPLOYMENTS if deployment.name() not in db_deployments ] diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index ca4f3b31de..a8bae5b61e 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -33,9 +33,9 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep HTTPException: If the deployment and model are not compatible """ - found = deployment_service.get_deployment_info_by_name(session, deployment) + found = deployment_service.get_deployment_definition_by_name(session, deployment) if not found: - found = deployment_service.get_deployment_info(session, deployment) + found = deployment_service.get_deployment_definition(session, deployment) if not found: raise HTTPException( status_code=400, @@ -43,11 +43,19 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep ) # Validate model + # deployment_model = next( + # ( + # model_db + # for model_db in found.models + # if model_db.name == model or model_db.id == model + # ), + # None, + # ) deployment_model = next( ( model_db for model_db in found.models - if model_db.name == model or model_db.id == model + if model_db == model ), None, ) diff --git a/src/backend/tests/integration/conftest.py b/src/backend/tests/integration/conftest.py index 8c9020e999..646e7487f1 100644 --- a/src/backend/tests/integration/conftest.py +++ b/src/backend/tests/integration/conftest.py @@ -9,13 +9,13 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session -from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS from backend.database_models import get_session from backend.database_models.agent import Agent from backend.database_models.deployment import Deployment from backend.database_models.model import Model from backend.main import app, create_app -from backend.schemas.deployment import DeploymentDefinition +# from backend.schemas.deployment import DeploymentDefinition from backend.schemas.organization import Organization from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -184,35 +184,13 @@ def mock_available_model_deployments(request): MockSageMakerDeployment, ) - is_available_values = getattr(request, "param", {}) + # is_available_values = getattr(request, "param", {}) MOCKED_DEPLOYMENTS = { - ModelDeploymentName.CoherePlatform: DeploymentDefinition( - id="cohere_platform", - name=ModelDeploymentName.CoherePlatform, - models=MockCohereDeployment.list_models(), - is_available=is_available_values.get( - ModelDeploymentName.CoherePlatform, True - ), - ), - ModelDeploymentName.SageMaker: DeploymentDefinition( - id="sagemaker", - name=ModelDeploymentName.SageMaker, - models=MockSageMakerDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.SageMaker, True), - ), - ModelDeploymentName.Azure: DeploymentDefinition( - id="azure", - name=ModelDeploymentName.Azure, - models=MockAzureDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.Azure, True), - ), - ModelDeploymentName.Bedrock: DeploymentDefinition( - id="bedrock", - name=ModelDeploymentName.Bedrock, - models=MockBedrockDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.Bedrock, True), - ), + MockCohereDeployment.name(): MockCohereDeployment, + MockAzureDeployment.name(): MockAzureDeployment, + MockSageMakerDeployment.name(): MockSageMakerDeployment, + MockBedrockDeployment.name(): MockBedrockDeployment, } - with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: + with patch.dict(ALL_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock diff --git a/src/backend/tests/integration/routers/test_agent.py b/src/backend/tests/integration/routers/test_agent.py index 9661606fe2..8ec03c70e6 100644 --- a/src/backend/tests/integration/routers/test_agent.py +++ b/src/backend/tests/integration/routers/test_agent.py @@ -2,10 +2,10 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.config.deployments import ModelDeploymentName from backend.config.tools import ToolName from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata +from backend.model_deployments.cohere_platform import CohereDeployment from backend.tests.unit.factories import get_factory @@ -17,7 +17,7 @@ def test_create_agent(session_client: TestClient, session: Session, user) -> Non "preamble": "test preamble", "temperature": 0.5, "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), "tools": [ToolName.Calculator, ToolName.Search_File, ToolName.Read_File], } @@ -58,7 +58,7 @@ def test_create_agent_with_tool_metadata( "preamble": "test preamble", "temperature": 0.5, "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), "tools": [ToolName.Google_Drive, ToolName.Search_File], "tools_metadata": [ { @@ -112,7 +112,7 @@ def test_create_agent_missing_non_required_fields( request_json = { "name": "test agent", "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.post( @@ -155,7 +155,7 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non "preamble": "updated preamble", "temperature": 0.7, "model": "command-r", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.put( @@ -172,4 +172,4 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non assert updated_agent["preamble"] == "updated preamble" assert updated_agent["temperature"] == 0.7 assert updated_agent["model"] == "command-r" - assert updated_agent["deployment"] == ModelDeploymentName.CoherePlatform + assert updated_agent["deployment"] == CohereDeployment.name() diff --git a/src/backend/tests/integration/routers/test_conversation.py b/src/backend/tests/integration/routers/test_conversation.py index 7d48fc4305..3a471b6eed 100644 --- a/src/backend/tests/integration/routers/test_conversation.py +++ b/src/backend/tests/integration/routers/test_conversation.py @@ -5,8 +5,8 @@ from sqlalchemy.orm import Session from backend.config import Settings -from backend.config.deployments import ModelDeploymentName from backend.database_models import Conversation +from backend.model_deployments.cohere_platform import CohereDeployment from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -54,7 +54,7 @@ def test_search_conversations_with_reranking( "/v1/conversations:search", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, params={"query": "color"}, ) diff --git a/src/backend/tests/unit/configuration.yaml b/src/backend/tests/unit/configuration.yaml index a620a18a20..e36097ecd7 100644 --- a/src/backend/tests/unit/configuration.yaml +++ b/src/backend/tests/unit/configuration.yaml @@ -2,15 +2,22 @@ deployments: default_deployment: enabled_deployments: sagemaker: - region_name: - endpoint_name: + access_key: "sagemaker_access_key" + secret_key: "sagemaker_secret" + session_token: "sagemaker_session_token" + region_name: "sagemaker-region" + endpoint_name: "http://www.example.com/sagemaker" azure: - endpoint_url: + api_key: "azure_api_key" + endpoint_url: "http://www.example.com/azure" bedrock: - region_name: + region_name: "bedrock-region" + access_key: "bedrock_access_key" + secret_key: "bedrock_secret" + session_token: "bedrock_session_token" single_container: - model: - url: + model: "single_container_model" + url: "http://www.example.com/single_container" database: url: redis: diff --git a/src/backend/tests/unit/conftest.py b/src/backend/tests/unit/conftest.py index 3e0c6c8ca9..4e3a804199 100644 --- a/src/backend/tests/unit/conftest.py +++ b/src/backend/tests/unit/conftest.py @@ -9,11 +9,10 @@ from sqlalchemy import create_engine from sqlalchemy.orm import Session -from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS from backend.database_models import get_session from backend.database_models.base import CustomFilterQuery from backend.main import app, create_app -from backend.schemas.deployment import Deployment from backend.schemas.organization import Organization from backend.schemas.user import User from backend.tests.unit.factories import get_factory @@ -165,43 +164,15 @@ def mock_available_model_deployments(request): MockSageMakerDeployment, ) - is_available_values = getattr(request, "param", {}) + # is_available_values = getattr(request, "param", {}) MOCKED_DEPLOYMENTS = { - ModelDeploymentName.CoherePlatform: Deployment( - id="cohere_platform", - name=ModelDeploymentName.CoherePlatform, - models=MockCohereDeployment.list_models(), - is_available=is_available_values.get( - ModelDeploymentName.CoherePlatform, True - ), - deployment_class=MockCohereDeployment, - env_vars=["COHERE_VAR_1", "COHERE_VAR_2"], - ), - ModelDeploymentName.SageMaker: Deployment( - id="sagemaker", - name=ModelDeploymentName.SageMaker, - models=MockSageMakerDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.SageMaker, True), - deployment_class=MockSageMakerDeployment, - env_vars=["SAGEMAKER_VAR_1", "SAGEMAKER_VAR_2"], - ), - ModelDeploymentName.Azure: Deployment( - id="azure", - name=ModelDeploymentName.Azure, - models=MockAzureDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.Azure, True), - deployment_class=MockAzureDeployment, - env_vars=["SAGEMAKER_VAR_1", "SAGEMAKER_VAR_2"], - ), - ModelDeploymentName.Bedrock: Deployment( - id="bedrock", - name=ModelDeploymentName.Bedrock, - models=MockBedrockDeployment.list_models(), - is_available=is_available_values.get(ModelDeploymentName.Bedrock, True), - deployment_class=MockBedrockDeployment, - env_vars=["BEDROCK_VAR_1", "BEDROCK_VAR_2"], - ), + MockCohereDeployment.name(): MockCohereDeployment, + MockAzureDeployment.name(): MockAzureDeployment, + MockSageMakerDeployment.name(): MockSageMakerDeployment, + MockBedrockDeployment.name(): MockBedrockDeployment, } - with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: + # with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: + with patch("backend.config.deployments.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock: + # with patch.dict(ALL_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock: yield mock diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py index 7104e5c603..4dde6d8d86 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_azure.py @@ -3,18 +3,28 @@ from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent -from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( + MockDeployment, +) -class MockAzureDeployment(BaseDeployment): +class MockAzureDeployment(MockDeployment): """Mocked Azure Deployment.""" DEFAULT_MODELS = ["azure-command"] - @property - def rerank_enabled(self) -> bool: + @classmethod + def name(cls) -> str: + return "Azure" + + @classmethod + def env_vars(cls) -> List[str]: + return ["AZURE_API_KEY", "AZURE_CHAT_ENDPOINT_URL"] + + @classmethod + def rerank_enabled(cls) -> bool: return False @classmethod diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_base.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_base.py new file mode 100644 index 0000000000..584f36e399 --- /dev/null +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_base.py @@ -0,0 +1,4 @@ +from backend.model_deployments.base import BaseDeployment + + +class MockDeployment(BaseDeployment): ... diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py index 798d235070..6a7fe4e09c 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_bedrock.py @@ -3,16 +3,26 @@ from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent -from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( + MockDeployment, +) -class MockBedrockDeployment(BaseDeployment): +class MockBedrockDeployment(MockDeployment): """Bedrock Deployment""" DEFAULT_MODELS = ["cohere.command-r-plus-v1:0"] + @classmethod + def name(cls) -> str: + return "Bedrock" + + @classmethod + def env_vars(cls) -> List[str]: + return [] + @property def rerank_enabled(self) -> bool: return False diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py index 3fe818d497..3839974cdb 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_cohere_platform.py @@ -3,16 +3,26 @@ from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent -from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( + MockDeployment, +) -class MockCohereDeployment(BaseDeployment): +class MockCohereDeployment(MockDeployment): """Mocked Cohere Platform Deployment.""" DEFAULT_MODELS = ["command", "command-r"] + @classmethod + def name(cls) -> str: + return "Cohere Platform" + + @classmethod + def env_vars(cls) -> List[str]: + return ["COHERE_API_KEY"] + @property def rerank_enabled(self) -> bool: return True diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py index b68e312518..2f64aebd91 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_sagemaker.py @@ -3,16 +3,26 @@ from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent -from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( + MockDeployment, +) -class MockSageMakerDeployment(BaseDeployment): +class MockSageMakerDeployment(MockDeployment): """SageMaker Deployment""" DEFAULT_MODELS = ["command-r"] + @classmethod + def name(cls) -> str: + return "SageMaker" + + @classmethod + def env_vars(cls) -> List[str]: + return [] + @property def rerank_enabled(self) -> bool: return False @@ -25,6 +35,11 @@ def list_models(cls) -> List[str]: def is_available(cls) -> bool: return True + def invoke_chat( + self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any + ) -> Generator[StreamedChatResponse, None, None]: + pass + def invoke_chat_stream( self, chat_request: CohereChatRequest, ctx: Context, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: diff --git a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py index c64f7f5f94..85c2279d8f 100644 --- a/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py +++ b/src/backend/tests/unit/model_deployments/mock_deployments/mock_single_container.py @@ -3,16 +3,26 @@ from cohere.types import StreamedChatResponse from backend.chat.enums import StreamEvent -from backend.model_deployments.base import BaseDeployment from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context +from backend.tests.unit.model_deployments.mock_deployments.mock_base import ( + MockDeployment, +) -class MockSingleContainerDeployment(BaseDeployment): +class MockSingleContainerDeployment(MockDeployment): """Mocked Single Container Deployment.""" DEFAULT_MODELS = ["command-r"] + @classmethod + def name(cls) -> str: + return "Single Container" + + @classmethod + def env_vars(cls) -> List[str]: + return [] + @property def rerank_enabled(self) -> bool: return False diff --git a/src/backend/tests/unit/model_deployments/test_azure.py b/src/backend/tests/unit/model_deployments/test_azure.py index c55cab4e36..afefd12e6a 100644 --- a/src/backend/tests/unit/model_deployments/test_azure.py +++ b/src/backend/tests/unit/model_deployments/test_azure.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient -from backend.config.deployments import ModelDeploymentName from backend.database_models.user import User +from backend.model_deployments.azure import AzureDeployment from backend.tests.unit.model_deployments.mock_deployments import MockAzureDeployment @@ -16,7 +16,7 @@ def test_streamed_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.Azure, + "Deployment-Name": AzureDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) @@ -35,7 +35,7 @@ def test_non_streamed_chat( "/v1/chat", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.Azure, + "Deployment-Name": AzureDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) diff --git a/src/backend/tests/unit/model_deployments/test_bedrock.py b/src/backend/tests/unit/model_deployments/test_bedrock.py index 645b00a779..fa3f77fdea 100644 --- a/src/backend/tests/unit/model_deployments/test_bedrock.py +++ b/src/backend/tests/unit/model_deployments/test_bedrock.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient -from backend.config.deployments import ModelDeploymentName from backend.database_models.user import User +from backend.model_deployments.bedrock import BedrockDeployment from backend.tests.unit.model_deployments.mock_deployments import MockBedrockDeployment @@ -16,7 +16,7 @@ def test_streamed_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.Bedrock, + "Deployment-Name": BedrockDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) @@ -33,7 +33,7 @@ def test_non_streamed_chat( mock_bedrock_deployment.return_value response = session_client_chat.post( "/v1/chat", - headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.Bedrock}, + headers={"User-Id": user.id, "Deployment-Name": BedrockDeployment.name(),}, json={"message": "Hello", "max_tokens": 10}, ) diff --git a/src/backend/tests/unit/model_deployments/test_cohere_platform.py b/src/backend/tests/unit/model_deployments/test_cohere_platform.py index 2ab82cfe56..2041a27f8c 100644 --- a/src/backend/tests/unit/model_deployments/test_cohere_platform.py +++ b/src/backend/tests/unit/model_deployments/test_cohere_platform.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient -from backend.config.deployments import ModelDeploymentName from backend.database_models.user import User +from backend.model_deployments.cohere_platform import CohereDeployment from backend.tests.unit.model_deployments.mock_deployments import MockCohereDeployment @@ -16,7 +16,7 @@ def test_streamed_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) @@ -35,7 +35,7 @@ def test_non_streamed_chat( "/v1/chat", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) diff --git a/src/backend/tests/unit/model_deployments/test_sagemaker.py b/src/backend/tests/unit/model_deployments/test_sagemaker.py index db499498a9..8498329188 100644 --- a/src/backend/tests/unit/model_deployments/test_sagemaker.py +++ b/src/backend/tests/unit/model_deployments/test_sagemaker.py @@ -1,8 +1,8 @@ import pytest from fastapi.testclient import TestClient -from backend.config.deployments import ModelDeploymentName from backend.database_models.user import User +from backend.model_deployments.sagemaker import SageMakerDeployment from backend.tests.unit.model_deployments.mock_deployments import ( MockSageMakerDeployment, ) @@ -17,7 +17,7 @@ def test_streamed_chat( deployment = mock_sagemaker_deployment.return_value response = session_client_chat.post( "/v1/chat-stream", - headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.SageMaker}, + headers={"User-Id": user.id, "Deployment-Name": SageMakerDeployment.name()}, json={"message": "Hello", "max_tokens": 10}, ) @@ -32,7 +32,7 @@ def test_non_streamed_chat( mock_sagemaker_deployment.return_value response = session_client_chat.post( "/v1/chat", - headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.SageMaker}, + headers={"User-Id": user.id, "Deployment-Name": SageMakerDeployment.name()}, json={"message": "Hello", "max_tokens": 10}, ) diff --git a/src/backend/tests/unit/model_deployments/test_single_container.py b/src/backend/tests/unit/model_deployments/test_single_container.py index f74a761bf7..be602f00eb 100644 --- a/src/backend/tests/unit/model_deployments/test_single_container.py +++ b/src/backend/tests/unit/model_deployments/test_single_container.py @@ -1,7 +1,7 @@ from fastapi.testclient import TestClient -from backend.config.deployments import ModelDeploymentName from backend.database_models.user import User +from backend.model_deployments.single_container import SingleContainerDeployment from backend.tests.unit.model_deployments.mock_deployments import ( MockSingleContainerDeployment, ) @@ -18,7 +18,7 @@ def test_streamed_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.SingleContainer, + "Deployment-Name": SingleContainerDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) @@ -35,7 +35,7 @@ def test_non_streamed_chat( "/v1/chat", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": SingleContainerDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index b047318a82..6a87562c5a 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -4,13 +4,14 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.config.deployments import ModelDeploymentName from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import deployment as deployment_crud from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.database_models.snapshot import Snapshot +from backend.exceptions import DeploymentNotFoundError +from backend.model_deployments.cohere_platform import CohereDeployment from backend.tests.unit.factories import get_factory is_cohere_env_set = ( @@ -26,7 +27,7 @@ def test_create_agent_missing_name( "preamble": "test preamble", "temperature": 0.5, "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} @@ -43,7 +44,7 @@ def test_create_agent_missing_model( "description": "test description", "preamble": "test preamble", "temperature": 0.5, - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} @@ -75,7 +76,7 @@ def test_create_agent_missing_user_id_header( request_json = { "name": "test agent", "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.post("/v1/agents", json=request_json) assert response.status_code == 401 @@ -94,13 +95,10 @@ def test_create_agent_invalid_deployment( "deployment": "not a real deployment", } - response = session_client.post( - "/v1/agents", json=request_json, headers={"User-Id": user.id} - ) - assert response.status_code == 400 - assert response.json() == { - "detail": "Deployment not a real deployment not found or is not available in the Database." - } + with pytest.raises(DeploymentNotFoundError): + session_client.post( + "/v1/agents", json=request_json, headers={"User-Id": user.id} + ) @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") @@ -113,14 +111,14 @@ def test_create_agent_deployment_not_in_db( "preamble": "test preamble", "temperature": 0.5, "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } - cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + cohere_deployment = deployment_crud.get_deployment_by_name(session, CohereDeployment.name()) deployment_crud.delete_deployment(session, cohere_deployment.id) response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} ) - cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + cohere_deployment = deployment_crud.get_deployment_by_name(session, CohereDeployment.name()) deployment_models = cohere_deployment.models deployment_models_list = [model.name for model in deployment_models] assert response.status_code == 200 @@ -134,7 +132,7 @@ def test_create_agent_invalid_tool( request_json = { "name": "test agent", "model": "command-r-plus", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), "tools": [ToolName.Calculator, "not a real tool"], } @@ -470,7 +468,7 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non "preamble": "updated preamble", "temperature": 0.7, "model": "command-r", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.put( @@ -487,7 +485,7 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non assert updated_agent["preamble"] == "updated preamble" assert updated_agent["temperature"] == 0.7 assert updated_agent["model"] == "command-r" - assert updated_agent["deployment"] == ModelDeploymentName.CoherePlatform + assert updated_agent["deployment"] == CohereDeployment.name() def test_partial_update_agent(session_client: TestClient, session: Session) -> None: @@ -756,7 +754,7 @@ def test_update_agent_invalid_model( request_json = { "model": "not a real model", - "deployment": ModelDeploymentName.CoherePlatform, + "deployment": CohereDeployment.name(), } response = session_client.put( @@ -785,13 +783,10 @@ def test_update_agent_invalid_deployment( "deployment": "not a real deployment", } - response = session_client.put( - f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": user.id} - ) - assert response.status_code == 400 - assert response.json() == { - "detail": "Deployment not a real deployment not found or is not available in the Database." - } + with pytest.raises(DeploymentNotFoundError): + session_client.put( + f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": user.id} + ) def test_update_agent_invalid_tool( diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 7e8d06ea2e..d615c92672 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -8,11 +8,11 @@ from sqlalchemy.orm import Session from backend.chat.enums import StreamEvent -from backend.config.deployments import ModelDeploymentName from backend.database_models import Agent from backend.database_models.conversation import Conversation from backend.database_models.message import Message, MessageAgent from backend.database_models.user import User +from backend.model_deployments.cohere_platform import CohereDeployment from backend.schemas.tool import Category from backend.tests.unit.factories import get_factory @@ -73,7 +73,7 @@ def test_streaming_new_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={"message": "Hello", "max_tokens": 10}, ) @@ -202,7 +202,7 @@ def test_streaming_chat_with_existing_conversation_from_other_agent( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, params={"agent_id": agent.id}, json={"message": "Hello", "max_tokens": 10, "conversation_id": conversation.id, "agent_id": agent.id}, @@ -263,7 +263,8 @@ def test_streaming_chat_with_agent_tools_and_empty_request_tools( "/v1/chat-stream", headers={ "User-Id": agent.user.id, - "Deployment-Name": agent.deployment, + # "Deployment-Name": agent.deployment, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "Who is a tallest NBA player", @@ -306,7 +307,7 @@ def test_streaming_existing_chat( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "How are you doing?", @@ -328,7 +329,7 @@ def test_fail_chat_missing_user_id( response = session_client_chat.post( "/v1/chat", json={"message": "Hello"}, - headers={"Deployment-Name": ModelDeploymentName.CoherePlatform}, + headers={"Deployment-Name": CohereDeployment.name()}, ) assert response.status_code == 401 @@ -356,7 +357,7 @@ def test_streaming_fail_chat_missing_message( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={}, ) @@ -390,7 +391,7 @@ def test_streaming_chat_with_custom_tools(session_client_chat, session_chat, use }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -413,7 +414,7 @@ def test_streaming_chat_with_managed_tools(session_client_chat, session_chat, us json={"message": "Hello", "tools": [{"name": tool}]}, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -432,7 +433,7 @@ def test_streaming_chat_with_invalid_tool( json={"message": "Hello", "tools": [{"name": "invalid_tool"}]}, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -464,7 +465,7 @@ def test_streaming_chat_with_managed_and_custom_tools( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -484,7 +485,7 @@ def test_streaming_chat_with_search_queries_only( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -515,7 +516,7 @@ def test_streaming_chat_with_chat_history( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -542,7 +543,7 @@ def test_streaming_existing_chat_with_files_attaches_to_user_message( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "How are you doing?", @@ -598,7 +599,7 @@ def test_streaming_existing_chat_with_attached_files_does_not_attach( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "How are you doing?", @@ -633,7 +634,7 @@ def test_streaming_chat_private_agent( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, params={"agent_id": agent.id}, json={"message": "Hello", "max_tokens": 10, "agent_id": agent.id}, @@ -656,7 +657,7 @@ def test_streaming_chat_public_agent( "/v1/chat-stream", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, params={"agent_id": agent.id}, json={"message": "Hello", "max_tokens": 10, "agent_id": agent.id}, @@ -679,7 +680,7 @@ def test_streaming_chat_private_agent_by_another_user( "/v1/chat-stream", headers={ "User-Id": other_user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, params={"agent_id": agent.id}, json={"message": "Hello", "max_tokens": 10, "agent_id": agent.id}, @@ -719,7 +720,7 @@ def test_stream_regenerate_existing_chat( "/v1/chat-stream/regenerate", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "", @@ -744,7 +745,7 @@ def test_stream_regenerate_not_existing_chat( "/v1/chat-stream/regenerate", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "", @@ -769,7 +770,7 @@ def test_stream_regenerate_existing_chat_not_existing_user_messages( "/v1/chat-stream/regenerate", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "", @@ -792,7 +793,7 @@ def test_non_streaming_chat( json={"message": "Hello", "max_tokens": 10}, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -815,7 +816,7 @@ def test_non_streaming_chat_with_managed_tools(session_client_chat, session_chat json={"message": "Hello", "tools": [{"name": tool}]}, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -849,7 +850,7 @@ def test_non_streaming_chat_with_managed_and_custom_tools( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -872,7 +873,7 @@ def test_non_streaming_chat_with_custom_tools(session_client_chat, session_chat, }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -892,7 +893,7 @@ def test_non_streaming_chat_with_search_queries_only( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -918,7 +919,7 @@ def test_non_streaming_chat_with_chat_history( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) @@ -941,7 +942,7 @@ def test_non_streaming_existing_chat_with_files_attaches_to_user_message( "/v1/chat", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "How are you doing?", @@ -988,7 +989,7 @@ def test_non_streaming_existing_chat_with_attached_files_does_not_attach( "/v1/chat", headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, json={ "message": "How are you doing?", @@ -1090,7 +1091,7 @@ def test_streaming_chat_with_files( }, headers={ "User-Id": user.id, - "Deployment-Name": ModelDeploymentName.CoherePlatform, + "Deployment-Name": CohereDeployment.name(), }, ) diff --git a/src/backend/tests/unit/routers/test_deployment.py b/src/backend/tests/unit/routers/test_deployment.py index 7d7888b0e3..a8f5be0f92 100644 --- a/src/backend/tests/unit/routers/test_deployment.py +++ b/src/backend/tests/unit/routers/test_deployment.py @@ -4,8 +4,9 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName +from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS from backend.database_models import Deployment +from backend.model_deployments.cohere_platform import CohereDeployment def test_create_deployment(session_client: TestClient) -> None: @@ -22,13 +23,13 @@ def test_create_deployment(session_client: TestClient) -> None: assert response.status_code == 200 deployment = response.json() assert deployment["name"] == request_json["name"] - assert deployment["env_vars"] == ["COHERE_API_KEY"] + assert deployment["config"] == {"COHERE_API_KEY": 'test-api-key'} assert deployment["is_available"] def test_create_deployment_unique(session_client: TestClient) -> None: request_json = { - "name": ModelDeploymentName.CoherePlatform, + "name": CohereDeployment.name(), "default_deployment_config": {"COHERE_API_KEY": "test-api-key"}, "deployment_class_name": "CohereDeployment", } @@ -38,7 +39,7 @@ def test_create_deployment_unique(session_client: TestClient) -> None: ) assert response.status_code == 400 assert ( - f"Deployment {ModelDeploymentName.CoherePlatform} already exists." + f"Deployment {CohereDeployment.name()} already exists." in response.json()["detail"] ) @@ -67,13 +68,7 @@ def test_list_deployments_has_all_option( response = session_client.get("/v1/deployments?all=1") assert response.status_code == 200 deployments = response.json() - db_deployments = session.query(Deployment).all() - # If no deployments are found in the database, then all available deployments from settings should be returned - if not db_deployments or len(deployments) != len(db_deployments): - db_deployments = [ - deployment for _, deployment in AVAILABLE_MODEL_DEPLOYMENTS.items() - ] - assert len(deployments) == len(db_deployments) + assert len(deployments) == len(AVAILABLE_MODEL_DEPLOYMENTS) def test_list_deployments_no_available_models_404( @@ -112,7 +107,7 @@ def test_update_deployment(session_client: TestClient, session: Session) -> None assert response.status_code == 200 updated_deployment = response.json() assert updated_deployment["name"] == request_json["name"] - assert updated_deployment["env_vars"] == ["COHERE_API_KEY"] + assert updated_deployment["config"] == {"COHERE_API_KEY": 'test-api-key'} assert updated_deployment["is_available"] assert updated_deployment["description"] == request_json["description"] assert updated_deployment["is_community"] == request_json["is_community"] @@ -120,6 +115,7 @@ def test_update_deployment(session_client: TestClient, session: Session) -> None def test_delete_deployment(session_client: TestClient, session: Session) -> None: deployment = session.query(Deployment).first() + assert deployment is not None response = session_client.delete("/v1/deployments/" + deployment.id) deleted = session.query(Deployment).filter(Deployment.id == deployment.id).first() assert response.status_code == 200 @@ -132,10 +128,10 @@ def test_set_env_vars( ) -> None: with patch("backend.services.env.set_key") as mock_set_key: response = client.post( - "/v1/deployments/Cohere+Platform/set_env_vars", + "/v1/deployments/cohere_platform/update_config", json={ "env_vars": { - "COHERE_VAR_1": "TestCohereValue", + "COHERE_API_KEY": "TestCohereValue", }, }, ) @@ -147,7 +143,7 @@ def __eq__(self, other): mock_set_key.assert_called_with( EnvPathMatcher(), - "COHERE_VAR_1", + "COHERE_API_KEY", "TestCohereValue", ) @@ -155,7 +151,7 @@ def __eq__(self, other): def test_set_env_vars_with_invalid_deployment_name( client: TestClient, mock_available_model_deployments: Mock ): - response = client.post("/v1/deployments/unknown/set_env_vars", json={}) + response = client.post("/v1/deployments/unknown/update_config", json={}) assert response.status_code == 404 @@ -163,7 +159,7 @@ def test_set_env_vars_with_var_for_other_deployment( client: TestClient, mock_available_model_deployments: Mock ) -> None: response = client.post( - "/v1/deployments/Cohere+Platform/set_env_vars", + "/v1/deployments/cohere_platform/update_config", json={ "env_vars": { "SAGEMAKER_VAR_1": "TestSageMakerValue", @@ -180,7 +176,7 @@ def test_set_env_vars_with_invalid_var( client: TestClient, mock_available_model_deployments: Mock ) -> None: response = client.post( - "/v1/deployments/Cohere+Platform/set_env_vars", + "/v1/deployments/cohere_platform/update_config", json={ "env_vars": { "API_KEY": "12345", diff --git a/src/backend/tests/unit/config/test_deployments.py b/src/backend/tests/unit/services/test_deployment.py similarity index 89% rename from src/backend/tests/unit/config/test_deployments.py rename to src/backend/tests/unit/services/test_deployment.py index bb6bac146f..c2b9a26dc9 100644 --- a/src/backend/tests/unit/config/test_deployments.py +++ b/src/backend/tests/unit/services/test_deployment.py @@ -1,6 +1,6 @@ from unittest.mock import Mock -from backend.config.deployments import ( +from backend.services.deployment import ( get_default_deployment, ) from backend.tests.unit.model_deployments.mock_deployments.mock_cohere_platform import (