Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: (Part 1) Use DB config values for Deployments during runtime #918

Merged
merged 10 commits into from
Jan 23, 2025
2 changes: 0 additions & 2 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,6 @@ async def chat(
Generator[StreamResponse, None, None]: Chat response.
"""
logger = ctx.get_logger()
# TODO Eugene: Discuss with Scott how to get agent here and use the Agent deployment
# Choose the deployment model - validation already performed by request validator
deployment_name = ctx.get_deployment_name()
deployment_model = get_deployment(deployment_name, ctx)

Expand Down
10 changes: 6 additions & 4 deletions src/backend/chat/custom/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Any

from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services import deployment as deployment_service


def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
"""Get the deployment implementation.
"""
Get the deployment implementation instance.

Args:
deployment (str): Deployment name.
Expand All @@ -18,8 +20,8 @@ def get_deployment(name: str, ctx: Context, **kwargs: Any) -> BaseDeployment:
kwargs["ctx"] = ctx
try:
session = next(get_session())
deployment = deployment_service.get_deployment_by_name(session, name, **kwargs)
except Exception:
deployment = deployment_service.get_default_deployment(**kwargs)
deployment = deployment_service.get_deployment_instance_by_name(session, name, **kwargs)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment_instance(**kwargs)

return deployment
45 changes: 38 additions & 7 deletions src/backend/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,37 @@
# First create the nested structure in the YAML file
# Then add the env variables as an AliasChoices in the Field - these aren't nested

class DeploymentSettingsMixin:
"""
Formats deployment config, used prior to saving values to DB
"""

def to_dict(self) -> dict[str, str]:
def get_first_upper(strings: list[str]) -> str | None:
"""
Heuristic method to retrieve the first all upper-case string in a list of strings.

This is needed to match the var used for a deployment.
"""
return next((s for s in strings if s.isupper()), None)

config = dict(self)
fields = self.__fields__.items()

# Retrieve capitalized variable names
new_dict = {}
for old_field_name, field in fields:
choices = field.validation_alias.choices
env_var = get_first_upper(choices)

value = config.get(old_field_name)
if not value:
value = ""

new_dict[env_var] = value

return new_dict


class GoogleOAuthSettings(BaseSettings, BaseModel):
model_config = SETTINGS_CONFIG
Expand Down Expand Up @@ -274,7 +305,7 @@ class GoogleCloudSettings(BaseSettings, BaseModel):
)


class SageMakerSettings(BaseSettings, BaseModel):
class SageMakerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
endpoint_name: Optional[str] = Field(
default=None,
Expand All @@ -298,7 +329,7 @@ class SageMakerSettings(BaseSettings, BaseModel):
)


class AzureSettings(BaseSettings, BaseModel):
class AzureSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
endpoint_url: Optional[str] = Field(
default=None,
Expand All @@ -309,14 +340,14 @@ class AzureSettings(BaseSettings, BaseModel):
)


class CoherePlatformSettings(BaseSettings, BaseModel):
class CoherePlatformSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
api_key: Optional[str] = Field(
default=None, validation_alias=AliasChoices("COHERE_API_KEY", "api_key")
)


class SingleContainerSettings(BaseSettings, BaseModel):
class SingleContainerSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
model: Optional[str] = Field(
default=None, validation_alias=AliasChoices("SINGLE_CONTAINER_MODEL", "model")
Expand All @@ -326,7 +357,7 @@ class SingleContainerSettings(BaseSettings, BaseModel):
)


class BedrockSettings(BaseSettings, BaseModel):
class BedrockSettings(BaseSettings, BaseModel, DeploymentSettingsMixin):
model_config = SETTINGS_CONFIG
region_name: Optional[str] = Field(
default=None,
Expand All @@ -349,15 +380,15 @@ class DeploymentSettings(BaseSettings, BaseModel):
default_deployment: Optional[str] = None
enabled_deployments: Optional[List[str]] = None

sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings())
azure: Optional[AzureSettings] = Field(default=AzureSettings())
bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings())
cohere_platform: Optional[CoherePlatformSettings] = Field(
default=CoherePlatformSettings()
)
sagemaker: Optional[SageMakerSettings] = Field(default=SageMakerSettings())
single_container: Optional[SingleContainerSettings] = Field(
default=SingleContainerSettings()
)
bedrock: Optional[BedrockSettings] = Field(default=BedrockSettings())


class LoggerSettings(BaseSettings, BaseModel):
Expand Down
14 changes: 14 additions & 0 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def get_deployment_by_name(db: Session, deployment_name: str) -> Deployment:
return db.query(Deployment).filter(Deployment.name == deployment_name).first()


def get_deployment_by_class_name(db: Session, deployment_class_name: str) -> Deployment:
"""
Get a deployment by deployment_class_name.

Args:
db (Session): Database session.
deployment_class_name (str): Deployment Class Name.

Returns:
Deployment: Deployment with the given class name.
"""
return db.query(Deployment).filter(Deployment.name == deployment_class_name).first()


def get_deployments(db: Session, offset: int = 0, limit: int = 100) -> list[Deployment]:
"""
List all deployments.
Expand Down
2 changes: 1 addition & 1 deletion src/backend/database_models/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __str__(self):

@property
def is_available(self) -> bool:
# check if the deployment has a default config
# Check if the deployment has a default config
if not self.default_deployment_config:
return False
return all(value != "" for value in self.default_deployment_config.values())
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -30,10 +30,10 @@ class AzureDeployment(BaseDeployment):

def __init__(self, **kwargs: Any):
# Override the environment variable from the request
self.api_key = get_model_config_var(
self.api_key = get_deployment_config_var(
AZURE_API_KEY_ENV_VAR, AzureDeployment.default_api_key, **kwargs
)
self.chat_endpoint_url = get_model_config_var(
self.chat_endpoint_url = get_deployment_config_var(
AZURE_CHAT_URL_ENV_VAR, AzureDeployment.default_chat_endpoint_url, **kwargs
)

Expand Down
10 changes: 6 additions & 4 deletions src/backend/model_deployments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def is_community(cls) -> bool:
@classmethod
def config(cls) -> Dict[str, Any]:
config = Settings().get(f"deployments.{cls.id()}")
config_dict = {} if not config else dict(config)
for key, value in config_dict.items():
if value is None:
config_dict[key] = ""

if not config:
config_dict = {}
else:
config_dict = config.to_dict()

return config_dict

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions src/backend/model_deployments/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -26,18 +26,18 @@ class BedrockDeployment(BaseDeployment):

def __init__(self, **kwargs: Any):
self.client = cohere.BedrockClient(
aws_access_key=get_model_config_var(
aws_access_key=get_deployment_config_var(
BEDROCK_ACCESS_KEY_ENV_VAR, BedrockDeployment.access_key, **kwargs
),
aws_secret_key=get_model_config_var(
aws_secret_key=get_deployment_config_var(
BEDROCK_SECRET_KEY_ENV_VAR,
BedrockDeployment.secret_access_key,
**kwargs,
),
aws_session_token=get_model_config_var(
aws_session_token=get_deployment_config_var(
BEDROCK_SESSION_TOKEN_ENV_VAR, BedrockDeployment.session_token, **kwargs
),
aws_region=get_model_config_var(
aws_region=get_deployment_config_var(
BEDROCK_REGION_NAME_ENV_VAR, BedrockDeployment.region_name, **kwargs
),
)
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/cohere_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
Expand All @@ -22,9 +22,9 @@ class CohereDeployment(BaseDeployment):
api_key = Settings().get('deployments.cohere_platform.api_key')

def __init__(self, **kwargs: Any):
# Override the environment variable from the request
super().__init__(**kwargs)
api_key = get_model_config_var(

api_key = get_deployment_config_var(
COHERE_API_KEY_ENV_VAR, CohereDeployment.api_key, **kwargs
)
self.client = cohere.Client(api_key, client_name=self.client_name)
Expand Down
12 changes: 6 additions & 6 deletions src/backend/model_deployments/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand Down Expand Up @@ -37,29 +37,29 @@ def __init__(self, **kwargs: Any):
# Create the AWS client for the Bedrock runtime with boto3
self.client = boto3.client(
"sagemaker-runtime",
region_name=get_model_config_var(
region_name=get_deployment_config_var(
SAGE_MAKER_REGION_NAME_ENV_VAR,
SageMakerDeployment.region_name,
**kwargs,
),
aws_access_key_id=get_model_config_var(
aws_access_key_id=get_deployment_config_var(
SAGE_MAKER_ACCESS_KEY_ENV_VAR,
SageMakerDeployment.aws_access_key_id,
**kwargs,
),
aws_secret_access_key=get_model_config_var(
aws_secret_access_key=get_deployment_config_var(
SAGE_MAKER_SECRET_KEY_ENV_VAR,
SageMakerDeployment.aws_secret_access_key,
**kwargs,
),
aws_session_token=get_model_config_var(
aws_session_token=get_deployment_config_var(
SAGE_MAKER_SESSION_TOKEN_ENV_VAR,
SageMakerDeployment.aws_session_token,
**kwargs,
),
)
self.params = {
"EndpointName": get_model_config_var(
"EndpointName": get_deployment_config_var(
SAGE_MAKER_ENDPOINT_NAME_ENV_VAR, SageMakerDeployment.endpoint, **kwargs
),
"ContentType": "application/json",
Expand Down
6 changes: 3 additions & 3 deletions src/backend/model_deployments/single_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from backend.chat.collate import to_dict
from backend.config.settings import Settings
from backend.model_deployments.base import BaseDeployment
from backend.model_deployments.utils import get_model_config_var
from backend.model_deployments.utils import get_deployment_config_var
from backend.schemas.cohere_chat import CohereChatRequest
from backend.schemas.context import Context

Expand All @@ -23,10 +23,10 @@ class SingleContainerDeployment(BaseDeployment):
default_model = sc_config.model

def __init__(self, **kwargs: Any):
self.url = get_model_config_var(
self.url = get_deployment_config_var(
SC_URL_ENV_VAR, SingleContainerDeployment.default_url, **kwargs
)
self.model = get_model_config_var(
self.model = get_deployment_config_var(
SC_MODEL_ENV_VAR, SingleContainerDeployment.default_model, **kwargs
)
self.client = cohere.Client(
Expand Down
38 changes: 26 additions & 12 deletions src/backend/model_deployments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,40 @@ def class_name_validator(v: str):
return v


def get_model_config_var(var_name: str, default: str, **kwargs: Any) -> str:
"""Get the model config variable.
def get_deployment_config_var(var_name: str, default: str, **kwargs: Any) -> str:
"""
Get the Deployment's config, in order of priority:

1. Request header values
2. DB values
3. Default values (from config)

Args:
var_name (str): Variable name.
model_config (dict): Model config.
var_name (str): Variable name
default (str): Variable default value

Returns:
str: Model config variable value.
str: Deployment config value

"""
ctx = kwargs.get("ctx")
model_config = ctx.model_config if ctx else None
config = (
model_config[var_name]
if model_config and model_config.get(var_name)
else default
)
db_config = kwargs.get("db_config", {})
config = None

# Get Request Header value
ctx_deployment_config = ctx.deployment_config if ctx else {}

if ctx_deployment_config:
config = ctx_deployment_config.get(var_name)

if not config:
# Check if DB config exists, otherwise use default
config = db_config.get(var_name, default)

# After all fallbacks, if config is still invalid
if not config:
raise ValueError(f"Missing model config variable: {var_name}")
raise ValueError(f"Missing deployment config variable: {var_name}")

return config


Expand Down
4 changes: 2 additions & 2 deletions src/backend/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

def get_deployment_for_agent(session: DBSessionDep, deployment, model) -> tuple[CohereDeployment, str | None]:
try:
deployment = deployment_service.get_deployment_by_name(session, deployment)
deployment = deployment_service.get_deployment_instance_by_name(session, deployment)
except DeploymentNotFoundError:
deployment = deployment_service.get_default_deployment()
deployment = deployment_service.get_default_deployment_instance()

model = next((m for m in deployment.models() if m.name == model), None)

Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def with_model(self, model: str) -> "Context":
self.model = model
return self

def with_deployment_config(self, deployment_config=None) -> "Context":
def with_deployment_config(self, deployment_config={}) -> "Context":
if deployment_config:
self.deployment_config = deployment_config
else:
Expand Down
Loading
Loading