Skip to content

Commit

Permalink
feat(wren-ai-service): Consolidate SQL Pairs Service and Remove Redun…
Browse files Browse the repository at this point in the history
…dant Code (#1268)
  • Loading branch information
paopa authored Feb 6, 2025
1 parent 7292b1f commit 69b220e
Show file tree
Hide file tree
Showing 28 changed files with 677 additions and 699 deletions.
3 changes: 0 additions & 3 deletions deployment/kustomizations/base/cm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,6 @@ data:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
Expand Down
3 changes: 0 additions & 3 deletions docker/config.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,6 @@ pipes:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
Expand Down
3 changes: 0 additions & 3 deletions wren-ai-service/docs/config_examples/config.azure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@ pipes:
document_store: qdrant
embedder: litellm_embedder.azure/text-embedding-ada-002
llm: litellm_llm.azure/gpt-4
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.azure/text-embedding-ada-002
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.azure/text-embedding-ada-002
Expand Down
3 changes: 0 additions & 3 deletions wren-ai-service/docs/config_examples/config.deepseek.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,6 @@ pipes:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ pipes:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.gemini/text-embedding-004
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.gemini/text-embedding-004
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.gemini/text-embedding-004
Expand Down
3 changes: 0 additions & 3 deletions wren-ai-service/docs/config_examples/config.groq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,6 @@ pipes:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.text-embedding-3-large
Expand Down
3 changes: 0 additions & 3 deletions wren-ai-service/docs/config_examples/config.ollama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ pipes:
- name: sql_pairs_indexing
document_store: qdrant
embedder: litellm_embedder.openai/nomic-embed-text
- name: sql_pairs_deletion
document_store: qdrant
embedder: litellm_embedder.openai/nomic-embed-text
- name: sql_pairs_retrieval
document_store: qdrant
embedder: litellm_embedder.openai/nomic-embed-text
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Settings(BaseSettings):
config_path: str = Field(default="config.yaml")
_components: list[dict]

sql_pairs_path: str = Field(default="pairs.json")
sql_pairs_path: str = Field(default="sql_pairs.json")

def __init__(self):
load_dotenv(".env.dev", override=True)
Expand Down
79 changes: 32 additions & 47 deletions wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,40 +7,27 @@
from src.core.pipeline import PipelineComponent
from src.core.provider import EmbedderProvider, LLMProvider
from src.pipelines import generation, indexing, retrieval
from src.web.v1.services.ask import AskService
from src.web.v1.services.ask_details import AskDetailsService
from src.web.v1.services.chart import ChartService
from src.web.v1.services.chart_adjustment import ChartAdjustmentService
from src.web.v1.services.question_recommendation import QuestionRecommendation
from src.web.v1.services.relationship_recommendation import RelationshipRecommendation
from src.web.v1.services.semantics_description import SemanticsDescription
from src.web.v1.services.semantics_preparation import SemanticsPreparationService
from src.web.v1.services.sql_answer import SqlAnswerService
from src.web.v1.services.sql_expansion import SqlExpansionService
from src.web.v1.services.sql_explanation import SqlExplanationService
from src.web.v1.services.sql_pairs_preparation import SqlPairsPreparationService
from src.web.v1.services.sql_question import SqlQuestionService
from src.web.v1.services.sql_regeneration import SqlRegenerationService
from src.web.v1 import services

logger = logging.getLogger("wren-ai-service")


@dataclass
class ServiceContainer:
ask_service: AskService
ask_details_service: AskDetailsService
question_recommendation: QuestionRecommendation
relationship_recommendation: RelationshipRecommendation
semantics_description: SemanticsDescription
semantics_preparation_service: SemanticsPreparationService
chart_service: ChartService
chart_adjustment_service: ChartAdjustmentService
sql_answer_service: SqlAnswerService
sql_expansion_service: SqlExpansionService
sql_explanation_service: SqlExplanationService
sql_regeneration_service: SqlRegenerationService
sql_pairs_preparation_service: SqlPairsPreparationService
sql_question_service: SqlQuestionService
ask_service: services.AskService
ask_details_service: services.AskDetailsService
question_recommendation: services.QuestionRecommendation
relationship_recommendation: services.RelationshipRecommendation
semantics_description: services.SemanticsDescription
semantics_preparation_service: services.SemanticsPreparationService
chart_service: services.ChartService
chart_adjustment_service: services.ChartAdjustmentService
sql_answer_service: services.SqlAnswerService
sql_expansion_service: services.SqlExpansionService
sql_explanation_service: services.SqlExplanationService
sql_regeneration_service: services.SqlRegenerationService
sql_pairs_service: services.SqlPairsService
sql_question_service: services.SqlQuestionService


@dataclass
Expand All @@ -58,15 +45,15 @@ def create_service_container(
"ttl": settings.query_cache_ttl,
}
return ServiceContainer(
semantics_description=SemanticsDescription(
semantics_description=services.SemanticsDescription(
pipelines={
"semantics_description": generation.SemanticsDescription(
**pipe_components["semantics_description"],
)
},
**query_cache,
),
semantics_preparation_service=SemanticsPreparationService(
semantics_preparation_service=services.SemanticsPreparationService(
pipelines={
"db_schema": indexing.DBSchema(
**pipe_components["db_schema_indexing"],
Expand All @@ -85,7 +72,7 @@ def create_service_container(
},
**query_cache,
),
ask_service=AskService(
ask_service=services.AskService(
pipelines={
"intent_classification": generation.IntentClassification(
**pipe_components["intent_classification"],
Expand Down Expand Up @@ -127,7 +114,7 @@ def create_service_container(
allow_sql_generation_reasoning=settings.allow_sql_generation_reasoning,
**query_cache,
),
chart_service=ChartService(
chart_service=services.ChartService(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
Expand All @@ -139,7 +126,7 @@ def create_service_container(
},
**query_cache,
),
chart_adjustment_service=ChartAdjustmentService(
chart_adjustment_service=services.ChartAdjustmentService(
pipelines={
"sql_executor": retrieval.SQLExecutor(
**pipe_components["sql_executor"],
Expand All @@ -151,7 +138,7 @@ def create_service_container(
},
**query_cache,
),
sql_answer_service=SqlAnswerService(
sql_answer_service=services.SqlAnswerService(
pipelines={
"preprocess_sql_data": retrieval.PreprocessSqlData(
**pipe_components["preprocess_sql_data"],
Expand All @@ -163,7 +150,7 @@ def create_service_container(
},
**query_cache,
),
ask_details_service=AskDetailsService(
ask_details_service=services.AskDetailsService(
pipelines={
"sql_breakdown": generation.SQLBreakdown(
**pipe_components["sql_breakdown"],
Expand All @@ -175,7 +162,7 @@ def create_service_container(
},
**query_cache,
),
sql_expansion_service=SqlExpansionService(
sql_expansion_service=services.SqlExpansionService(
pipelines={
"retrieval": retrieval.Retrieval(
**pipe_components["db_schema_retrieval"],
Expand All @@ -196,23 +183,23 @@ def create_service_container(
},
**query_cache,
),
sql_explanation_service=SqlExplanationService(
sql_explanation_service=services.SqlExplanationService(
pipelines={
"sql_explanation": generation.SQLExplanation(
**pipe_components["sql_explanation"],
)
},
**query_cache,
),
sql_regeneration_service=SqlRegenerationService(
sql_regeneration_service=services.SqlRegenerationService(
pipelines={
"sql_regeneration": generation.SQLRegeneration(
**pipe_components["sql_regeneration"],
)
},
**query_cache,
),
relationship_recommendation=RelationshipRecommendation(
relationship_recommendation=services.RelationshipRecommendation(
pipelines={
"relationship_recommendation": generation.RelationshipRecommendation(
**pipe_components["relationship_recommendation"],
Expand All @@ -221,7 +208,7 @@ def create_service_container(
},
**query_cache,
),
question_recommendation=QuestionRecommendation(
question_recommendation=services.QuestionRecommendation(
pipelines={
"question_recommendation": generation.QuestionRecommendation(
**pipe_components["question_recommendation"],
Expand All @@ -242,18 +229,16 @@ def create_service_container(
},
**query_cache,
),
sql_pairs_preparation_service=SqlPairsPreparationService(
sql_pairs_service=services.SqlPairsService(
pipelines={
"sql_pairs_preparation": indexing.SqlPairs(
"sql_pairs": indexing.SqlPairs(
**pipe_components["sql_pairs_indexing"],
),
"sql_pairs_deletion": indexing.SqlPairsDeletion(
**pipe_components["sql_pairs_deletion"],
),
sql_pairs_path=settings.sql_pairs_path,
)
},
**query_cache,
),
sql_question_service=SqlQuestionService(
sql_question_service=services.SqlQuestionService(
pipelines={
"sql_question_generation": generation.SQLQuestion(
**pipe_components["sql_question_generation"],
Expand Down
11 changes: 8 additions & 3 deletions wren-ai-service/src/pipelines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,26 @@ def build_table_ddl(
), has_calculated_field


def dry_run_pipeline(pipeline_cls: BasicPipeline, pipeline_name: str, **kwargs):
def dry_run_pipeline(
pipeline_cls: BasicPipeline,
pipeline_name: str,
method: str = "run",
**kwargs,
):
from langfuse.decorators import langfuse_context

from src.config import settings
from src.core.pipeline import async_validate
from src.providers import generate_components
from src.utils import init_langfuse, setup_custom_logger

setup_custom_logger("wren-ai-service", level_str=settings.logging_level)
setup_custom_logger("wren-ai-service", level_str=settings.logging_level, is_dev=True)

pipe_components = generate_components(settings.components)
pipeline = pipeline_cls(**pipe_components[pipeline_name])
init_langfuse(settings)

async_validate(lambda: pipeline.run(**kwargs))
async_validate(lambda: getattr(pipeline, method)(**kwargs))

langfuse_context.flush()

Expand Down
26 changes: 0 additions & 26 deletions wren-ai-service/src/pipelines/indexing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,41 +87,15 @@ async def run(
return {"documents_written": documents_written}


@component
class SqlPairsCleaner:
def __init__(self, sql_pairs_store: DocumentStore) -> None:
self._sql_pairs_store = sql_pairs_store

@component.output_types()
async def run(
self, sql_pair_ids: List[str], project_id: Optional[str] = None
) -> None:
filters = {
"operator": "AND",
"conditions": [
{"field": "sql_pair_id", "operator": "in", "value": sql_pair_ids},
],
}

if project_id:
filters["conditions"].append(
{"field": "project_id", "operator": "==", "value": project_id}
)

return await self._sql_pairs_store.delete_documents(filters)


# Put the pipelines imports here to avoid circular imports and make them accessible directly to the rest of packages
from .db_schema import DBSchema # noqa: E402
from .historical_question import HistoricalQuestion # noqa: E402
from .sql_pairs import SqlPairs # noqa: E402
from .sql_pairs_deletion import SqlPairsDeletion # noqa: E402
from .table_description import TableDescription # noqa: E402

__all__ = [
"DBSchema",
"TableDescription",
"HistoricalQuestion",
"SqlPairsDeletion",
"SqlPairs",
]
Loading

0 comments on commit 69b220e

Please sign in to comment.