Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: initial work on endpoints for creating/updating workspace config #1107

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,20 +248,16 @@ async def activate_workspace(request: v1_models.ActivateWorkspaceRequest, status

@v1.post("/workspaces", tags=["Workspaces"], generate_unique_id_function=uniq_name, status_code=201)
async def create_workspace(
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
request: v1_models.FullWorkspace,
) -> v1_models.FullWorkspace:
"""Create a new workspace."""
if request.rename_to is not None:
return await rename_workspace(request)
return await create_new_workspace(request)


async def create_new_workspace(
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
# Input validation is done in the model
try:
_ = await wscrud.add_workspace(request.name)
custom_instructions = request.config.custom_instructions if request.config else None
muxing_rules = request.config.muxing_rules if request.config else None

workspace_row, mux_rules = await wscrud.add_workspace(
request.name, custom_instructions, muxing_rules
)
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
Expand All @@ -277,18 +273,38 @@ async def create_new_workspace(
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return v1_models.Workspace(name=request.name, is_active=False)
return v1_models.FullWorkspace(
name=workspace_row.name,
config=v1_models.WorkspaceConfig(
custom_instructions=workspace_row.custom_instructions or "",
muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules],
),
)


async def rename_workspace(
request: v1_models.CreateOrRenameWorkspaceRequest,
) -> v1_models.Workspace:
@v1.put(
"/workspaces/{workspace_name}",
tags=["Workspaces"],
generate_unique_id_function=uniq_name,
status_code=201,
)
async def update_workspace(
workspace_name: str,
request: v1_models.FullWorkspace,
) -> v1_models.FullWorkspace:
"""Update a workspace."""
try:
_ = await wscrud.rename_workspace(request.name, request.rename_to)
custom_instructions = request.config.custom_instructions if request.config else None
muxing_rules = request.config.muxing_rules if request.config else None

workspace_row, mux_rules = await wscrud.update_workspace(
workspace_name,
request.name,
custom_instructions,
muxing_rules,
)
except crud.WorkspaceDoesNotExistError:
raise HTTPException(status_code=404, detail="Workspace does not exist")
except AlreadyExistsError:
raise HTTPException(status_code=409, detail="Workspace already exists")
except ValidationError:
raise HTTPException(
status_code=400,
Expand All @@ -302,7 +318,13 @@ async def rename_workspace(
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")

return v1_models.Workspace(name=request.rename_to, is_active=False)
return v1_models.FullWorkspace(
name=workspace_row.name,
config=v1_models.WorkspaceConfig(
custom_instructions=workspace_row.custom_instructions or "",
muxing_rules=[mux_models.MuxRule.from_db_mux_rule(mux_rule) for mux_rule in mux_rules],
),
)


@v1.delete(
Expand Down
9 changes: 1 addition & 8 deletions src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def from_db_workspaces(


class WorkspaceConfig(pydantic.BaseModel):
system_prompt: str
custom_instructions: str

muxing_rules: List[mux_models.MuxRule]

Expand All @@ -72,13 +72,6 @@ class FullWorkspace(pydantic.BaseModel):
config: Optional[WorkspaceConfig] = None


class CreateOrRenameWorkspaceRequest(FullWorkspace):
# If set, rename the workspace to this name. Note that
# the 'name' field is still required and the workspace
# workspace must exist.
rename_to: Optional[str] = None


class ActivateWorkspaceRequest(pydantic.BaseModel):
name: str

Expand Down
32 changes: 30 additions & 2 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from sqlalchemy import CursorResult, TextClause, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

from codegate.db.fim_cache import FimCache
from codegate.db.models import (
Expand Down Expand Up @@ -65,7 +66,7 @@ def __new__(cls, *args, **kwargs):
# It should only be used for testing
if "_no_singleton" in kwargs and kwargs["_no_singleton"]:
kwargs.pop("_no_singleton")
return super().__new__(cls, *args, **kwargs)
return super().__new__(cls)

if cls._instance is None:
cls._instance = super().__new__(cls)
Expand Down Expand Up @@ -895,6 +896,33 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
return muxes


class DbTransaction:
def __init__(self):
self._session = None

async def __aenter__(self):
self._session = sessionmaker(
bind=DbCodeGate()._async_db_engine,
class_=AsyncSession,
expire_on_commit=False,
)()
await self._session.begin()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type:
await self._session.rollback()
else:
await self._session.commit()
await self._session.close()

async def commit(self):
await self._session.commit()

async def rollback(self):
await self._session.rollback()


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
current_dir = Path(__file__).parent
Expand Down
5 changes: 1 addition & 4 deletions src/codegate/pipeline/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def help(self) -> str:


class CodegateCommandSubcommand(CodegateCommand):

@property
@abstractmethod
def subcommands(self) -> Dict[str, Callable[[List[str]], Awaitable[str]]]:
Expand Down Expand Up @@ -174,7 +173,6 @@ async def run(self, args: List[str]) -> str:


class Workspace(CodegateCommandSubcommand):

def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()

Expand Down Expand Up @@ -258,7 +256,7 @@ async def _rename_workspace(self, flags: Dict[str, str], args: List[str]) -> str
)

try:
await self.workspace_crud.rename_workspace(old_workspace_name, new_workspace_name)
await self.workspace_crud.update_workspace(old_workspace_name, new_workspace_name)
except crud.WorkspaceDoesNotExistError:
return f"Workspace **{old_workspace_name}** does not exist"
except AlreadyExistsError:
Expand Down Expand Up @@ -410,7 +408,6 @@ def help(self) -> str:


class CustomInstructions(CodegateCommandSubcommand):

def __init__(self):
self.workspace_crud = crud.WorkspaceCrud()

Expand Down
Loading
Loading