-
Notifications
You must be signed in to change notification settings - Fork 383
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feat/mock_integration_tests
- Loading branch information
Showing
101 changed files
with
6,261 additions
and
3,723 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,140 +1,35 @@ | ||
from enum import StrEnum | ||
|
||
from backend.config.settings import Settings | ||
from backend.model_deployments import ( | ||
AzureDeployment, | ||
BedrockDeployment, | ||
CohereDeployment, | ||
SageMakerDeployment, | ||
SingleContainerDeployment, | ||
) | ||
from backend.model_deployments.azure import AZURE_ENV_VARS | ||
from backend.model_deployments.base import BaseDeployment | ||
from backend.model_deployments.bedrock import BEDROCK_ENV_VARS | ||
from backend.model_deployments.cohere_platform import COHERE_ENV_VARS | ||
from backend.model_deployments.sagemaker import SAGE_MAKER_ENV_VARS | ||
from backend.model_deployments.single_container import SC_ENV_VARS | ||
from backend.schemas.deployment import Deployment | ||
from backend.services.logger.utils import LoggerFactory | ||
|
||
logger = LoggerFactory().get_logger() | ||
|
||
|
||
class ModelDeploymentName(StrEnum): | ||
CoherePlatform = "Cohere Platform" | ||
SageMaker = "SageMaker" | ||
Azure = "Azure" | ||
Bedrock = "Bedrock" | ||
SingleContainer = "Single Container" | ||
|
||
|
||
use_community_features = Settings().get('feature_flags.use_community_features') | ||
ALL_MODEL_DEPLOYMENTS = { d.name(): d for d in BaseDeployment.__subclasses__() } | ||
|
||
# TODO names in the map below should not be the display names but ids | ||
ALL_MODEL_DEPLOYMENTS = { | ||
ModelDeploymentName.CoherePlatform: Deployment( | ||
id="cohere_platform", | ||
name=ModelDeploymentName.CoherePlatform, | ||
deployment_class=CohereDeployment, | ||
models=CohereDeployment.list_models(), | ||
is_available=CohereDeployment.is_available(), | ||
env_vars=COHERE_ENV_VARS, | ||
), | ||
ModelDeploymentName.SingleContainer: Deployment( | ||
id="single_container", | ||
name=ModelDeploymentName.SingleContainer, | ||
deployment_class=SingleContainerDeployment, | ||
models=SingleContainerDeployment.list_models(), | ||
is_available=SingleContainerDeployment.is_available(), | ||
env_vars=SC_ENV_VARS, | ||
), | ||
ModelDeploymentName.SageMaker: Deployment( | ||
id="sagemaker", | ||
name=ModelDeploymentName.SageMaker, | ||
deployment_class=SageMakerDeployment, | ||
models=SageMakerDeployment.list_models(), | ||
is_available=SageMakerDeployment.is_available(), | ||
env_vars=SAGE_MAKER_ENV_VARS, | ||
), | ||
ModelDeploymentName.Azure: Deployment( | ||
id="azure", | ||
name=ModelDeploymentName.Azure, | ||
deployment_class=AzureDeployment, | ||
models=AzureDeployment.list_models(), | ||
is_available=AzureDeployment.is_available(), | ||
env_vars=AZURE_ENV_VARS, | ||
), | ||
ModelDeploymentName.Bedrock: Deployment( | ||
id="bedrock", | ||
name=ModelDeploymentName.Bedrock, | ||
deployment_class=BedrockDeployment, | ||
models=BedrockDeployment.list_models(), | ||
is_available=BedrockDeployment.is_available(), | ||
env_vars=BEDROCK_ENV_VARS, | ||
), | ||
} | ||
|
||
def get_available_deployments() -> list[type[BaseDeployment]]: | ||
installed_deployments = list(ALL_MODEL_DEPLOYMENTS.values()) | ||
|
||
def get_available_deployments() -> dict[ModelDeploymentName, Deployment]: | ||
if use_community_features: | ||
if Settings().get("feature_flags.use_community_features"): | ||
try: | ||
from community.config.deployments import ( | ||
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS_SETUP, | ||
) | ||
|
||
model_deployments = ALL_MODEL_DEPLOYMENTS.copy() | ||
model_deployments.update(COMMUNITY_DEPLOYMENTS_SETUP) | ||
return model_deployments | ||
except ImportError: | ||
installed_deployments.extend(COMMUNITY_DEPLOYMENTS_SETUP.values()) | ||
except ImportError as e: | ||
logger.warning( | ||
event="[Deployments] No available community deployments have been configured" | ||
event="[Deployments] No available community deployments have been configured", ex=e | ||
) | ||
|
||
deployments = Settings().get('deployments.enabled_deployments') | ||
if deployments is not None and len(deployments) > 0: | ||
return { | ||
key: value | ||
for key, value in ALL_MODEL_DEPLOYMENTS.items() | ||
if value.id in Settings().get('deployments.enabled_deployments') | ||
} | ||
|
||
return ALL_MODEL_DEPLOYMENTS | ||
|
||
|
||
def get_default_deployment(**kwargs) -> BaseDeployment: | ||
# Fallback to the first available deployment | ||
fallback = None | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.is_available: | ||
fallback = deployment.deployment_class(**kwargs) | ||
break | ||
|
||
default = Settings().get('deployments.default_deployment') | ||
if default: | ||
return next( | ||
( | ||
v.deployment_class(**kwargs) | ||
for k, v in AVAILABLE_MODEL_DEPLOYMENTS.items() | ||
if v.id == default | ||
), | ||
fallback, | ||
) | ||
else: | ||
return fallback | ||
|
||
|
||
def find_config_by_deployment_id(deployment_id: str) -> Deployment: | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.id == deployment_id: | ||
return deployment | ||
return None | ||
|
||
|
||
def find_config_by_deployment_name(deployment_name: str) -> Deployment: | ||
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): | ||
if deployment.name == deployment_name: | ||
return deployment | ||
return None | ||
enabled_deployment_ids = Settings().get("deployments.enabled_deployments") | ||
if enabled_deployment_ids: | ||
return [ | ||
deployment | ||
for deployment in installed_deployments | ||
if deployment.id() in enabled_deployment_ids | ||
] | ||
|
||
return installed_deployments | ||
|
||
AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
src/backend/database_models/seeders/deployments_models_seed.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from sqlalchemy.orm import Session | ||
|
||
from backend.database_models import Deployment, Model, Organization | ||
|
||
|
||
def deployments_models_seed(op): | ||
""" | ||
Seed default deployments, models, organization, user and agent. | ||
""" | ||
# Previously we would seed the default deployments and models here. We've changed this | ||
# behaviour during a refactor of the deployments module so that deployments and models | ||
# are inserted when they're first used. This solves an issue where seed data would | ||
# sometimes be inserted with invalid config data. | ||
pass | ||
|
||
|
||
def delete_default_models(op): | ||
""" | ||
Delete deployments and models. | ||
""" | ||
session = Session(op.get_bind()) | ||
session.query(Deployment).delete() | ||
session.query(Model).delete() | ||
session.query(Organization).filter_by(id="default").delete() | ||
session.commit() |
Oops, something went wrong.