Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(wren-ai-service): Consolidate SQL Pairs Service and Remove Redundant Code #1268

Merged
merged 14 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Settings(BaseSettings):
config_path: str = Field(default="config.yaml")
_components: list[dict]

sql_pairs_path: str = Field(default="pairs.json")
sql_pairs_path: str = Field(default="sql_pairs.json")

def __init__(self):
load_dotenv(".env.dev", override=True)
Expand Down
79 changes: 32 additions & 47 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,27 @@
from src.core.pipeline import PipelineComponent
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines import generation, indexing, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
from src.web.v1.services.chart import ChartService
from src.web.v1.services.chart_adjustment import ChartAdjustmentService
from src.web.v1.services.question_recommendation import QuestionRecommendation
from src.web.v1.services.relationship_recommendation import RelationshipRecommendation
from src.web.v1.services.semantics_description import SemanticsDescription
from src.web.v1.services.semantics_preparation import SemanticsPreparationService
from src.web.v1.services.sql_answer import SqlAnswerService
from src.web.v1.services.sql_expansion import SqlExpansionService
from src.web.v1.services.sql_explanation import SqlExplanationService
from src.web.v1.services.sql_pairs_preparation import SqlPairsPreparationService
from src.web.v1.services.sql_question import SqlQuestionService
from src.web.v1.services.sql_regeneration import SqlRegenerationService
from src.web.v1 import services

logger = logging.getLogger("wren-ai-service")


@dataclass
class ServiceContainer:
ask_service: AskService
ask_details_service: AskDetailsService
question_recommendation: QuestionRecommendation
relationship_recommendation: RelationshipRecommendation
semantics_description: SemanticsDescription
semantics_preparation_service: SemanticsPreparationService
chart_service: ChartService
chart_adjustment_service: ChartAdjustmentService
sql_answer_service: SqlAnswerService
sql_expansion_service: SqlExpansionService
sql_explanation_service: SqlExplanationService
sql_regeneration_service: SqlRegenerationService
sql_pairs_preparation_service: SqlPairsPreparationService
sql_question_service: SqlQuestionService
ask_service: services.AskService
ask_details_service: services.AskDetailsService
question_recommendation: services.QuestionRecommendation
relationship_recommendation: services.RelationshipRecommendation
semantics_description: services.SemanticsDescription
semantics_preparation_service: services.SemanticsPreparationService
chart_service: services.ChartService
chart_adjustment_service: services.ChartAdjustmentService
sql_answer_service: services.SqlAnswerService
sql_expansion_service: services.SqlExpansionService
sql_explanation_service: services.SqlExplanationService
sql_regeneration_service: services.SqlRegenerationService
sql_pairs_service: services.SqlPairsService
sql_question_service: services.SqlQuestionService


