Skip to content

Commit

Permalink
Makes CustomPostgresConnection async-able
Browse files Browse the repository at this point in the history
  • Loading branch information
vmesel committed Sep 28, 2024
1 parent d6aa495 commit f6147b1
Show file tree
Hide file tree
Showing 7 changed files with 1,181 additions and 922 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ COPY poetry.lock pyproject.toml README.md /app/
COPY pytest.ini /app/dialog_lib/

USER root
RUN apt update -y && apt upgrade -y && apt install gcc libpq-dev -y
RUN apt update -y && apt upgrade -y && apt install gcc libpq-dev postgresql-client -y
RUN pip install -U pip poetry

COPY /etc /app/etc
Expand Down
13 changes: 11 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
.PHONY: test
.PHONY: test bump-beta bump-major bump-minor

test:
poetry run pytest --cov=dialog_lib dialog_lib/tests/
poetry run pytest --cov=dialog_lib dialog_lib/tests/

bump-prepatch:
poetry version --next-phase prepatch

bump-preminor:
poetry version --next-phase preminor

bump-premajor:
poetry version --next-phase premajor
141 changes: 124 additions & 17 deletions dialog_lib/db/memory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import psycopg
from .session import get_session
from langchain_postgres import PostgresChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict

from typing import List
from psycopg import sql

from .models import Chat, ChatMessages
from .session import get_session, get_async_session, get_async_psycopg_connection

from langchain_postgres import PostgresChatMessageHistory
from langchain.schema.messages import BaseMessage, _message_to_dict


class CustomPostgresChatMessageHistory(PostgresChatMessageHistory):
Expand All @@ -16,45 +20,127 @@ def __init__(
*args,
parent_session_id=None,
dbsession=get_session,
async_dbsession=get_async_session,
chats_model=Chat,
chat_messages_model=ChatMessages,
ssl_mode=None,
**kwargs,
):
self.parent_session_id = parent_session_id
self.dbsession = dbsession
self.async_dbsession = async_dbsession
self.chats_model = chats_model
self.chat_messages_model = chat_messages_model
self._connection = psycopg.connect(
kwargs.pop("connection_string"), sslmode=ssl_mode
)
self._async_connection = None # Will be initialized when needed
self._session_id = kwargs.pop("session_id")
self._table_name = kwargs.pop("table_name")

self._table_name = kwargs.pop("table_name", chat_messages_model.__tablename__)

self.cursor = self._connection.cursor()

async def _initialize_async_connection(self):
if self._async_connection is None:
self._async_connection = await get_async_psycopg_connection()
return self._async_connection

