Skip to content

Commit

Permalink
DB query optimization and reducing sqlalchemy logs (daxa-ai#575)
Browse files Browse the repository at this point in the history
* Used 'in' construct and reduced number of queries performed while getting snippet details.

* Implemented pagination.

* Added UTs.
  • Loading branch information
shreyas-damle authored Sep 30, 2024
1 parent ac2b289 commit c68f56e
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 9 deletions.
12 changes: 10 additions & 2 deletions pebblo/app/models/sqltables.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import logging

from sqlalchemy import JSON, Column, Integer, create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import declarative_base

from pebblo.app.config.config import var_server_config_dict
from pebblo.app.enums.common import StorageTypes
from pebblo.app.enums.enums import CacheDir, SQLiteTables
from pebblo.app.utils.utils import get_full_path
from pebblo.log import get_logger

logger = get_logger(__name__)

Base = declarative_base()

Expand Down Expand Up @@ -66,7 +71,10 @@ class AiUser(Base):
# Create an engine that stores data in the local directory's my_database.db file.
full_path = get_full_path(CacheDir.HOME_DIR.value)
sqlite_db_path = CacheDir.SQLITE_ENGINE.value.format(full_path)
engine = create_engine(sqlite_db_path, echo=True)
if logger.isEnabledFor(logging.DEBUG):
engine = create_engine(sqlite_db_path, echo=True)
else:
engine = create_engine(sqlite_db_path)

# Create all tables in the engine. This is equivalent to "Create Table" statements in raw SQL.
Base.metadata.create_all(engine)
20 changes: 15 additions & 5 deletions pebblo/app/service/local_ui/loader_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_current_time,
get_full_path,
get_pebblo_server_version,
timeit,
)
from pebblo.log import get_logger