@dataclass
Expand All @@ -58,15 +45,15 @@ def create_service_container(
"ttl": settings.query_cache_ttl,
}
return ServiceContainer(
semantics_description=SemanticsDescription(
semantics_description=services.SemanticsDescription(
pipelines={
"semantics_description": generation.SemanticsDescription(
**pipe_components["semantics_description"],
)
},
**query_cache,
),
semantics_preparation_service=SemanticsPreparationService(
semantics_preparation_service=services.SemanticsPreparationService(
pipelines={
"db_schema": indexing.DBSchema(
**pipe_components["db_schema_indexing"],
Expand All @@ -85,7 +72,7 @@ def create_service_container(
},
**query_cache,
),
ask_service=AskService(
ask_service=services.AskService(
pipelines={
"intent_classification": generation.IntentClassification(
**pipe_components["intent_classification"],
Expand Down Expand Up @@ -127,7 +114,7 @@ def create_service_container(
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
**query_cache,
),
chart_service=ChartService(
chart_service=services.ChartService(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
Expand All @@ -139,7 +126,7 @@ def create_service_container(
},
**query_cache,
),
chart_adjustment_service=ChartAdjustmentService(
chart_adjustment_service=services.ChartAdjustmentService(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
Expand All @@ -151,7 +138,7 @@ def create_service_container(
},
**query_cache,
),
sql_answer_service=SqlAnswerService(
sql_answer_service=services.SqlAnswerService(
pipelines={
"preprocess_sql_data": retrieval.PreprocessSqlData(
**pipe_components["preprocess_sql_data"],
Expand All @@ -163,7 +150,7 @@ def create_service_container(
},
**query_cache,
),
ask_details_service=AskDetailsService(
ask_details_service=services.AskDetailsService(
pipelines={
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
Expand All @@ -175,7 +162,7 @@ def create_service_container(
},
**query_cache,
),
sql_expansion_service=SqlExpansionService(
sql_expansion_service=services.SqlExpansionService(
pipelines={
"retrieval": retrieval.Retrieval(
**pipe_components["db_schema_retrieval"],
Expand All @@ -196,23 +183,23 @@ def create_service_container(
},
**query_cache,
),
sql_explanation_service=SqlExplanationService(
sql_explanation_service=services.SqlExplanationService(
pipelines={
"sql_explanation": generation.SQLExplanation(
**pipe_components["sql_explanation"],
)
},
**query_cache,
),
sql_regeneration_service=SqlRegenerationService(
sql_regeneration_service=services.SqlRegenerationService(
pipelines={
"sql_regeneration": generation.SQLRegeneration(
**pipe_components["sql_regeneration"],
)
},
**query_cache,
),
relationship_recommendation=RelationshipRecommendation(
relationship_recommendation=services.RelationshipRecommendation(
pipelines={
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
Expand All @@ -221,7 +208,7 @@ def create_service_container(
},
**query_cache,
),
question_recommendation=QuestionRecommendation(
question_recommendation=services.QuestionRecommendation(
pipelines={
"question_recommendation": generation.QuestionRecommendation(
**pipe_components["question_recommendation"],
Expand All @@ -242,18 +229,16 @@ def create_service_container(
},
**query_cache,
),
sql_pairs_preparation_service=SqlPairsPreparationService(
sql_pairs_service=services.SqlPairsService(
pipelines={
"sql_pairs_preparation": indexing.SqlPairs(
"sql_pairs": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
),
"sql_pairs_deletion": indexing.SqlPairsDeletion(
**pipe_components["sql_pairs_deletion"],
),
sql_pairs_path=settings.sql_pairs_path,
)
},
**query_cache,
),
sql_question_service=SqlQuestionService(
sql_question_service=services.SqlQuestionService(
pipelines={
"sql_question_generation": generation.SQLQuestion(
**pipe_components["sql_question_generation"],
Expand Down
11 changes: 8 additions & 3 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,26 @@ def build_table_ddl(
), has_calculated_field


def dry_run_pipeline(pipeline_cls: BasicPipeline, pipeline_name: str, **kwargs):
def dry_run_pipeline(
pipeline_cls: BasicPipeline,
pipeline_name: str,
method: str = "run",
**kwargs,
):
from langfuse.decorators import langfuse_context

from src.config import settings
from src.core.pipeline import async_validate
from src.providers import generate_components
from src.utils import init_langfuse, setup_custom_logger

setup_custom_logger("wren-ai-service", level_str=settings.logging_level)
setup_custom_logger("wren-ai-service", level_str=settings.logging_level, is_dev=True)

pipe_components = generate_components(settings.components)
pipeline = pipeline_cls(**pipe_components[pipeline_name])
init_langfuse(settings)

async_validate(lambda: pipeline.run(**kwargs))
async_validate(lambda: getattr(pipeline, method)(**kwargs))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Verify error handling for dynamic method invocation.

Using getattr without checking if the method exists could raise AttributeError.

-    async_validate(lambda: getattr(pipeline, method)(**kwargs))
+    if not hasattr(pipeline, method):
+        raise ValueError(f"Method '{method}' not found in pipeline")
+    async_validate(lambda: getattr(pipeline, method)(**kwargs))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
async_validate(lambda: getattr(pipeline, method)(**kwargs))
if not hasattr(pipeline, method):
raise ValueError(f"Method '{method}' not found in pipeline")
async_validate(lambda: getattr(pipeline, method)(**kwargs))


langfuse_context.flush()

Expand Down
26 changes: 0 additions & 26 deletions wren-ai-service/src/pipelines/indexing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,41 +87,15 @@ async def run(
return {"documents_written": documents_written}


@component
class SqlPairsCleaner:
def __init__(self, sql_pairs_store: DocumentStore) -> None:
self._sql_pairs_store = sql_pairs_store

@component.output_types()
async def run(
self, sql_pair_ids: List[str], project_id: Optional[str] = None
) -> None:
filters = {
"operator": "AND",
"conditions": [
{"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
],
}

if project_id:
filters["conditions"].append(
{"field": "project_id", "operator": "==", "value": project_id}
)

return await self._sql_pairs_store.delete_documents(filters)


# Put the pipelines imports here to avoid circular imports and make them accessible directly to the rest of packages
from .db_schema import DBSchema # noqa: E402
from .historical_question import HistoricalQuestion # noqa: E402
from .sql_pairs import SqlPairs # noqa: E402
from .sql_pairs_deletion import SqlPairsDeletion # noqa: E402
from .table_description import TableDescription # noqa: E402

__all__ = [
"DBSchema",
"TableDescription",
"HistoricalQuestion",
"SqlPairsDeletion",
"SqlPairs",
]
64 changes: 53 additions & 11 deletions wren-ai-service/src/pipelines/indexing/sql_pairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@
from hamilton import base
from hamilton.async_driver import AsyncDriver
from haystack import Document, component
from haystack.document_stores.types import DuplicatePolicy
from haystack.document_stores.types import DocumentStore, DuplicatePolicy
from langfuse.decorators import observe
from pydantic import BaseModel

from src.core.pipeline import BasicPipeline
from src.core.provider import DocumentStoreProvider, EmbedderProvider
from src.pipelines.indexing import AsyncDocumentWriter, SqlPairsCleaner
from src.pipelines.indexing import AsyncDocumentWriter

logger = logging.getLogger("wren-ai-service")


class SqlPair(BaseModel):
id: str
sql: str
question: str
sql: str = ""
question: str = ""


@component
Expand All @@ -49,6 +49,30 @@ def run(self, sql_pairs: List[SqlPair], project_id: Optional[str] = ""):
}


@component
class SqlPairsCleaner:
def __init__(self, sql_pairs_store: DocumentStore) -> None:
self.store = sql_pairs_store

@component.output_types()
async def run(
self, sql_pair_ids: List[str], project_id: Optional[str] = None
) -> None:
filter = {
"operator": "AND",
"conditions": [
{"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
],
}

if project_id:
filter["conditions"].append(
{"field": "project_id", "operator": "==", "value": project_id}
)

return await self.store.delete_documents(filter)


## Start of Pipeline
@observe(capture_input=False)
def boilerplates(
Expand Down Expand Up @@ -155,9 +179,10 @@ def __init__(
document_store=store,
policy=DuplicatePolicy.OVERWRITE,
),
"external_pairs": _load_sql_pairs(sql_pairs_path),
}

self._external_pairs = _load_sql_pairs(sql_pairs_path)

super().__init__(
AsyncDriver({}, sys.modules[__name__], result_builder=base.DictResult())
)
Expand All @@ -167,18 +192,35 @@ async def run(
self,
mdl_str: str,
project_id: Optional[str] = "",
external_pairs: Optional[Dict[str, Any]] = {},
) -> Dict[str, Any]:
logger.info(
f"Project ID: {project_id} SQL Pairs Indexing pipeline is running..."
)

return await self._pipe.execute(
["write"],
inputs={
"mdl_str": mdl_str,
"project_id": project_id,
**self._components,
input = {
"mdl_str": mdl_str,
"project_id": project_id,
"external_pairs": {
**self._external_pairs,
**external_pairs,
},
**self._components,
}

return await self._pipe.execute(["write"], inputs=input)

Comment on lines 200 to +212
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Consider error handling in pipeline execution.

The pipeline self._pipe.execute(["write"], inputs=input) might fail if the store or embedding process errors. Currently, that error is not caught here. For user-facing reliability, consider a try-except to either log or handle partial writes more gracefully.

+        try:
+            return await self._pipe.execute(["write"], inputs=input)
+        except Exception as e:
+            logger.error(f"Pipeline execution failed: {e}")
+            raise
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
return await self._pipe.execute(
["write"],
inputs={
"mdl_str": mdl_str,
"project_id": project_id,
**self._components,
input = {
"mdl_str": mdl_str,
"project_id": project_id,
"external_pairs": {
**self._external_pairs,
**external_pairs,
},
**self._components,
}
return await self._pipe.execute(["write"], inputs=input)
input = {
"mdl_str": mdl_str,
"project_id": project_id,
"external_pairs": {
**self._external_pairs,
**external_pairs,
},
**self._components,
}
try:
return await self._pipe.execute(["write"], inputs=input)
except Exception as e:
logger.error(f"Pipeline execution failed: {e}")
raise

@observe(name="Clean Documents for SQL Pairs")
async def clean(
self,
sql_pairs: List[SqlPair],
project_id: Optional[str] = None,
) -> None:
await clean(
sql_pairs=sql_pairs,
embedding={"documents": []},
cleaner=self._components["cleaner"],
project_id=project_id,
)


Expand Down
Loading
Loading