Skip to content

Commit

Permalink
fix database queries
Browse files Browse the repository at this point in the history
  • Loading branch information
yrobla committed Mar 3, 2025
1 parent bba154c commit 769f611
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 34 deletions.
15 changes: 6 additions & 9 deletions src/codegate/api/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,34 +392,31 @@ async def get_workspace_alerts(
except Exception:
logger.exception("Error while getting workspace")
raise HTTPException(status_code=500, detail="Internal server error")

total_alerts = 0
fetched_alerts = []

offset = (page - 1) * page_size
batch_size = page_size * 2 # fetch more alerts per batch to allow deduplication
fetched_alerts = []

while len(fetched_alerts) < page_size:
alerts_batch, total_alerts = await dbreader.get_alerts_by_workspace(
alerts_batch = await dbreader.get_alerts_by_workspace(
ws.id, AlertSeverity.CRITICAL.value, page_size, offset
)
if not alerts_batch:
break

dedup_alerts = await v1_processing.remove_duplicate_alerts(alerts_batch)
fetched_alerts.extend(dedup_alerts)
offset += batch_size
offset += page_size

final_alerts = fetched_alerts[:page_size]
total_alerts = len(fetched_alerts)

prompt_ids = list({alert.prompt_id for alert in final_alerts if alert.prompt_id})
prompts_outputs = await dbreader.get_prompts_with_output(prompt_ids)
alert_conversations = await v1_processing.parse_get_alert_conversation(
final_alerts, prompts_outputs
)
return {
"page": page,
"page_size": page_size,
"total_alerts": total_alerts,
"total_pages": (total_alerts + page_size - 1) // page_size,
"alerts": alert_conversations,
}

Expand Down
38 changes: 17 additions & 21 deletions src/codegate/db/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
from pydantic import BaseModel
from sqlalchemy import CursorResult, TextClause, event, text
from sqlalchemy import CursorResult, TextClause, bindparam, event, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.ext.asyncio import create_async_engine
Expand Down Expand Up @@ -587,11 +587,12 @@ async def get_prompts_with_output(self, prompt_ids: List[str]) -> List[GetPrompt
o.output_cost
FROM prompts p
LEFT JOIN outputs o ON p.id = o.prompt_id
WHERE p.id IN :prompt_ids
WHERE (p.id IN :prompt_ids)
ORDER BY o.timestamp DESC
"""
)
conditions = {"prompt_ids": tuple(prompt_ids)}
).bindparams(bindparam("prompt_ids", expanding=True))

conditions = {"prompt_ids": prompt_ids if prompt_ids else None}
prompts = await self._exec_select_conditions_to_pydantic(
GetPromptWithOutputsRow, sql, conditions, should_raise=True
)
Expand Down Expand Up @@ -659,13 +660,23 @@ async def get_prompts_with_output_alerts_usage_by_workspace_id(

return list(prompts_dict.values())

async def _exec_select_count(self, sql_command: str, conditions: dict) -> int:
"""Executes a COUNT SQL command and returns an integer result."""
async with self._async_db_engine.begin() as conn:
try:
result = await conn.execute(text(sql_command), conditions)
return result.scalar_one() # Ensures it returns exactly one integer value
except Exception as e:
logger.error(f"Failed to execute COUNT query.", error=str(e))
return 0 # Return 0 in case of failure to avoid crashes

async def get_alerts_by_workspace(
self,
workspace_id: str,
trigger_category: Optional[str] = None,
limit: int = API_DEFAULT_PAGE_SIZE,
offset: int = 0,
) -> Tuple[List[Alert], int]:
) -> List[Alert]:
sql = text(
"""
SELECT
Expand All @@ -691,25 +702,10 @@ async def get_alerts_by_workspace(
conditions["limit"] = limit
conditions["offset"] = offset

alerts = await self._exec_select_conditions_to_pydantic(
return await self._exec_select_conditions_to_pydantic(
Alert, sql, conditions, should_raise=True
)

# Count total alerts for pagination
count_sql = text(
"""
SELECT COUNT(*)
FROM alerts a
INNER JOIN prompts p ON p.id = a.prompt_id
WHERE p.workspace_id = :workspace_id
"""
)
if trigger_category:
count_sql = text(count_sql.text + " AND a.trigger_category = :trigger_category")

total_alerts = await self._exec_select_count(count_sql, conditions)
return alerts, total_alerts

async def get_workspaces(self) -> List[WorkspaceWithSessionInfo]:
sql = text(
"""
Expand Down
4 changes: 0 additions & 4 deletions src/codegate/workspaces/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,8 @@ async def hard_delete_workspace(self, workspace_name: str):
return

async def get_workspace_by_name(self, workspace_name: str) -> db_models.WorkspaceRow:
print("i get by name")
workspace = await self._db_reader.get_workspace_by_name(workspace_name)
print("workspace is")
print(workspace)
if not workspace:
print("in not exist")
raise WorkspaceDoesNotExistError(f"Workspace {workspace_name} does not exist.")
return workspace

Expand Down

0 comments on commit 769f611

Please sign in to comment.