Skip to content

Commit

Permalink
backend: (Part 1) Use DB config values for Deployments during runtime (
Browse files Browse the repository at this point in the history
…#918)

* Update DB for deployments config

* Default to empty string instead of None for config

* Remove whitespace

* Update interface clients

* Update schema

* default to empty dict

* default context config

* All extra validation

* Revert dotenv
  • Loading branch information
tianjing-li authored Jan 23, 2025
1 parent d966233 commit ec0c043
Show file tree
Hide file tree
Showing 17 changed files with 133 additions and 68 deletions.
2 changes: 0 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ async def chat(
Generator[StreamResponse, None, None]: Chat response.
"""
logger = ctx.get_logger()
# TODO Eugene: Discuss with Scott how to get agent here and use the Agent deployment
# Choose the deployment model - validation already performed by request validator
deployment_name = ctx.get_deployment_name()
deployment_model = get_deployment(deployment_name, ctx)

Expand Down
10 changes: 6 additions & 4 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any

from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
"""Get the deployment implementation.
"""
Get the deployment implementation instance.
Args:
deployment (str): Deployment name.
Expand All @@ -18,8 +20,8 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
kwargs["ctx"] = ctx
try:
session = next(get_session())
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
except Exception:
deployment = deployment_service.get_default_deployment(**kwargs)
deployment = deployment_service.get_deployment_instance_by_name(session, name, **kwargs)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment_instance(**kwargs)

return deployment
45 changes: 38 additions & 7 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,37 @@
# First create the nested structure in the YAML file
# Then add the env variables as an AliasChoices in the Field - these aren't nested

class DeploymentSettingsMixin:
"""
Formats deployment config, used prior to saving values to DB
"""

def to_dict(self) -> dict[str, str]:
def get_first_upper(strings: list[str]) -> str | None:
"""
Heuristic method to retrieve the first all upper-case string in a list of strings.
This is needed to match the var used for a deployment.
"""
return next((s for s in strings if s.isupper()), None)

config = dict(self)
fields = self.__fields__.items()

# Retrieve capitalized variable names
new_dict = {}
for old_field_name, field in fields:
choices = field.validation_alias.choices
env_var = get_first_upper(choices)

value = config.get(old_field_name)
if not value:
value = ""

new_dict[env_var] = value

return new_dict


class GoogleOAuthSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
Expand Down Expand Up @@ -274,7 +305,7 @@ class GoogleCloudSettings(BaseSettings, BaseModel):
)


class SageMakerSettings(BaseSettings, BaseModel):
class SageMakerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
endpoint_name: Optional[str] = Field(
default=None,
Expand All @@ -298,7 +329,7 @@ class SageMakerSettings(BaseSettings, BaseModel):
)


class AzureSettings(BaseSettings, BaseModel):
class AzureSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
endpoint_url: Optional[str] = Field(
default=None,
Expand All @@ -309,14 +340,14 @@ class AzureSettings(BaseSettings, BaseModel):
)


class CoherePlatformSettings(BaseSettings, BaseModel):
class CoherePlatformSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
api_key: Optional[str] = Field(
default=None, validation_alias=AliasChoices("COHERE_API_KEY", "api_key")
)


class SingleContainerSettings(BaseSettings, BaseModel):
class SingleContainerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
model: Optional[str] = Field(
default=None, validation_alias=AliasChoices("SINGLE_CONTAINER_MODEL", "model")
Expand All @@ -326,7 +357,7 @@ class SingleContainerSettings(BaseSettings, BaseModel):
)


class BedrockSettings(BaseSettings, BaseModel):
class BedrockSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
region_name: Optional[str] = Field(
default=None,
Expand All @@ -349,15 +380,15 @@ class DeploymentSettings(BaseSettings, BaseModel):
default_deployment: Optional[str] = None
enabled_deployments: Optional[List[str]] = None

sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings())
azure: Optional[AzureSettings] = Field(default=AzureSettings())
bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings())
cohere_platform: Optional[CoherePlatformSettings] = Field(
default=CoherePlatformSettings()
)
sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings())
single_container: Optional[SingleContainerSettings] = Field(
default=SingleContainerSettings()
)
bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings())


class LoggerSettings(BaseSettings, BaseModel):
Expand Down
14 changes: 14 additions & 0 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def get_deployment_by_name(db: Session, deployment_name: str) -> Deployment:
return db.query(Deployment).filter(Deployment.name == deployment_name).first()


def get_deployment_by_class_name(db: Session, deployment_class_name: str) -> Deployment:
"""
Get a deployment by deployment_class_name.
Args:
db (Session): Database session.
deployment_class_name (str): Deployment Class Name.
Returns:
Deployment: Deployment with the given class name.
"""
return db.query(Deployment).filter(Deployment.name == deployment_class_name).first()


def get_deployments(db: Session, offset: int = 0, limit: int = 100) -> list[Deployment]:
"""
List all deployments.
Expand Down
2 changes: 1 addition & 1 deletion src/backend/database_models/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __str__(self):

@property
def is_available(self) -> bool:
# check if the deployment has a default config
# Check if the deployment has a default config
if not self.default_deployment_config:
return False
return all(value != "" for value in self.default_deployment_config.values())
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -30,10 +30,10 @@ class AzureDeployment(BaseDeployment):

def __init__(self, **kwargs: Any):
# Override the environment variable from the request
self.api_key = get_model_config_var(
self.api_key = get_deployment_config_var(
AZURE_API_KEY_ENV_VAR, AzureDeployment.default_api_key, **kwargs
)
self.chat_endpoint_url = get_model_config_var(
self.chat_endpoint_url = get_deployment_config_var(
AZURE_CHAT_URL_ENV_VAR, AzureDeployment.default_chat_endpoint_url, **kwargs
)

Expand Down
10 changes: 6 additions & 4 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def is_community(cls) -> bool:
@classmethod
def config(cls) -> Dict[str, Any]:
config = Settings().get(f"deployments.{cls.id()}")
config_dict = {} if not config else dict(config)
for key, value in config_dict.items():
if value is None:
config_dict[key] = ""

if not config:
config_dict = {}
else:
config_dict = config.to_dict()

return config_dict

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -26,18 +26,18 @@ class BedrockDeployment(BaseDeployment):

def __init__(self, **kwargs: Any):
self.client = cohere.BedrockClient(
aws_access_key=get_model_config_var(
aws_access_key=get_deployment_config_var(
BEDROCK_ACCESS_KEY_ENV_VAR, BedrockDeployment.access_key, **kwargs
),
aws_secret_key=get_model_config_var(
aws_secret_key=get_deployment_config_var(
BEDROCK_SECRET_KEY_ENV_VAR,
BedrockDeployment.secret_access_key,
**kwargs,
),
aws_session_token=get_model_config_var(
aws_session_token=get_deployment_config_var(
BEDROCK_SESSION_TOKEN_ENV_VAR, BedrockDeployment.session_token, **kwargs
),
aws_region=get_model_config_var(
aws_region=get_deployment_config_var(
BEDROCK_REGION_NAME_ENV_VAR, BedrockDeployment.region_name, **kwargs
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
Expand All @@ -22,9 +22,9 @@ class CohereDeployment(BaseDeployment):
api_key = Settings().get('deployments.cohere_platform.api_key')

def __init__(self, **kwargs: Any):
# Override the environment variable from the request
super().__init__(**kwargs)
api_key = get_model_config_var(

api_key = get_deployment_config_var(
COHERE_API_KEY_ENV_VAR, CohereDeployment.api_key, **kwargs
)
self.client = cohere.Client(api_key, client_name=self.client_name)
Expand Down
12 changes: 6 additions & 6 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand Down Expand Up @@ -37,29 +37,29 @@ def __init__(self, **kwargs: Any):
# Create the AWS client for the Bedrock runtime with boto3
self.client = boto3.client(
"sagemaker-runtime",
region_name=get_model_config_var(
region_name=get_deployment_config_var(
SAGE_MAKER_REGION_NAME_ENV_VAR,
SageMakerDeployment.region_name,
**kwargs,
),
aws_access_key_id=get_model_config_var(
aws_access_key_id=get_deployment_config_var(
SAGE_MAKER_ACCESS_KEY_ENV_VAR,
SageMakerDeployment.aws_access_key_id,
**kwargs,
),
aws_secret_access_key=get_model_config_var(
aws_secret_access_key=get_deployment_config_var(
SAGE_MAKER_SECRET_KEY_ENV_VAR,
SageMakerDeployment.aws_secret_access_key,
**kwargs,
),
aws_session_token=get_model_config_var(
aws_session_token=get_deployment_config_var(
SAGE_MAKER_SESSION_TOKEN_ENV_VAR,
SageMakerDeployment.aws_session_token,
**kwargs,
),
)
self.params = {
"EndpointName": get_model_config_var(
"EndpointName": get_deployment_config_var(
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, SageMakerDeployment.endpoint, **kwargs
),
"ContentType": "application/json",
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -23,10 +23,10 @@ class SingleContainerDeployment(BaseDeployment):
default_model = sc_config.model

def __init__(self, **kwargs: Any):
self.url = get_model_config_var(
self.url = get_deployment_config_var(
SC_URL_ENV_VAR, SingleContainerDeployment.default_url, **kwargs
)
self.model = get_model_config_var(
self.model = get_deployment_config_var(
SC_MODEL_ENV_VAR, SingleContainerDeployment.default_model, **kwargs
)
self.client = cohere.Client(
Expand Down
38 changes: 26 additions & 12 deletions src/backend/model_deployments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,40 @@ def class_name_validator(v: str):
return v


def get_model_config_var(var_name: str, default: str, **kwargs: Any) -> str:
"""Get the model config variable.
def get_deployment_config_var(var_name: str, default: str, **kwargs: Any) -> str:
"""
Get the Deployment's config, in order of priority:
1. Request header values
2. DB values
3. Default values (from config)
Args:
var_name (str): Variable name.
model_config (dict): Model config.
var_name (str): Variable name
default (str): Variable default value
Returns:
str: Model config variable value.
str: Deployment config value
"""
ctx = kwargs.get("ctx")
model_config = ctx.model_config if ctx else None
config = (
model_config[var_name]
if model_config and model_config.get(var_name)
else default
)
db_config = kwargs.get("db_config", {})
config = None

# Get Request Header value
ctx_deployment_config = ctx.deployment_config if ctx else {}

if ctx_deployment_config:
config = ctx_deployment_config.get(var_name)

if not config:
# Check if DB config exists, otherwise use default
config = db_config.get(var_name, default)

# After all fallbacks, if config is still invalid
if not config:
raise ValueError(f"Missing model config variable: {var_name}")
raise ValueError(f"Missing deployment config variable: {var_name}")

return config


Expand Down
4 changes: 2 additions & 2 deletions src/backend/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

def get_deployment_for_agent(session: DBSessionDep, deployment, model) -> tuple[CohereDeployment, str | None]:
try:
deployment = deployment_service.get_deployment_by_name(session, deployment)
deployment = deployment_service.get_deployment_instance_by_name(session, deployment)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment()
deployment = deployment_service.get_default_deployment_instance()

model = next((m for m in deployment.models() if m.name == model), None)

Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def with_model(self, model: str) -> "Context":
self.model = model
return self

def with_deployment_config(self, deployment_config=None) -> "Context":
def with_deployment_config(self, deployment_config={}) -> "Context":
if deployment_config:
self.deployment_config = deployment_config
else:
Expand Down
Loading

0 comments on commit ec0c043

Please sign in to comment.