Skip to content

Commit

Permalink
Added a class which performs semantic routing
Browse files Browse the repository at this point in the history
Related to: #1055

For the current implementation of muxing we only need
to match a single Persona at a time. For example:
1. mux1 -> persona Architect -> openai o1
2. mux2 -> catch all -> openai gpt4o

In the above case we would only need to know if the request
matches the persona `Architect`. It's not needed to match
any extra personas even if they exist in DB.

This PR introduces what's necessary to do the above without
actually wiring in muxing rules. The PR:
- Creates the persona table in DB
- Adds methods to write and read to the new persona table
- Implements a function to check if a query matches to the specified persona

To check more about the personas and the queries please check the unit tests
  • Loading branch information
aponcedeleonch committed Mar 3, 2025
1 parent 809c24a commit 9dac0af
Show file tree
Hide file tree
Showing 6 changed files with 910 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""add persona table
Revision ID: 02b710eda156
Revises: 5e5cd2288147
Create Date: 2025-03-03 10:08:16.206617+00:00
"""

from typing import Sequence, Union

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "02b710eda156"
down_revision: Union[str, None] = "5e5cd2288147"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
CREATE TABLE IF NOT EXISTS personas (
id TEXT PRIMARY KEY, -- UUID stored as TEXT
name TEXT NOT NULL UNIQUE,
description TEXT NOT NULL,
description_embedding BLOB NOT NULL
);
"""
)

# Finish transaction
op.execute("COMMIT;")


def downgrade() -> None:
# Begin transaction
op.execute("BEGIN TRANSACTION;")

op.execute(
"""
DROP TABLE personas;
"""
)

# Finish transaction
op.execute("COMMIT;")
1 change: 1 addition & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class Config:
force_certs: bool = False

max_fim_hash_lifetime: int = 60 * 5 # Time in seconds. Default is 5 minutes.
persona_threshold = 0.75 # Min value is 0 (max similarity), max value is 2 (orthogonal)

# Provider URLs with defaults
provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy())
Expand Down
103 changes: 102 additions & 1 deletion src/codegate/db/connection.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import json
import sqlite3
import uuid
from pathlib import Path
from typing import Dict, List, Optional, Type

import numpy as np
import sqlite_vec_sl_tmp
import structlog
from alembic import command as alembic_command
from alembic.config import Config as AlembicConfig
Expand All @@ -22,6 +25,9 @@
IntermediatePromptWithOutputUsageAlerts,
MuxRule,
Output,
Persona,
PersonaDistance,
PersonaEmbedding,
Prompt,
ProviderAuthMaterial,
ProviderEndpoint,
Expand Down Expand Up @@ -65,7 +71,7 @@ def __new__(cls, *args, **kwargs):
# It should only be used for testing
if "_no_singleton" in kwargs and kwargs["_no_singleton"]:
kwargs.pop("_no_singleton")
return super().__new__(cls, *args, **kwargs)
return super().__new__(cls)

if cls._instance is None:
cls._instance = super().__new__(cls)
Expand All @@ -92,6 +98,22 @@ def __init__(self, sqlite_path: Optional[str] = None, **kwargs):
}
self._async_db_engine = create_async_engine(**engine_dict)

def _get_vec_db_connection(self):
"""
Vector database connection is a separate connection to the SQLite database. aiosqlite
does not support loading extensions, so we need to use the sqlite3 module to load the
vector extension.
"""
try:
conn = sqlite3.connect(self._db_path)
conn.enable_load_extension(True)
sqlite_vec_sl_tmp.load(conn)
conn.enable_load_extension(False)
return conn
except Exception:
logger.exception("Failed to initialize vector database connection")
raise

def does_db_exist(self):
return self._db_path.is_file()

Expand Down Expand Up @@ -523,6 +545,30 @@ async def add_mux(self, mux: MuxRule) -> MuxRule:
added_mux = await self._execute_update_pydantic_model(mux, sql, should_raise=True)
return added_mux

async def add_persona(self, persona: PersonaEmbedding) -> None:
"""Add a new Persona to the DB.
This handles validation and insertion of a new persona.
It may raise a AlreadyExistsError if the persona already exists.
"""
sql = text(
"""
INSERT INTO personas (id, name, description, description_embedding)
VALUES (:id, :name, :description, :description_embedding)
"""
)

