-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a class which performs semantic routing
Related to: #1055 For the current implementation of muxing we only need to match a single Persona at a time. For example: 1. mux1 -> persona Architect -> openai o1 2. mux2 -> catch all -> openai gpt4o In the above case we would only need to know if the request matches the persona `Architect`. It's not needed to match any extra personas even if they exist in DB. This PR introduces what's necessary to do the above without actually wiring in muxing rules. The PR: - Creates the persona table in DB - Adds methods to write and read to the new persona table - Implements a function to check if a query matches to the specified persona To check more about the personas and the queries please check the unit tests
- Loading branch information
1 parent
809c24a
commit 9dac0af
Showing
6 changed files
with
910 additions
and
2 deletions.
There are no files selected for viewing
50 changes: 50 additions & 0 deletions
50
migrations/versions/2025_03_03_1008-02b710eda156_add_persona_table.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
"""add persona table | ||
Revision ID: 02b710eda156 | ||
Revises: 5e5cd2288147 | ||
Create Date: 2025-03-03 10:08:16.206617+00:00 | ||
""" | ||
|
||
from typing import Sequence, Union | ||
|
||
from alembic import op | ||
|
||
# revision identifiers, used by Alembic. | ||
revision: str = "02b710eda156" | ||
down_revision: Union[str, None] = "5e5cd2288147" | ||
branch_labels: Union[str, Sequence[str], None] = None | ||
depends_on: Union[str, Sequence[str], None] = None | ||
|
||
|
||
def upgrade() -> None: | ||
# Begin transaction | ||
op.execute("BEGIN TRANSACTION;") | ||
|
||
op.execute( | ||
""" | ||
CREATE TABLE IF NOT EXISTS personas ( | ||
id TEXT PRIMARY KEY, -- UUID stored as TEXT | ||
name TEXT NOT NULL UNIQUE, | ||
description TEXT NOT NULL, | ||
description_embedding BLOB NOT NULL | ||
); | ||
""" | ||
) | ||
|
||
# Finish transaction | ||
op.execute("COMMIT;") | ||
|
||
|
||
def downgrade() -> None: | ||
# Begin transaction | ||
op.execute("BEGIN TRANSACTION;") | ||
|
||
op.execute( | ||
""" | ||
DROP TABLE personas; | ||
""" | ||
) | ||
|
||
# Finish transaction | ||
op.execute("COMMIT;") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import unicodedata | ||
import uuid | ||
|
||
import numpy as np | ||
import regex as re | ||
import structlog | ||
|
||
from codegate.config import Config | ||
from codegate.db import models as db_models | ||
from codegate.db.connection import DbReader, DbRecorder | ||
from codegate.inference.inference_engine import LlamaCppInferenceEngine | ||
|
||
logger = structlog.get_logger("codegate") | ||
|
||
|
||
class PersonaDoesNotExistError(Exception): | ||
pass | ||
|
||
|
||
class SemanticRouter: | ||
|
||
def __init__(self): | ||
self._inference_engine = LlamaCppInferenceEngine() | ||
conf = Config.get_config() | ||
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}" | ||
self._n_gpu = conf.chat_model_n_gpu_layers | ||
self._persona_threshold = conf.persona_threshold | ||
self._db_recorder = DbRecorder() | ||
self._db_reader = DbReader() | ||
|
||
def _clean_text_for_embedding(self, text: str) -> str: | ||
""" | ||
Clean the text for embedding. This function should be used to preprocess the text | ||
before embedding. | ||
Performs the following operations: | ||
1. Replaces newlines and carriage returns with spaces | ||
2. Removes extra whitespace | ||
3. Converts to lowercase | ||
4. Removes URLs and email addresses | ||
5. Removes code block markers and other markdown syntax | ||
6. Normalizes Unicode characters | ||
7. Handles special characters and punctuation | ||
8. Normalizes numbers | ||
""" | ||
if not text: | ||
return "" | ||
|
||
# Replace newlines and carriage returns with spaces | ||
text = text.replace("\n", " ").replace("\r", " ") | ||
|
||
# Normalize Unicode characters (e.g., convert accented characters to ASCII equivalents) | ||
text = unicodedata.normalize("NFKD", text) | ||
text = "".join([c for c in text if not unicodedata.combining(c)]) | ||
|
||
# Remove URLs | ||
text = re.sub(r"https?://\S+|www\.\S+", " ", text) | ||
|
||
# Remove email addresses | ||
text = re.sub(r"\S+@\S+", " ", text) | ||
|
||
# Remove code block markers and other markdown/code syntax | ||
text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks | ||
text = re.sub(r"`[^`]*`", " ", text) # Inline code | ||
|
||
# Remove HTML/XML tags | ||
text = re.sub(r"<[^>]+>", " ", text) | ||
|
||
# Normalize numbers (replace with placeholder) | ||
text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers | ||
text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers | ||
|
||
# Replace punctuation with spaces (keeping apostrophes for contractions) | ||
text = re.sub(r"[^\w\s\']", " ", text) | ||
|
||
# Normalize whitespace (replace multiple spaces with a single space) | ||
text = re.sub(r"\s+", " ", text) | ||
|
||
# Convert to lowercase and strip | ||
text = text.strip() | ||
|
||
return text | ||
|
||
async def _embed_text(self, text: str) -> np.ndarray: | ||
""" | ||
Helper function to embed text using the inference engine. | ||
""" | ||
cleaned_text = self._clean_text_for_embedding(text) | ||
# .embed returns a list of embeddings | ||
embed_list = await self._inference_engine.embed( | ||
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu | ||
) | ||
# Use only the first entry in the list and make sure we have the appropriate type | ||
logger.debug("Text embedded in semantic routing", text=cleaned_text[:100]) | ||
return np.array(embed_list[0], dtype=np.float32) | ||
|
||
async def add_persona(self, persona_name: str, persona_desc: str) -> None: | ||
""" | ||
Add a new persona to the database. The persona description is embedded | ||
and stored in the database. | ||
""" | ||
emb_persona_desc = await self._embed_text(persona_desc) | ||
new_persona = db_models.PersonaEmbedding( | ||
id=str(uuid.uuid4()), | ||
name=persona_name, | ||
description=persona_desc, | ||
description_embedding=emb_persona_desc, | ||
) | ||
await self._db_recorder.add_persona(new_persona) | ||
logger.info(f"Added persona {persona_name} to the database.") | ||
|
||
async def check_persona_match(self, persona_name: str, query: str) -> bool: | ||
""" | ||
Check if the query matches the persona description. A vector similarity | ||
search is performed between the query and the persona description. | ||
0 means the vectors are identical, 2 means they are orthogonal. | ||
See | ||
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine) | ||
""" | ||
persona = await self._db_reader.get_persona_by_name(persona_name) | ||
if not persona: | ||
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.") | ||
|
||
emb_query = await self._embed_text(query) | ||
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query) | ||
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance) | ||
if persona_distance.distance < self._persona_threshold: | ||
return True | ||
return False |
Oops, something went wrong.