Skip to content

Commit

Permalink
Fix a number of integration and unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
malexw committed Nov 19, 2024
1 parent 54da111 commit 6b8025e
Show file tree
Hide file tree
Showing 25 changed files with 220 additions and 198 deletions.
19 changes: 13 additions & 6 deletions src/backend/database_models/seeders/deplyments_models_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
from sqlalchemy import text
from sqlalchemy.orm import Session

from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName
from backend.config.deployments import ALL_MODEL_DEPLOYMENTS
from backend.database_models import Deployment, Model, Organization
from backend.model_deployments import (
CohereDeployment,
SingleContainerDeployment,
SageMakerDeployment,
AzureDeployment,
BedrockDeployment,
)
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP,
)
Expand All @@ -18,7 +25,7 @@
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP)

MODELS_NAME_MAPPING = {
ModelDeploymentName.CoherePlatform: {
CohereDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -60,7 +67,7 @@
"is_default": False,
},
},
ModelDeploymentName.SingleContainer: {
SingleContainerDeployment.name(): {
"command": {
"cohere_name": "command",
"is_default": False,
Expand Down Expand Up @@ -102,19 +109,19 @@
"is_default": False,
},
},
ModelDeploymentName.SageMaker: {
SageMakerDeployment.name(): {
"sagemaker-command": {
"cohere_name": "command",
"is_default": True,
},
},
ModelDeploymentName.Azure: {
AzureDeployment.name(): {
"azure-command": {
"cohere_name": "command-r",
"is_default": True,
},
},
ModelDeploymentName.Bedrock: {
BedrockDeployment.name(): {
"cohere.command-r-plus-v1:0": {
"cohere_name": "command-r-plus",
"is_default": True,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def config(cls) -> Dict[str, Any]:
return config.dict() if config else {}

@classmethod
def to_deployment_info(cls) -> DeploymentDefinition:
def to_deployment_definition(cls) -> DeploymentDefinition:
return DeploymentDefinition(
id=cls.id(),
name=cls.name(),
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class SingleContainerDeployment(BaseDeployment):
"""Single Container Deployment."""

client_name = "cohere-toolkit"
config = Settings().get('deployments.single_container')
default_url = config.url
default_model = config.model
sc_config = Settings().get('deployments.single_container')
default_url = sc_config.url
default_model = sc_config.model

def __init__(self, **kwargs: Any):
self.url = get_model_config_var(
Expand Down
4 changes: 2 additions & 2 deletions src/backend/services/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl
except StopIteration:
raise DeploymentNotFoundError(deployment_id=deployment_id)

return deployment.to_deployment_info()
return deployment.to_deployment_definition()

def get_deployment_definition_by_name(session: DBSessionDep, deployment_name: str) -> DeploymentDefinition:
definitions = get_deployment_definitions(session)
Expand All @@ -84,7 +84,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti
}

installed_deployments = [
deployment.to_deployment_info()
deployment.to_deployment_definition()
for deployment in AVAILABLE_MODEL_DEPLOYMENTS
if deployment.name() not in db_deployments
]
Expand Down
14 changes: 11 additions & 3 deletions src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,29 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
HTTPException: If the deployment and model are not compatible
"""
found = deployment_service.get_deployment_info_by_name(session, deployment)
found = deployment_service.get_deployment_definition_by_name(session, deployment)
if not found:
found = deployment_service.get_deployment_info(session, deployment)
found = deployment_service.get_deployment_definition(session, deployment)
if not found:
raise HTTPException(
status_code=400,
detail=f"Deployment {deployment} not found or is not available in the Database.",
)

# Validate model
# deployment_model = next(
# (
# model_db
# for model_db in found.models
# if model_db.name == model or model_db.id == model
# ),
# None,
# )
deployment_model = next(
(
model_db
for model_db in found.models
if model_db.name == model or model_db.id == model
if model_db == model
),
None,
)
Expand Down
38 changes: 8 additions & 30 deletions src/backend/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName
from backend.config.deployments import ALL_MODEL_DEPLOYMENTS
from backend.database_models import get_session
from backend.database_models.agent import Agent
from backend.database_models.deployment import Deployment
from backend.database_models.model import Model
from backend.main import app, create_app
from backend.schemas.deployment import DeploymentDefinition
# from backend.schemas.deployment import DeploymentDefinition
from backend.schemas.organization import Organization
from backend.schemas.user import User
from backend.tests.unit.factories import get_factory
Expand Down Expand Up @@ -184,35 +184,13 @@ def mock_available_model_deployments(request):
MockSageMakerDeployment,
)

is_available_values = getattr(request, "param", {})
# is_available_values = getattr(request, "param", {})
MOCKED_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: DeploymentDefinition(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
models=MockCohereDeployment.list_models(),
is_available=is_available_values.get(
ModelDeploymentName.CoherePlatform, True
),
),
ModelDeploymentName.SageMaker: DeploymentDefinition(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
models=MockSageMakerDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.SageMaker, True),
),
ModelDeploymentName.Azure: DeploymentDefinition(
id="azure",
name=ModelDeploymentName.Azure,
models=MockAzureDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.Azure, True),
),
ModelDeploymentName.Bedrock: DeploymentDefinition(
id="bedrock",
name=ModelDeploymentName.Bedrock,
models=MockBedrockDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.Bedrock, True),
),
MockCohereDeployment.name(): MockCohereDeployment,
MockAzureDeployment.name(): MockAzureDeployment,
MockSageMakerDeployment.name(): MockSageMakerDeployment,
MockBedrockDeployment.name(): MockBedrockDeployment,
}

with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
with patch.dict(ALL_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
yield mock
12 changes: 6 additions & 6 deletions src/backend/tests/integration/routers/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import ToolName
from backend.database_models.agent import Agent
from backend.database_models.agent_tool_metadata import AgentToolMetadata
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.tests.unit.factories import get_factory


Expand All @@ -17,7 +17,7 @@ def test_create_agent(session_client: TestClient, session: Session, user) -> Non
"preamble": "test preamble",
"temperature": 0.5,
"model": "command-r-plus",
"deployment": ModelDeploymentName.CoherePlatform,
"deployment": CohereDeployment.name(),
"tools": [ToolName.Calculator, ToolName.Search_File, ToolName.Read_File],
}

Expand Down Expand Up @@ -58,7 +58,7 @@ def test_create_agent_with_tool_metadata(
"preamble": "test preamble",
"temperature": 0.5,
"model": "command-r-plus",
"deployment": ModelDeploymentName.CoherePlatform,
"deployment": CohereDeployment.name(),
"tools": [ToolName.Google_Drive, ToolName.Search_File],
"tools_metadata": [
{
Expand Down Expand Up @@ -112,7 +112,7 @@ def test_create_agent_missing_non_required_fields(
request_json = {
"name": "test agent",
"model": "command-r-plus",
"deployment": ModelDeploymentName.CoherePlatform,
"deployment": CohereDeployment.name(),
}

response = session_client.post(
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non
"preamble": "updated preamble",
"temperature": 0.7,
"model": "command-r",
"deployment": ModelDeploymentName.CoherePlatform,
"deployment": CohereDeployment.name(),
}

response = session_client.put(
Expand All @@ -172,4 +172,4 @@ def test_update_agent(session_client: TestClient, session: Session, user) -> Non
assert updated_agent["preamble"] == "updated preamble"
assert updated_agent["temperature"] == 0.7
assert updated_agent["model"] == "command-r"
assert updated_agent["deployment"] == ModelDeploymentName.CoherePlatform
assert updated_agent["deployment"] == CohereDeployment.name()
4 changes: 2 additions & 2 deletions src/backend/tests/integration/routers/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from sqlalchemy.orm import Session

from backend.config import Settings
from backend.config.deployments import ModelDeploymentName
from backend.database_models import Conversation
from backend.model_deployments.cohere_platform import CohereDeployment
from backend.schemas.user import User
from backend.tests.unit.factories import get_factory

Expand Down Expand Up @@ -54,7 +54,7 @@ def test_search_conversations_with_reranking(
"/v1/conversations:search",
headers={
"User-Id": user.id,
"Deployment-Name": ModelDeploymentName.CoherePlatform,
"Deployment-Name": CohereDeployment.name(),
},
params={"query": "color"},
)
Expand Down
19 changes: 13 additions & 6 deletions src/backend/tests/unit/configuration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@ deployments:
default_deployment:
enabled_deployments:
sagemaker:
region_name:
endpoint_name:
access_key: "sagemaker_access_key"
secret_key: "sagemaker_secret"
session_token: "sagemaker_session_token"
region_name: "sagemaker-region"
endpoint_name: "http://www.example.com/sagemaker"
azure:
endpoint_url:
api_key: "azure_api_key"
endpoint_url: "http://www.example.com/azure"
bedrock:
region_name:
region_name: "bedrock-region"
access_key: "bedrock_access_key"
secret_key: "bedrock_secret"
session_token: "bedrock_session_token"
single_container:
model:
url:
model: "single_container_model"
url: "http://www.example.com/single_container"
database:
url:
redis:
Expand Down
47 changes: 9 additions & 38 deletions src/backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS, ModelDeploymentName
from backend.config.deployments import ALL_MODEL_DEPLOYMENTS
from backend.database_models import get_session
from backend.database_models.base import CustomFilterQuery
from backend.main import app, create_app
from backend.schemas.deployment import Deployment
from backend.schemas.organization import Organization
from backend.schemas.user import User
from backend.tests.unit.factories import get_factory
Expand Down Expand Up @@ -165,43 +164,15 @@ def mock_available_model_deployments(request):
MockSageMakerDeployment,
)

is_available_values = getattr(request, "param", {})
# is_available_values = getattr(request, "param", {})
MOCKED_DEPLOYMENTS = {
ModelDeploymentName.CoherePlatform: Deployment(
id="cohere_platform",
name=ModelDeploymentName.CoherePlatform,
models=MockCohereDeployment.list_models(),
is_available=is_available_values.get(
ModelDeploymentName.CoherePlatform, True
),
deployment_class=MockCohereDeployment,
env_vars=["COHERE_VAR_1", "COHERE_VAR_2"],
),
ModelDeploymentName.SageMaker: Deployment(
id="sagemaker",
name=ModelDeploymentName.SageMaker,
models=MockSageMakerDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.SageMaker, True),
deployment_class=MockSageMakerDeployment,
env_vars=["SAGEMAKER_VAR_1", "SAGEMAKER_VAR_2"],
),
ModelDeploymentName.Azure: Deployment(
id="azure",
name=ModelDeploymentName.Azure,
models=MockAzureDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.Azure, True),
deployment_class=MockAzureDeployment,
env_vars=["SAGEMAKER_VAR_1", "SAGEMAKER_VAR_2"],
),
ModelDeploymentName.Bedrock: Deployment(
id="bedrock",
name=ModelDeploymentName.Bedrock,
models=MockBedrockDeployment.list_models(),
is_available=is_available_values.get(ModelDeploymentName.Bedrock, True),
deployment_class=MockBedrockDeployment,
env_vars=["BEDROCK_VAR_1", "BEDROCK_VAR_2"],
),
MockCohereDeployment.name(): MockCohereDeployment,
MockAzureDeployment.name(): MockAzureDeployment,
MockSageMakerDeployment.name(): MockSageMakerDeployment,
MockBedrockDeployment.name(): MockBedrockDeployment,
}

with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
# with patch.dict(AVAILABLE_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
with patch("backend.config.deployments.AVAILABLE_MODEL_DEPLOYMENTS", list(MOCKED_DEPLOYMENTS.values())) as mock:
# with patch.dict(ALL_MODEL_DEPLOYMENTS, MOCKED_DEPLOYMENTS) as mock:
yield mock
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,28 @@
from cohere.types import StreamedChatResponse

from backend.chat.enums import StreamEvent
from backend.model_deployments.base import BaseDeployment
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.tests.unit.model_deployments.mock_deployments.mock_base import (
MockDeployment,
)


class MockAzureDeployment(BaseDeployment):
class MockAzureDeployment(MockDeployment):
"""Mocked Azure Deployment."""

DEFAULT_MODELS = ["azure-command"]

@property
def rerank_enabled(self) -> bool:
@classmethod
def name(cls) -> str:
return "Azure"

@classmethod
def env_vars(cls) -> List[str]:
return ["AZURE_API_KEY", "AZURE_CHAT_ENDPOINT_URL"]

@classmethod
def rerank_enabled(cls) -> bool:
return False

@classmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from backend.model_deployments.base import BaseDeployment


class MockDeployment(BaseDeployment): ...
Loading

0 comments on commit 6b8025e

Please sign in to comment.