Skip to content

Commit

Permalink
Add postgres tests
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed May 22, 2024
1 parent 981ba77 commit f9b9641
Show file tree
Hide file tree
Showing 5 changed files with 329 additions and 24 deletions.
6 changes: 3 additions & 3 deletions elm/utilities/try_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def try_import(package_name):
Returns
-------
p : package
p : module
imported package.
"""
try:
Expand All @@ -26,7 +26,7 @@ def try_import(package_name):
except ImportError:
msg = (f'Unable to import {package_name}. '
'Please ensure you have the package '
'installed and spelled correctly '
'before proceeding.')
'installed. This is an extra requirement '
'for the package you are running')
logger.warning(msg)
warn(msg)
56 changes: 35 additions & 21 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,38 +384,52 @@ class EnergyWizardPostgres(EnergyWizardBase):
"""Interface to ask OpenAI LLMs about energy research.
This class is for execution with a postgres vector database.
Connecting to the database requires the use of the psycopg2
python package, environment variables storing the db user and
password, and the specification of other connection paremeters
such as host, port, and name. The database has the following
columns: id, embedding, chunks, and metadata.
Querying the database requires the use of the psycopg2 and
boto3 python packages, environment variables ('EWIZ_DB_USER'
and 'EWIZ_DB_PASSWORD') storing the db user and password, and
the specification of other connection paremeters such as host,
port, and name. The database has the following columns: id,
embedding, chunks, and metadata.
This class is designed as follows:
Vector database: PostgreSQL database accessed using psycopg2.
Query Embedding: AWS titan using boto3
LLM Application: GPT4 via Azure deployment
"""
EMBEDDING_MODEL = 'amazon.titan-embed-text-v1'

def __init__(self, db_host, db_port, db_name,
model=None, token_budget=3500):
db_schema, db_table, model=None,
token_budget=3500):
"""
Parameters
----------
model : str
GPT model name, default is the DEFAULT_MODEL global var
token_budget : int
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
db_host : str
Host url for postgres database.
db_port : str
Port for postres database. ex: '5432'
db_name : str
Postgres database name.
db_schema : str
Schema name for postres database.
db_table : str
Table to query in Postgres database.
model : str
GPT model name, default is the DEFAULT_MODEL global var
token_budget : int
Number of tokens that can be embedded in the prompt. Note that the
default budget for GPT-3.5-Turbo is 4096, but you want to subtract
some tokens to account for the response budget.
"""
boto3 = try_import('boto3')
psycopg2 = try_import('psycopg2')

db_user = os.getenv("EWIZ_DB_USER")
db_password = os.getenv('EWIZ_DB_PASSWORD')
assert db_user is not None, "Must set user for postgres database!"
assert db_password is not None, "Must set user for postgres database!"
self.db_schema = db_schema
self.db_table = db_table
assert db_user is not None, "Must set EWIZ_DB_USER!"
assert db_password is not None, "Must set EWIZ_DB_PASSWORD!"

self.conn = psycopg2.connect(user=db_user,
password=db_password,
Expand Down Expand Up @@ -497,10 +511,10 @@ def query_vector_db(self, query, limit=100):

query_embedding = self.get_embedding(query)

self.cursor.execute("SELECT ewiz_kb.id, "
"ewiz_kb.chunks, "
"ewiz_kb.embedding <=> %s::vector as score "
"FROM ewiz_schema.ewiz_kb "
self.cursor.execute(f"SELECT {self.db_table}.id, "
f"{self.db_table}.chunks, "
f"{self.db_table}.embedding <=> %s::vector as score "
f"FROM {self.db_schema}.{self.db_table} "
"ORDER BY embedding <=> %s::vector LIMIT %s;",
(query_embedding, query_embedding, limit,), )

Expand Down Expand Up @@ -529,9 +543,9 @@ def make_ref_list(self, ids):

placeholders = ', '.join(['%s'] * len(ids))

sql_query = ("SELECT ewiz_kb.title, ewiz_kb.url "
"FROM ewiz_schema.ewiz_kb "
"WHERE ewiz_kb.id IN (" + placeholders + ")")
sql_query = (f"SELECT {self.db_table}.title, {self.db_table}.url "
f"FROM {self.db_schema}.{self.db_table} "
f"WHERE {self.db_table}.id IN (" + placeholders + ")")

self.cursor.execute(sql_query, ids)

Expand Down
Loading

0 comments on commit f9b9641

Please sign in to comment.