try:
# For Pydantic we conver the numpy array to a string when serializing.
# We need to convert it back to a numpy array before inserting it into the DB.
persona_dict = persona.model_dump()
persona_dict["description_embedding"] = persona.description_embedding
await self._execute_with_no_return(sql, persona_dict)
except IntegrityError as e:
logger.debug(f"Exception type: {type(e)}")
raise AlreadyExistsError(f"Persona '{persona.name}' already exists.")


class DbReader(DbCodeGate):
def __init__(self, sqlite_path: Optional[str] = None, *args, **kwargs):
Expand Down Expand Up @@ -569,6 +615,18 @@ async def _exec_select_conditions_to_pydantic(
raise e
return None

async def _exec_vec_db_query(
self, sql_command: str, conditions: dict
) -> Optional[CursorResult]:
"""
Execute a query on the vector database. This is a separate connection to the SQLite
database that has the vector extension loaded.
"""
conn = self._get_vec_db_connection()
cursor = conn.cursor()
cursor.execute(sql_command, conditions)
return cursor

async def get_prompts_with_output(self, workpace_id: str) -> List[GetPromptWithOutputsRow]:
sql = text(
"""
Expand Down Expand Up @@ -893,6 +951,49 @@ async def get_muxes_by_workspace(self, workspace_id: str) -> List[MuxRule]:
)
return muxes

async def get_persona_by_name(self, persona_name: str) -> Optional[Persona]:
"""
Get a persona by name.
"""
sql = text(
"""
SELECT
id, name, description
FROM personas
WHERE name = :name
"""
)
conditions = {"name": persona_name}
personas = await self._exec_select_conditions_to_pydantic(
Persona, sql, conditions, should_raise=True
)
return personas[0] if personas else None

async def get_distance_to_persona(
self, persona_id: str, query_embedding: np.ndarray
) -> PersonaDistance:
"""
Get the distance between a persona and a query embedding.
"""
sql = """
SELECT
id,
name,
description,
vec_distance_cosine(description_embedding, :query_embedding) as distance
FROM personas
WHERE id = :id
"""
conditions = {"id": persona_id, "query_embedding": query_embedding}
persona_distance_cursor = await self._exec_vec_db_query(sql, conditions)
persona_distance_raw = persona_distance_cursor.fetchone()
return PersonaDistance(
id=persona_distance_raw[0],
name=persona_distance_raw[1],
description=persona_distance_raw[2],
distance=persona_distance_raw[3],
)


def init_db_sync(db_path: Optional[str] = None):
"""DB will be initialized in the constructor in case it doesn't exist."""
Expand Down
39 changes: 38 additions & 1 deletion src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from enum import Enum
from typing import Annotated, Any, Dict, List, Optional

from pydantic import BaseModel, StringConstraints
import numpy as np
from pydantic import BaseModel, BeforeValidator, ConfigDict, PlainSerializer, StringConstraints


class AlertSeverity(str, Enum):
Expand Down Expand Up @@ -240,3 +241,39 @@ class MuxRule(BaseModel):
priority: int
created_at: Optional[datetime.datetime] = None
updated_at: Optional[datetime.datetime] = None


# Pydantic doesn't support numpy arrays out of the box. Defining a custom type
# Reference: https://github.com/pydantic/pydantic/issues/7017
def nd_array_custom_before_validator(x):
# custome before validation logic
return x


def nd_array_custom_serializer(x):
# custome serialization logic
return str(x)


NdArray = Annotated[
np.ndarray,
BeforeValidator(nd_array_custom_before_validator),
PlainSerializer(nd_array_custom_serializer, return_type=str),
]


class Persona(BaseModel):
id: str
name: str
description: str


class PersonaEmbedding(Persona):
description_embedding: NdArray # sqlite-vec will handle numpy arrays directly

# Part of the workaround to allow numpy arrays in pydantic models
model_config = ConfigDict(arbitrary_types_allowed=True)


class PersonaDistance(Persona):
distance: float
129 changes: 129 additions & 0 deletions src/codegate/muxing/semantic_router.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import unicodedata
import uuid

import numpy as np
import regex as re
import structlog

from codegate.config import Config
from codegate.db import models as db_models
from codegate.db.connection import DbReader, DbRecorder
from codegate.inference.inference_engine import LlamaCppInferenceEngine

logger = structlog.get_logger("codegate")


class PersonaDoesNotExistError(Exception):
pass


class SemanticRouter:

def __init__(self):
self._inference_engine = LlamaCppInferenceEngine()
conf = Config.get_config()
self._embeddings_model = f"{conf.model_base_path}/{conf.embedding_model}"
self._n_gpu = conf.chat_model_n_gpu_layers
self._persona_threshold = conf.persona_threshold
self._db_recorder = DbRecorder()
self._db_reader = DbReader()

def _clean_text_for_embedding(self, text: str) -> str:
"""
Clean the text for embedding. This function should be used to preprocess the text
before embedding.
Performs the following operations:
1. Replaces newlines and carriage returns with spaces
2. Removes extra whitespace
3. Converts to lowercase
4. Removes URLs and email addresses
5. Removes code block markers and other markdown syntax
6. Normalizes Unicode characters
7. Handles special characters and punctuation
8. Normalizes numbers
"""
if not text:
return ""

# Replace newlines and carriage returns with spaces
text = text.replace("\n", " ").replace("\r", " ")

# Normalize Unicode characters (e.g., convert accented characters to ASCII equivalents)
text = unicodedata.normalize("NFKD", text)
text = "".join([c for c in text if not unicodedata.combining(c)])

# Remove URLs
text = re.sub(r"https?://\S+|www\.\S+", " ", text)

# Remove email addresses
text = re.sub(r"\S+@\S+", " ", text)

# Remove code block markers and other markdown/code syntax
text = re.sub(r"```[\s\S]*?```", " ", text) # Code blocks
text = re.sub(r"`[^`]*`", " ", text) # Inline code

# Remove HTML/XML tags
text = re.sub(r"<[^>]+>", " ", text)

# Normalize numbers (replace with placeholder)
text = re.sub(r"\b\d+\.\d+\b", " NUM ", text) # Decimal numbers
text = re.sub(r"\b\d+\b", " NUM ", text) # Integer numbers

# Replace punctuation with spaces (keeping apostrophes for contractions)
text = re.sub(r"[^\w\s\']", " ", text)

# Normalize whitespace (replace multiple spaces with a single space)
text = re.sub(r"\s+", " ", text)

# Convert to lowercase and strip
text = text.strip()

return text

async def _embed_text(self, text: str) -> np.ndarray:
"""
Helper function to embed text using the inference engine.
"""
cleaned_text = self._clean_text_for_embedding(text)
# .embed returns a list of embeddings
embed_list = await self._inference_engine.embed(
self._embeddings_model, [cleaned_text], n_gpu_layers=self._n_gpu
)
# Use only the first entry in the list and make sure we have the appropriate type
logger.debug("Text embedded in semantic routing", text=cleaned_text[:100])
return np.array(embed_list[0], dtype=np.float32)

async def add_persona(self, persona_name: str, persona_desc: str) -> None:
"""
Add a new persona to the database. The persona description is embedded
and stored in the database.
"""
emb_persona_desc = await self._embed_text(persona_desc)
new_persona = db_models.PersonaEmbedding(
id=str(uuid.uuid4()),
name=persona_name,
description=persona_desc,
description_embedding=emb_persona_desc,
)
await self._db_recorder.add_persona(new_persona)
logger.info(f"Added persona {persona_name} to the database.")

async def check_persona_match(self, persona_name: str, query: str) -> bool:
"""
Check if the query matches the persona description. A vector similarity
search is performed between the query and the persona description.
0 means the vectors are identical, 2 means they are orthogonal.
See
[sqlite docs](https://alexgarcia.xyz/sqlite-vec/api-reference.html#vec_distance_cosine)
"""
persona = await self._db_reader.get_persona_by_name(persona_name)
if not persona:
raise PersonaDoesNotExistError(f"Persona {persona_name} does not exist.")

emb_query = await self._embed_text(query)
persona_distance = await self._db_reader.get_distance_to_persona(persona.id, emb_query)
logger.info(f"Persona distance to {persona_name}", distance=persona_distance.distance)
if persona_distance.distance < self._persona_threshold:
return True
return False
Loading

0 comments on commit 9dac0af

Please sign in to comment.