Expand Down Expand Up @@ -58,13 +59,20 @@ def _get_snippet_details(
This function finds snippet details based on labels
"""
response = []
for snippet_id in snippet_ids:
result, output = self.db.query_by_list(
AiSnippetsTable,
filter_key="id",
filter_values=snippet_ids[: ReportConstants.SNIPPET_LIMIT.value],
)

if not result or len(output) == 0:
return response

for row in output:
if len(response) >= ReportConstants.SNIPPET_LIMIT.value:
break
result, output = self.db.query(AiSnippetsTable, {"id": snippet_id})
if not result or len(output) == 0:
continue
snippet_details = output[0].data

snippet_details = row.data
entity_details = {}
topic_details = {}
if snippet_details.get("topicDetails") and snippet_details[
Expand Down Expand Up @@ -351,6 +359,7 @@ def _create_loader_app_model(self, app_list: list) -> LoaderAppModel:
)
return loader_response

@timeit
def get_all_loader_apps(self):
"""
Returns all necessary loader app details required for get all app functionality.
Expand Down Expand Up @@ -402,6 +411,7 @@ def get_all_loader_apps(self):
# Closing the session
self.db.session.close()

@timeit
def get_loader_app_details(self, db: SQLiteClient, app_name: str) -> str:
"""
This function is being used by the loader_doc_service to get data needed to generate pdf.
Expand Down
3 changes: 3 additions & 0 deletions pebblo/app/service/local_ui/retriever_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
AiUser,
)
from pebblo.app.storage.sqlite_db import SQLiteClient
from pebblo.app.utils.utils import timeit
from pebblo.log import get_logger

config_details = var_server_config_dict.get()
Expand Down Expand Up @@ -385,6 +386,7 @@ def prepare_retrieval_app_response(self, app_data, retrieval_data):
)
return json.dumps(response.model_dump(), default=str, indent=4)

@timeit
def get_all_retriever_apps(self):
try:
self.db = SQLiteClient()
Expand Down Expand Up @@ -462,6 +464,7 @@ def get_all_retriever_apps(self):
# Closing the session
self.db.session.close()

@timeit
def get_retriever_app_details(self, app_name):
try:
retrieval_data = []
Expand Down
72 changes: 70 additions & 2 deletions pebblo/app/storage/sqlite_db.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from sqlalchemy import and_, create_engine, text
import logging
from math import ceil
from typing import List, Type

from sqlalchemy import and_, create_engine, func, text
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import flag_modified
from sqlalchemy.orm.decl_api import DeclarativeMeta

from pebblo.app.enums.enums import CacheDir
from pebblo.app.storage.database import Database
Expand All @@ -21,7 +26,10 @@ def _create_engine():
# Create an engine that stores data in the local directory's db file.
full_path = get_full_path(CacheDir.HOME_DIR.value)
sqlite_db_path = CacheDir.SQLITE_ENGINE.value.format(full_path)
engine = create_engine(sqlite_db_path, echo=True)
if logger.isEnabledFor(logging.DEBUG):
engine = create_engine(sqlite_db_path, echo=True)
else:
engine = create_engine(sqlite_db_path)
return engine

def create_session(self):
Expand Down Expand Up @@ -104,6 +112,66 @@ def query_by_id(self, table_obj, id):
)
return False, err

@timeit
def query_by_list(
self,
table_obj: Type[DeclarativeMeta],
filter_key: str,
filter_values: List[str],
page_size: int = 100,
):
"""
Pass filter like list. For example get snippets with ids in [<id1>, <id2>]
:param table_obj: Table object on which query is to be performed
:param filter_key: Search key
:param filter_values: List of strings to be added to filter criteria.
:param page_size: Page size to be used per iteration.
All items from filter_values would be search based on page_size.
"""
table_name = table_obj.__tablename__
try:
logger.debug(f"Fetching data from table {table_name}")
total_records = len(filter_values)
total_pages = ceil(total_records / page_size)
results = []
for page in range(total_pages):
try:
# Calculate start and end indices for the current batch
start_idx = page * page_size
end_idx = start_idx + page_size

# Slice filter_values to match the current batch
current_batch = filter_values[start_idx:end_idx]

logger.debug(
f"Processing batch {page + 1}/{total_pages}, filter values: {current_batch}"
)

# Execute the query for the current batch
batch_result = (
self.session.query(table_obj)
.filter(
func.json_extract(table_obj.data, f"$.{filter_key}").in_(
current_batch
)
)
.all()
)
results.extend(batch_result)
except Exception as err:
logger.error(
f"Failed in fetching data from table {table_name}, Error: {err}"
)
continue

return True, results

except Exception as err:
logger.error(
f"Failed in fetching data from table {table_name}, Error: {err}"
)
return False, []

@timeit
def update_data(self, table_obj, data):
table_name = table_obj.__tablename__
Expand Down
Empty file added tests/app/storage/__init__.py
Empty file.
98 changes: 98 additions & 0 deletions tests/app/storage/test_sqlite_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from unittest.mock import MagicMock

import pytest
from sqlalchemy.orm import Session

from pebblo.app.models.sqltables import AiSnippetsTable

# Assume table_obj is imported from the actual module where the table is defined


@pytest.fixture
def sqlite_client():
"""Fixture for creating an SQLiteClient instance."""
from pebblo.app.storage.sqlite_db import SQLiteClient

client = SQLiteClient()
client.session = MagicMock(spec=Session)
return client


def test_query_by_list_success(sqlite_client, mocker):
"""Test successful query with query_by_list."""
mock_session = sqlite_client.session
table_obj = AiSnippetsTable
filter_key = "id"
filter_values = ["snippet_id1", "snippet_id2"]

# Mocking query result
mock_query = mock_session.query.return_value
mock_filter = mock_query.filter.return_value
mock_filter.all.return_value = ["result1", "result2"] # Mocked results

# Call the method
success, result = sqlite_client.query_by_list(table_obj, filter_key, filter_values)

# Assertions
assert success is True
assert result == ["result1", "result2"]

# Ensure the query was called only once (no pagination)
assert mock_session.query().filter().all.call_count == 1


def test_query_by_list_page_size(sqlite_client):
"""Test successful query with query_by_list to verify max_filter_limit"""
mock_session = sqlite_client.session
table_obj = AiSnippetsTable
filter_key = "id"
filter_values = [
"snippet_id1",
"snippet_id2",
"snippet_id3",
"snippet_id4",
"snippet_id5",
]
page_size = 2

# Mocking query result
mock_result_page_1 = ["result1", "result2"]
mock_result_page_2 = ["result3", "result4"]
mock_result_page_3 = ["result5"]
mock_query = mock_session.query().filter().all
mock_query.side_effect = [
mock_result_page_1,
mock_result_page_2,
mock_result_page_3,
]

# Call the method
success, result = sqlite_client.query_by_list(
table_obj, filter_key, filter_values, page_size
)

# Assertions
assert success is True
assert result == ["result1", "result2", "result3", "result4", "result5"]


def test_query_by_list_failure(sqlite_client):
mock_session = sqlite_client.session
mock_table_obj = AiSnippetsTable
filter_key = "id"
filter_values = ["value1", "value2"]

# Create mock data
page_size = "abcd" # invalid page size

# Call the query_by_list function
success, results = sqlite_client.query_by_list(
table_obj=mock_table_obj,
filter_key=filter_key,
filter_values=filter_values,
page_size=page_size,
)

assert success is False
assert results == []
mock_session.query.assert_not_called()

0 comments on commit c68f56e

Please sign in to comment.