Skip to content

Commit

Permalink
New postgres test
Browse files Browse the repository at this point in the history
  • Loading branch information
spodgorny9 committed May 23, 2024
1 parent aa77658 commit 36696bf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
2 changes: 2 additions & 0 deletions elm/wizard.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def query_vector_db(self, query, limit=100):
ranked strings/scores outputs.
"""

def test(self, question):
return question
def engineer_query(self, query, token_budget=None, new_info_threshold=0.7,
convo=False):
"""Engineer a query for GPT using the corpus of information
Expand Down
38 changes: 15 additions & 23 deletions tests/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from elm import TEST_DATA_DIR
from elm.wizard import EnergyWizardPostgres


FP_REF_TXT = os.path.join(TEST_DATA_DIR, 'postgres_ref_list.txt')
FP_QUERY_TXT = os.path.join(TEST_DATA_DIR, 'postgres_query_test.txt')

Expand All @@ -17,17 +18,6 @@

QUERY_TUPLE = ast.literal_eval(QUERY_TEXT)

DB_HOST = ("aurora-postgres-low-stage.cluster-"
"ccklrxkcenui.us-west-2.rds.amazonaws.com")
DB_PORT = "5432"
DB_NAME = "ewiz_analysis"
DB_SCHEMA = "ewiz_schema"
DB_TABLE = "ewiz_kb"

os.environ['AWS_ACCESS_KEY_ID'] = "dummy"
os.environ['AWS_SECRET_ACCESS_KEY'] = "dummy"
os.environ['AWS_SESSION_TOKEN'] = "dummy"


class MockClass:
"""Dummy class to mock EnergyWizardPostgres.make_ref_list()"""
Expand All @@ -39,25 +29,27 @@ def ref_call(*args, **kwargs): # pylint: disable=unused-argument

@staticmethod
def query_call(*args, **kwargs): # pylint: disable=unused-argument
"""Mock for EnergyWizardPostgres.make_ref_list()"""
"""Mock for EnergyWizardPostgres.query_vector_db()"""
return QUERY_TUPLE


def test_ref_list(mocker):
"""Test to ensure correct response from research hub."""
wizard = EnergyWizardPostgres(db_host=DB_HOST, db_port=DB_PORT,
db_name=DB_NAME, db_schema=DB_SCHEMA,
db_table=DB_TABLE)

mocker.patch.object(wizard,
'make_ref_list', MockClass.ref_call)
mocker.patch.object(wizard,
'query_vector_db', MockClass.query_call)
"""Test to ensure correct response vector db."""
wizard_mock = mocker.patch('elm.wizard.EnergyWizardPostgres', autospec=True)

wizard = wizard_mock.return_value
wizard.messages = []
wizard.MODEL_INSTRUCTION = "Model instruction dummy."
wizard.token_budget = 500
wizard.model = "dummy-model-name"
wizard.count_tokens = mocker.Mock(return_value=50)
wizard.make_ref_list.side_effect= MockClass.ref_call
wizard.query_vector_db.side_effect = MockClass.query_call

question = "What is a dummy question?"
wizard.messages.append({"role": "user", "content": question})

message = wizard.engineer_query(question)[0]
refs = wizard.engineer_query(question)[-1]
message, refs = EnergyWizardPostgres.engineer_query(wizard, question)

assert len(refs) > 0
assert 'parentTitle' in str(refs)
Expand Down

0 comments on commit 36696bf

Please sign in to comment.