Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: Update deployment config on app start #921

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def update_deployment(
db: Session, deployment: Deployment, new_deployment: DeploymentUpdate
) -> Deployment:
"""
Update a deployment by ID.
Update a deployment.

Args:
db (Session): Database session.
Expand All @@ -125,8 +125,10 @@ def update_deployment(
"""
for attr, value in new_deployment.model_dump(exclude_none=True).items():
setattr(deployment, attr, value)

db.commit()
db.refresh(deployment)

return deployment


Expand Down
5 changes: 5 additions & 0 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from backend.config.routers import ROUTER_DEPENDENCIES, RouterName
from backend.config.settings import Settings
from backend.database_models.database import get_session
from backend.exceptions import DeploymentNotFoundError
from backend.routers.agent import router as agent_router
from backend.routers.auth import router as auth_router
Expand All @@ -31,6 +32,7 @@
from backend.routers.snapshot import router as snapshot_router
from backend.routers.tool import router as tool_router
from backend.routers.user import router as user_router
from backend.services import deployment as deployment_service
from backend.services.context import ContextMiddleware, get_context
from backend.services.logger.middleware import LoggingMiddleware

Expand Down Expand Up @@ -108,6 +110,9 @@ def create_app() -> FastAPI:
app.add_middleware(ContextMiddleware) # This should be the first middleware
app.add_exception_handler(SCIMException, scim_exception_handler) # pyright: ignore

# Update Deployments config
deployment_service.update_db_config_from_env(next(get_session()))

return app


Expand Down
4 changes: 2 additions & 2 deletions src/backend/routers/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def delete_deployment(


@router.post("/{deployment_id}/update_config", response_model=DeploymentDefinition)
async def update_config(
async def update_db_config(
*,
deployment_id: DeploymentIdPathParam,
env_vars: UpdateDeploymentEnv,
Expand All @@ -155,7 +155,7 @@ async def update_config(
Set environment variables for the deployment.
"""
return mask_deployment_secrets(
deployment_service.update_config(session, deployment_id, valid_env_vars)
deployment_service.update_db_config(session, deployment_id, valid_env_vars)
)


Expand Down
30 changes: 29 additions & 1 deletion src/backend/services/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_deployment_definitions(session: DBSessionDep) -> list[DeploymentDefiniti

return [*db_deployments.values(), *installed_deployments]

def update_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str, str]) -> DeploymentDefinition:
def update_db_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str, str]) -> DeploymentDefinition:
logger.debug(event="update_config", deployment_id=deployment_id, env_vars=env_vars)

db_deployment = deployment_crud.get_deployment(session, deployment_id)
Expand All @@ -128,3 +128,31 @@ def update_config(session: DBSessionDep, deployment_id: str, env_vars: dict[str,
updated_deployment = get_deployment_definition(session, deployment_id)

return updated_deployment

def update_db_config_from_env(session: DBSessionDep):
try:
for deployment_name, deployment in AVAILABLE_MODEL_DEPLOYMENTS.items():
# Fetch local config
env_config = deployment.config()
# Fetch DB entity
db_deployment = deployment_crud.get_deployment_by_name(session, deployment_name)

# Skip to next if no config or no DB deployment found
if not env_config or not db_deployment:
logger.debug(event="Updating DB deployment config, no config or no DB deployment found.")
continue

db_config = dict(db_deployment.default_deployment_config)

for key, value in env_config.items():
db_config[key] = value

deployment_crud.update_deployment(
session,
db_deployment,
DeploymentUpdate(
default_deployment_config=db_config
)
)
except Exception as e:
logger.error(event=f"Error while updating DB deployment config: {e}")
4 changes: 2 additions & 2 deletions src/backend/tests/unit/services/test_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def test_get_deployment_definitions_with_db_deployments(session, mock_available_
assert any(d.id == "db-mock-cohere-platform-id" for d in definitions)

def test_update_config_db(session, db_deployment) -> None:
deployment_service.update_config(session, db_deployment.id, {"COHERE_API_KEY": "new-db-test-api-key"})
deployment_service.update_db_config(session, db_deployment.id, {"COHERE_API_KEY": "new-db-test-api-key"})
updated_deployment = session.query(Deployment).get("db-mock-cohere-platform-id")
assert updated_deployment.default_deployment_config == {"COHERE_API_KEY": "new-db-test-api-key"}

def test_update_config_no_db_deployments(session, mock_available_model_deployments, clear_db_deployments) -> None:
with patch("backend.services.deployment.update_env_file") as mock_update_env_file:
with patch("backend.services.deployment.get_deployment_definition", return_value=MockCohereDeployment.to_deployment_definition()):
deployment_service.update_config(session, "some-deployment-id", {"API_KEY": "new-api-key"})
deployment_service.update_db_config(session, "some-deployment-id", {"API_KEY": "new-api-key"})
mock_update_env_file.assert_called_with({"API_KEY": "new-api-key"})
Loading