Skip to content

Commit

Permalink
feat(wren-ai-service): litellm integration (#946)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyyeh authored Dec 6, 2024
1 parent b747ddc commit 8d47d20
Show file tree
Hide file tree
Showing 25 changed files with 297 additions and 68 deletions.
66 changes: 27 additions & 39 deletions wren-ai-service/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions wren-ai-service/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pyyaml = "^6.0.2"
pydantic-settings = "^2.5.2"
google-auth = "^2.35.0"
tiktoken = "^0.8.0"
litellm = "^1.52.12"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.7.1"
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/core/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def get_generator(self, *args, **kwargs):
...

def get_model(self):
return self._generation_model
return self._model

def get_model_kwargs(self):
return self._model_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def data_assistance(prompt: dict, generator: Any, query_id: str) -> dict:
return await generator.run(prompt=prompt.get("prompt"), query_id=query_id)
return await generator(prompt=prompt.get("prompt"), query_id=query_id)


## End of Pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_in_followup(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@async_timer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def classify_intent(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@timer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def prompt(

@observe(capture_input=False, as_type="generation")
async def generate(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@observe(capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def prompt(

@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@observe(capture_input=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def prompt(

@observe(as_type="generation", capture_input=False)
async def generate(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@observe(capture_input=False)
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_details(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@async_timer
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def prompts(
async def generate_sql_corrections(prompts: list[dict], generator: Any) -> list[dict]:
tasks = []
for prompt in prompts:
task = asyncio.ensure_future(generator.run(prompt=prompt.get("prompt")))
task = asyncio.ensure_future(generator(prompt=prompt.get("prompt")))
tasks.append(task)

return await asyncio.gather(*tasks)
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_expansion(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@async_timer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def prompts(
@observe(as_type="generation", capture_input=False)
async def generate_sql_explanation(prompts: List[dict], generator: Any) -> List[dict]:
async def _task(prompt: str, generator: Any):
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))

tasks = [_task(prompt, generator) for prompt in prompts]
return await asyncio.gather(*tasks)
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def generate_sql(
prompt: dict,
generator: Any,
) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@async_timer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ async def generate_sql_regeneration(
sql_regeneration_prompt: dict,
generator: Any,
) -> dict:
return await generator.run(prompt=sql_regeneration_prompt.get("prompt"))
return await generator(prompt=sql_regeneration_prompt.get("prompt"))


@async_timer
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/generation/sql_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def prompt(
@async_timer
@observe(as_type="generation", capture_input=False)
async def generate_sql_summary(prompt: dict, generator: Any) -> dict:
return await generator.run(prompt=prompt.get("prompt"))
return await generator(prompt=prompt.get("prompt"))


@timer
Expand Down
2 changes: 1 addition & 1 deletion wren-ai-service/src/pipelines/retrieval/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ async def filter_columns_in_tables(
prompt: dict, table_columns_selection_generator: Any
) -> dict:
if prompt:
return await table_columns_selection_generator.run(prompt=prompt.get("prompt"))
return await table_columns_selection_generator(prompt=prompt.get("prompt"))
else:
return {}

Expand Down
8 changes: 6 additions & 2 deletions wren-ai-service/src/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,16 @@ def llm_processor(entry: dict) -> dict:
"""
others = {k: v for k, v in entry.items() if k not in ["type", "provider", "models"]}
returned = {}
for model in entry["models"]:
model_name = f"{entry['provider']}.{model['model']}"
for model in entry.get("models", []):
model_name = f"{entry.get('provider')}.{model.get('model')}"
model_additional_params = {
k: v for k, v in model.items() if k not in ["model", "kwargs"]
}
returned[model_name] = {
"provider": entry["provider"],
"model": model["model"],
"kwargs": model["kwargs"],
**model_additional_params,
**others,
}
return returned
Expand Down
Loading

0 comments on commit 8d47d20

Please sign in to comment.