def _create_tables_queries(self, table_name):
index_name = f"idx_{table_name}_session_id"
return [
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {table_name} (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
message JSONB NOT NULL,
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);"""
).format(table_name=sql.Identifier(table_name)),
sql.SQL(
"""
CREATE INDEX IF NOT EXISTS {index_name} ON {table_name} (session_id);
"""
).format(
index_name=sql.Identifier(index_name),
table_name=sql.Identifier(table_name)
)
]

def _get_messages_query(self, table_name):
return [
sql.SQL(
"""
SELECT message FROM {table_name} WHERE session_id = {session_id};
"""
).format(
table_name=sql.Identifier(table_name),
session_id=sql.Literal(self._session_id)
)
]

def create_tables(self) -> None:
"""
create table if it does not exist
add a new column for timestamp
Create table if it does not exist
Add a new column for timestamp
"""
create_table_queries = self._create_tables_queries(self._table_name)
for query in create_table_queries:
self.cursor.execute(query)
self._connection.commit()

async def acreate_tables(self) -> None:
"""
Asynchronously create tables.
"""
create_table_queries = self._create_tables_queries(self._table_name)
async_conn = await self._initialize_async_connection()
async with async_conn.cursor() as cursor:
for query in create_table_queries:
await cursor.execute(query)
await async_conn.commit()

def get_messages(self):
"""
Retrieve messages synchronously.
"""
get_messages_query = self._get_messages_query(self._table_name)
for query in get_messages_query:
self.cursor.execute(query)
return self.cursor.fetchall()

async def aget_messages(self):
"""
Retrieve messages asynchronously.
"""
create_table_query = f"""CREATE TABLE IF NOT EXISTS {self.table_name} (
id SERIAL PRIMARY KEY,
session_id TEXT NOT NULL,
message JSONB NOT NULL,
timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);"""
self.cursor.execute(create_table_query)
self.connection.commit()
get_messages_query = self._get_messages_query(self._table_name)
async_conn = await self._initialize_async_connection()
async with async_conn.cursor() as cursor:
for query in get_messages_query:
await cursor.execute(query)
return await cursor.fetchall()

def add_tags(self, tags: str) -> None:
"""Add tags for a given session_id/uuid on chats table"""
"""
Add tags for a given session_id/uuid on chats table.
"""
with self.dbsession() as session:
session.query(self.chats_model).where(
self.chats_model.session_id == self._session_id
).update({getattr(self.chats_model, "tags"): tags})
session.commit()

def add_messages(self, messages: List[BaseMessage]) -> None:
"""
Add messages to the record in PostgreSQL.
"""
for message in messages:
self.add_message(message)

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
"""
Append the message to the record in PostgreSQL.
"""
message = self.chat_messages_model(
session_id=self._session_id, message=_message_to_dict(message)
)
Expand All @@ -63,6 +149,27 @@ def add_message(self, message: BaseMessage) -> None:
self.dbsession.add(message)
self.dbsession.commit()

async def aadd_messages(self, messages: List[BaseMessage]) -> None:
"""
Asynchronously add messages to the record in PostgreSQL.
"""
for message in messages:
await self.aadd_message(message)

async def aadd_message(self, message: BaseMessage) -> None:
"""
Asynchronously append the message to the record in PostgreSQL.
"""
async_conn = await self._initialize_async_connection()
async with async_conn.cursor() as cursor:
await cursor.execute(
sql.SQL("INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)").format(
table_name=sql.Identifier(self._table_name)
),
(self._session_id, _message_to_dict(message))
)
await async_conn.commit()


def generate_memory_instance(
session_id,
Expand Down
59 changes: 51 additions & 8 deletions dialog_lib/db/session.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import os
from functools import lru_cache

import sqlalchemy as sa
from sqlalchemy.orm import Session, sessionmaker

from contextlib import contextmanager
from contextlib import contextmanager, asynccontextmanager
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from psycopg_pool import AsyncConnectionPool

from functools import cache

@cache
def get_engine():
@lru_cache()
def get_sync_engine():
return sa.create_engine(os.environ.get("DATABASE_URL"))

@contextmanager
def session_scope():
with Session(bind=get_engine()) as session:
def sync_session_scope():
with Session(bind=get_sync_engine()) as session:
try:
yield session
session.commit()
Expand All @@ -24,5 +25,47 @@ def session_scope():
session.close()

def get_session():
with session_scope() as session:
with sync_session_scope() as session:
return session

@lru_cache()
def get_async_engine():
return create_async_engine(os.environ.get("DATABASE_URL"))

@asynccontextmanager
async def async_session_scope():
async_session = sessionmaker(
get_async_engine(), class_=AsyncSession, expire_on_commit=False
)
async with async_session() as session:
try:
yield session
await session.commit()
except Exception as exc:
await session.rollback()
raise exc
finally:
await session.close()

async def get_async_session():
async with async_session_scope() as session:
return session

@lru_cache()
def create_async_psycopg_pool():
return AsyncConnectionPool(os.environ.get("DATABASE_URL"))

@asynccontextmanager
async def async_psycopg_connection():
pool = create_async_psycopg_pool()
async with pool.connection() as conn:
try:
yield conn
await conn.commit()
except Exception:
await conn.rollback()
raise

async def get_async_psycopg_connection():
async with async_psycopg_connection() as conn:
return conn
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ services:
build:
context: .
dockerfile: Dockerfile
entrypoint: pytest -vvv
command: pytest -vvv
stdin_open: true
tty: true
depends_on:
Expand Down
Loading

0 comments on commit f6147b1

Please sign in to comment.