diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 4f7c7f5079..a15482cdfc 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -48,8 +48,6 @@ async def chat( Generator[StreamResponse, None, None]: Chat response. """ logger = ctx.get_logger() - # TODO Eugene: Discuss with Scott how to get agent here and use the Agent deployment - # Choose the deployment model - validation already performed by request validator deployment_name = ctx.get_deployment_name() deployment_model = get_deployment(deployment_name, ctx) diff --git a/src/backend/chat/custom/utils.py b/src/backend/chat/custom/utils.py index 893ba221fc..9b9a46f263 100644 --- a/src/backend/chat/custom/utils.py +++ b/src/backend/chat/custom/utils.py @@ -1,13 +1,15 @@ from typing import Any from backend.database_models.database import get_session +from backend.exceptions import DeploymentNotFoundError from backend.model_deployments.base import BaseDeployment from backend.schemas.context import Context from backend.services import deployment as deployment_service def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment: - """Get the deployment implementation. + """ + Get the deployment implementation instance. Args: deployment (str): Deployment name. @@ -18,8 +20,8 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment: kwargs["ctx"] = ctx try: session = next(get_session()) - deployment = deployment_service.get_deployment_by_name(session, name, **kwargs) - except Exception: - deployment = deployment_service.get_default_deployment(**kwargs) + deployment = deployment_service.get_deployment_instance_by_name(session, name, **kwargs) + except DeploymentNotFoundError: + deployment = deployment_service.get_default_deployment_instance(**kwargs) return deployment diff --git a/src/backend/config/settings.py b/src/backend/config/settings.py index 2f0011de83..087d34a9ca 100644 --- a/src/backend/config/settings.py +++ b/src/backend/config/settings.py @@ -35,6 +35,37 @@ # First create the nested structure in the YAML file # Then add the env variables as an AliasChoices in the Field - these aren't nested +class DeploymentSettingsMixin: + """ + Formats deployment config, used prior to saving values to DB + """ + + def to_dict(self) -> dict[str, str]: + def get_first_upper(strings: list[str]) -> str | None: + """ + Heuristic method to retrieve the first all upper-case string in a list of strings. + + This is needed to match the var used for a deployment. + """ + return next((s for s in strings if s.isupper()), None) + + config = dict(self) + fields = self.__fields__.items() + + # Retrieve capitalized variable names + new_dict = {} + for old_field_name, field in fields: + choices = field.validation_alias.choices + env_var = get_first_upper(choices) + + value = config.get(old_field_name) + if not value: + value = "" + + new_dict[env_var] = value + + return new_dict + class GoogleOAuthSettings(BaseSettings, BaseModel): model_config = SETTINGS_CONFIG @@ -274,7 +305,7 @@ class GoogleCloudSettings(BaseSettings, BaseModel): ) -class SageMakerSettings(BaseSettings, BaseModel): +class SageMakerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin): model_config = SETTINGS_CONFIG endpoint_name: Optional[str] = Field( default=None, @@ -298,7 +329,7 @@ class SageMakerSettings(BaseSettings, BaseModel): ) -class AzureSettings(BaseSettings, BaseModel): +class AzureSettings(BaseSettings, BaseModel, DeploymentSettingsMixin): model_config = SETTINGS_CONFIG endpoint_url: Optional[str] = Field( default=None, @@ -309,14 +340,14 @@ class AzureSettings(BaseSettings, BaseModel): ) -class CoherePlatformSettings(BaseSettings, BaseModel): +class CoherePlatformSettings(BaseSettings, BaseModel, DeploymentSettingsMixin): model_config = SETTINGS_CONFIG api_key: Optional[str] = Field( default=None, validation_alias=AliasChoices("COHERE_API_KEY", "api_key") ) -class SingleContainerSettings(BaseSettings, BaseModel): +class SingleContainerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin): model_config = SETTINGS_CONFIG model: Optional[str] = Field( default=None, validation_alias=AliasChoices("SINGLE_CONTAINER_MODEL", "model") @@ -326,7 +357,7 @@ class SingleContainerSettings(BaseSettings, BaseModel): ) -class BedrockSettings(BaseSettings, BaseModel): +class BedrockSettings(BaseSettings, BaseModel, DeploymentSettingsMixin): model_config = SETTINGS_CONFIG region_name: Optional[str] = Field( default=None, @@ -349,15 +380,15 @@ class DeploymentSettings(BaseSettings, BaseModel): default_deployment: Optional[str] = None enabled_deployments: Optional[List[str]] = None - sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings()) azure: Optional[AzureSettings] = Field(default=AzureSettings()) + bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings()) cohere_platform: Optional[CoherePlatformSettings] = Field( default=CoherePlatformSettings() ) + sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings()) single_container: Optional[SingleContainerSettings] = Field( default=SingleContainerSettings() ) - bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings()) class LoggerSettings(BaseSettings, BaseModel): diff --git a/src/backend/crud/deployment.py b/src/backend/crud/deployment.py index 3f021e7f8b..09668c3cc6 100644 --- a/src/backend/crud/deployment.py +++ b/src/backend/crud/deployment.py @@ -60,6 +60,20 @@ def get_deployment_by_name(db: Session, deployment_name: str) -> Deployment: return db.query(Deployment).filter(Deployment.name == deployment_name).first() +def get_deployment_by_class_name(db: Session, deployment_class_name: str) -> Deployment: + """ + Get a deployment by deployment_class_name. + + Args: + db (Session): Database session. + deployment_class_name (str): Deployment Class Name. + + Returns: + Deployment: Deployment with the given class name. + """ + return db.query(Deployment).filter(Deployment.name == deployment_class_name).first() + + def get_deployments(db: Session, offset: int = 0, limit: int = 100) -> list[Deployment]: """ List all deployments. diff --git a/src/backend/database_models/deployment.py b/src/backend/database_models/deployment.py index 579a2441d6..e2522aa038 100644 --- a/src/backend/database_models/deployment.py +++ b/src/backend/database_models/deployment.py @@ -27,7 +27,7 @@ def __str__(self): @property def is_available(self) -> bool: - # check if the deployment has a default config + # Check if the deployment has a default config if not self.default_deployment_config: return False return all(value != "" for value in self.default_deployment_config.values()) diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index bea01b7743..f2c06c1ecb 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -5,7 +5,7 @@ from backend.chat.collate import to_dict from backend.config.settings import Settings from backend.model_deployments.base import BaseDeployment -from backend.model_deployments.utils import get_model_config_var +from backend.model_deployments.utils import get_deployment_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context @@ -30,10 +30,10 @@ class AzureDeployment(BaseDeployment): def __init__(self, **kwargs: Any): # Override the environment variable from the request - self.api_key = get_model_config_var( + self.api_key = get_deployment_config_var( AZURE_API_KEY_ENV_VAR, AzureDeployment.default_api_key, **kwargs ) - self.chat_endpoint_url = get_model_config_var( + self.chat_endpoint_url = get_deployment_config_var( AZURE_CHAT_URL_ENV_VAR, AzureDeployment.default_chat_endpoint_url, **kwargs ) diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index cae22e68fe..61147a7c36 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -52,10 +52,12 @@ def is_community(cls) -> bool: @classmethod def config(cls) -> Dict[str, Any]: config = Settings().get(f"deployments.{cls.id()}") - config_dict = {} if not config else dict(config) - for key, value in config_dict.items(): - if value is None: - config_dict[key] = "" + + if not config: + config_dict = {} + else: + config_dict = config.to_dict() + return config_dict @classmethod diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index 7241c79dd1..4125ddf21e 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -5,7 +5,7 @@ from backend.chat.collate import to_dict from backend.config.settings import Settings from backend.model_deployments.base import BaseDeployment -from backend.model_deployments.utils import get_model_config_var +from backend.model_deployments.utils import get_deployment_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context @@ -26,18 +26,18 @@ class BedrockDeployment(BaseDeployment): def __init__(self, **kwargs: Any): self.client = cohere.BedrockClient( - aws_access_key=get_model_config_var( + aws_access_key=get_deployment_config_var( BEDROCK_ACCESS_KEY_ENV_VAR, BedrockDeployment.access_key, **kwargs ), - aws_secret_key=get_model_config_var( + aws_secret_key=get_deployment_config_var( BEDROCK_SECRET_KEY_ENV_VAR, BedrockDeployment.secret_access_key, **kwargs, ), - aws_session_token=get_model_config_var( + aws_session_token=get_deployment_config_var( BEDROCK_SESSION_TOKEN_ENV_VAR, BedrockDeployment.session_token, **kwargs ), - aws_region=get_model_config_var( + aws_region=get_deployment_config_var( BEDROCK_REGION_NAME_ENV_VAR, BedrockDeployment.region_name, **kwargs ), ) diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index cbddb750ea..5581cfcf7c 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -6,7 +6,7 @@ from backend.chat.collate import to_dict from backend.config.settings import Settings from backend.model_deployments.base import BaseDeployment -from backend.model_deployments.utils import get_model_config_var +from backend.model_deployments.utils import get_deployment_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context from backend.services.logger.utils import LoggerFactory @@ -22,9 +22,9 @@ class CohereDeployment(BaseDeployment): api_key = Settings().get('deployments.cohere_platform.api_key') def __init__(self, **kwargs: Any): - # Override the environment variable from the request super().__init__(**kwargs) - api_key = get_model_config_var( + + api_key = get_deployment_config_var( COHERE_API_KEY_ENV_VAR, CohereDeployment.api_key, **kwargs ) self.client = cohere.Client(api_key, client_name=self.client_name) diff --git a/src/backend/model_deployments/sagemaker.py b/src/backend/model_deployments/sagemaker.py index b8de329230..eade99c4aa 100644 --- a/src/backend/model_deployments/sagemaker.py +++ b/src/backend/model_deployments/sagemaker.py @@ -6,7 +6,7 @@ from backend.config.settings import Settings from backend.model_deployments.base import BaseDeployment -from backend.model_deployments.utils import get_model_config_var +from backend.model_deployments.utils import get_deployment_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context @@ -37,29 +37,29 @@ def __init__(self, **kwargs: Any): # Create the AWS client for the Bedrock runtime with boto3 self.client = boto3.client( "sagemaker-runtime", - region_name=get_model_config_var( + region_name=get_deployment_config_var( SAGE_MAKER_REGION_NAME_ENV_VAR, SageMakerDeployment.region_name, **kwargs, ), - aws_access_key_id=get_model_config_var( + aws_access_key_id=get_deployment_config_var( SAGE_MAKER_ACCESS_KEY_ENV_VAR, SageMakerDeployment.aws_access_key_id, **kwargs, ), - aws_secret_access_key=get_model_config_var( + aws_secret_access_key=get_deployment_config_var( SAGE_MAKER_SECRET_KEY_ENV_VAR, SageMakerDeployment.aws_secret_access_key, **kwargs, ), - aws_session_token=get_model_config_var( + aws_session_token=get_deployment_config_var( SAGE_MAKER_SESSION_TOKEN_ENV_VAR, SageMakerDeployment.aws_session_token, **kwargs, ), ) self.params = { - "EndpointName": get_model_config_var( + "EndpointName": get_deployment_config_var( SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, SageMakerDeployment.endpoint, **kwargs ), "ContentType": "application/json", diff --git a/src/backend/model_deployments/single_container.py b/src/backend/model_deployments/single_container.py index a9d69ab6a9..fcce3b3828 100644 --- a/src/backend/model_deployments/single_container.py +++ b/src/backend/model_deployments/single_container.py @@ -5,7 +5,7 @@ from backend.chat.collate import to_dict from backend.config.settings import Settings from backend.model_deployments.base import BaseDeployment -from backend.model_deployments.utils import get_model_config_var +from backend.model_deployments.utils import get_deployment_config_var from backend.schemas.cohere_chat import CohereChatRequest from backend.schemas.context import Context @@ -23,10 +23,10 @@ class SingleContainerDeployment(BaseDeployment): default_model = sc_config.model def __init__(self, **kwargs: Any): - self.url = get_model_config_var( + self.url = get_deployment_config_var( SC_URL_ENV_VAR, SingleContainerDeployment.default_url, **kwargs ) - self.model = get_model_config_var( + self.model = get_deployment_config_var( SC_MODEL_ENV_VAR, SingleContainerDeployment.default_model, **kwargs ) self.client = cohere.Client( diff --git a/src/backend/model_deployments/utils.py b/src/backend/model_deployments/utils.py index 79d74edda7..63919f12b5 100644 --- a/src/backend/model_deployments/utils.py +++ b/src/backend/model_deployments/utils.py @@ -18,26 +18,40 @@ def class_name_validator(v: str): return v -def get_model_config_var(var_name: str, default: str, **kwargs: Any) -> str: - """Get the model config variable. +def get_deployment_config_var(var_name: str, default: str, **kwargs: Any) -> str: + """ + Get the Deployment's config, in order of priority: + + 1. Request header values + 2. DB values + 3. Default values (from config) Args: - var_name (str): Variable name. - model_config (dict): Model config. + var_name (str): Variable name + default (str): Variable default value Returns: - str: Model config variable value. + str: Deployment config value """ ctx = kwargs.get("ctx") - model_config = ctx.model_config if ctx else None - config = ( - model_config[var_name] - if model_config and model_config.get(var_name) - else default - ) + db_config = kwargs.get("db_config", {}) + config = None + + # Get Request Header value + ctx_deployment_config = ctx.deployment_config if ctx else {} + + if ctx_deployment_config: + config = ctx_deployment_config.get(var_name) + + if not config: + # Check if DB config exists, otherwise use default + config = db_config.get(var_name, default) + + # After all fallbacks, if config is still invalid if not config: - raise ValueError(f"Missing model config variable: {var_name}") + raise ValueError(f"Missing deployment config variable: {var_name}") + return config diff --git a/src/backend/routers/utils.py b/src/backend/routers/utils.py index ffe13b5abf..92a87b1801 100644 --- a/src/backend/routers/utils.py +++ b/src/backend/routers/utils.py @@ -8,9 +8,9 @@ def get_deployment_for_agent(session: DBSessionDep, deployment, model) -> tuple[CohereDeployment, str | None]: try: - deployment = deployment_service.get_deployment_by_name(session, deployment) + deployment = deployment_service.get_deployment_instance_by_name(session, deployment) except DeploymentNotFoundError: - deployment = deployment_service.get_default_deployment() + deployment = deployment_service.get_default_deployment_instance() model = next((m for m in deployment.models() if m.name == model), None) diff --git a/src/backend/schemas/context.py b/src/backend/schemas/context.py index 6faaa3ec93..ade0185a55 100644 --- a/src/backend/schemas/context.py +++ b/src/backend/schemas/context.py @@ -95,7 +95,7 @@ def with_model(self, model: str) -> "Context": self.model = model return self - def with_deployment_config(self, deployment_config=None) -> "Context": + def with_deployment_config(self, deployment_config={}) -> "Context": if deployment_config: self.deployment_config = deployment_config else: diff --git a/src/backend/services/deployment.py b/src/backend/services/deployment.py index 3f838b58a6..aac7b77b6f 100644 --- a/src/backend/services/deployment.py +++ b/src/backend/services/deployment.py @@ -36,7 +36,7 @@ def create_db_deployment(session: DBSessionDep, deployment: DeploymentDefinition return DeploymentDefinition.from_db_deployment(db_deployment) -def get_default_deployment(**kwargs) -> BaseDeployment: +def get_default_deployment_instance(**kwargs) -> BaseDeployment: try: fallback = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.is_available()) except StopIteration: @@ -55,17 +55,21 @@ def get_default_deployment(**kwargs) -> BaseDeployment: return fallback(**kwargs) -def get_deployment(session: DBSessionDep, deployment_id: str, **kwargs) -> BaseDeployment: +def get_deployment_instance_by_id(session: DBSessionDep, deployment_id: str, **kwargs) -> BaseDeployment: definition = get_deployment_definition(session, deployment_id) - return get_deployment_by_name(session, definition.name, **kwargs) + # TODO: What's the point of fetching by ID if we just fetch by name after? + return get_deployment_instance_by_name(session, definition.name, **kwargs) -def get_deployment_by_name(session: DBSessionDep, deployment_name: str, **kwargs) -> BaseDeployment: +def get_deployment_instance_by_name(session: DBSessionDep, deployment_name: str, **kwargs) -> BaseDeployment: definition = get_deployment_definition_by_name(session, deployment_name) try: - return next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == definition.class_name)( - db_id=definition.id, **definition.config, **kwargs + deployment_class = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.__name__ == definition.class_name) + deployment_instance = deployment_class( + db_id=definition.id, db_config=definition.config, **kwargs ) + + return deployment_instance except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_name) @@ -76,11 +80,10 @@ def get_deployment_definition(session: DBSessionDep, deployment_id: str) -> Depl try: deployment = next(d for d in AVAILABLE_MODEL_DEPLOYMENTS.values() if d.id() == deployment_id) + create_db_deployment(session, deployment.to_deployment_definition()) except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_id) - create_db_deployment(session, deployment.to_deployment_definition()) - return deployment.to_deployment_definition() def get_deployment_definition_by_name(session: DBSessionDep, deployment_name: str) -> DeploymentDefinition: @@ -90,6 +93,7 @@ def get_deployment_definition_by_name(session: DBSessionDep, deployment_name: st except StopIteration: raise DeploymentNotFoundError(deployment_id=deployment_name) + # Creates deployment in DB if it doesn't exist if definition.name not in [d.name for d in deployment_crud.get_deployments(session)]: definition = create_db_deployment(session, definition) diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index f932d26e12..d42a31b15d 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -219,7 +219,7 @@ async def validate_env_vars(session: DBSessionDep, request: Request): deployment_id = unquote_plus(request.path_params.get("deployment_id")) try: - deployment = deployment_service.get_deployment(session, deployment_id) + deployment = deployment_service.get_deployment_instance_by_id(session, deployment_id) except DeploymentNotFoundError: raise HTTPException( status_code=404, detail=f"Deployment {deployment_id} not found." diff --git a/src/backend/tests/unit/services/test_deployment.py b/src/backend/tests/unit/services/test_deployment.py index d3aae770d4..d7eae458c0 100644 --- a/src/backend/tests/unit/services/test_deployment.py +++ b/src/backend/tests/unit/services/test_deployment.py @@ -43,27 +43,27 @@ def test_all_tools_have_id() -> None: def test_get_default_deployment_none_available() -> None: with patch("backend.services.deployment.AVAILABLE_MODEL_DEPLOYMENTS", {}): with pytest.raises(NoAvailableDeploymentsError): - deployment_service.get_default_deployment() + deployment_service.get_default_deployment_instance() def test_get_default_deployment_no_settings(mock_available_model_deployments) -> None: - assert isinstance(deployment_service.get_default_deployment(), MockCohereDeployment) + assert isinstance(deployment_service.get_default_deployment_instance(), MockCohereDeployment) def test_get_default_deployment_with_settings(mock_available_model_deployments) -> None: with patch("backend.config.settings.Settings.get", return_value="azure") as mock_settings: - assert isinstance(deployment_service.get_default_deployment(), MockAzureDeployment) + assert isinstance(deployment_service.get_default_deployment_instance(), MockAzureDeployment) mock_settings.assert_called_once_with("deployments.default_deployment") def test_get_deployment(session, mock_available_model_deployments, db_deployment) -> None: - deployment = deployment_service.get_deployment(session, db_deployment.id) + deployment = deployment_service.get_deployment_instance_by_id(session, db_deployment.id) assert isinstance(deployment, MockCohereDeployment) def test_get_deployment_by_name(session, mock_available_model_deployments, clear_db_deployments) -> None: - deployment = deployment_service.get_deployment_by_name(session, MockCohereDeployment.name()) + deployment = deployment_service.get_deployment_instance_by_name(session, MockCohereDeployment.name()) assert isinstance(deployment, MockCohereDeployment) def test_get_deployment_by_name_wrong_name(session, mock_available_model_deployments) -> None: with pytest.raises(DeploymentNotFoundError): - deployment_service.get_deployment_by_name(session, "wrong-name") + deployment_service.get_deployment_instance_by_name(session, "wrong-name") def test_get_deployment_definition(session, mock_available_model_deployments, db_deployment) -> None: definition = deployment_service.get_deployment_definition(session, "db-mock-cohere-platform-id")