This repository has been archived by the owner on Mar 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 738
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
608 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Vanna AI LLamaPack | ||
|
||
Vanna AI is an open-source RAG framework for SQL generation. It works in two steps: | ||
1. Train a RAG model on your data | ||
2. Ask questions (use reference corpus to generate SQL queries that can run on your db). | ||
|
||
Check out the [Github project](https://github.com/vanna-ai/vanna) and the [docs](https://vanna.ai/docs/) for more details. | ||
|
||
This LlamaPack creates a simple `VannaQueryEngine` with vanna, ChromaDB and OpenAI, and allows you to train and ask questions over a SQL database. | ||
|
||
## CLI Usage | ||
|
||
You can download llamapacks directly using `llamaindex-cli`, which comes installed with the `llama-index` python package: | ||
|
||
```bash | ||
llamaindex-cli download-llamapack VannaPack --download-dir ./vanna_pack | ||
``` | ||
|
||
You can then inspect the files at `./vanna_pack` and use them as a template for your own project! | ||
|
||
## Code Usage | ||
|
||
You can download the pack to a `./vanna_pack` directory: | ||
|
||
```python | ||
from llama_index.llama_pack import download_llama_pack | ||
|
||
# download and install dependencies | ||
VannaPack = download_llama_pack( | ||
"VannaPack", "./vanna_pack" | ||
) | ||
``` | ||
|
||
From here, you can use the pack, or inspect and modify the pack in `./vanna_pack`. | ||
|
||
Then, you can set up the pack like so: | ||
|
||
```python | ||
pack = VannaPack( | ||
openai_api_key="<openai_api_key>", | ||
sql_db_url="chinook.db", | ||
openai_model="gpt-3.5-turbo" | ||
) | ||
``` | ||
|
||
The `run()` function is a light wrapper around `llm.complete()`. | ||
|
||
```python | ||
response = pack.run("List some sample albums") | ||
``` | ||
|
||
You can also use modules individually. | ||
|
||
```python | ||
query_engine = pack.get_modules()["vanna_query_engine"] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Init params.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Vanna AI Pack. | ||
Uses: https://vanna.ai/. | ||
""" | ||
|
||
|
||
from typing import Any, Dict, Optional, cast | ||
|
||
from llama_index.llama_pack.base import BaseLlamaPack | ||
from llama_index.query_engine import CustomQueryEngine | ||
import pandas as pd | ||
from llama_index.response.schema import RESPONSE_TYPE, Response | ||
|
||
|
||
class VannaQueryEngine(CustomQueryEngine): | ||
"""Vanna query engine. | ||
Uses chromadb and OpenAI. | ||
""" | ||
|
||
openai_api_key: str | ||
sql_db_url: str | ||
|
||
ask_kwargs: Dict[str, Any] | ||
vn: Any | ||
|
||
def __init__( | ||
self, | ||
openai_api_key: str, | ||
sql_db_url: str, | ||
openai_model: str = "gpt-3.5-turbo", | ||
ask_kwargs: Optional[Dict[str, Any]] = None, | ||
verbose: bool = True, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
from vanna.openai.openai_chat import OpenAI_Chat | ||
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore | ||
|
||
class MyVanna(ChromaDB_VectorStore, OpenAI_Chat): | ||
def __init__(self, config=None): | ||
ChromaDB_VectorStore.__init__(self, config=config) | ||
OpenAI_Chat.__init__(self, config=config) | ||
|
||
vn = MyVanna(config={"api_key": openai_api_key, "model": openai_model}) | ||
vn.connect_to_sqlite(sql_db_url) | ||
if verbose: | ||
print(f"> Connected to db: {sql_db_url}") | ||
|
||
# get every table DDL from db | ||
sql_results = cast( | ||
pd.DataFrame, | ||
vn.run_sql("SELECT sql FROM sqlite_master WHERE type='table';"), | ||
) | ||
# go through every sql statement, do vn.train(ddl=ddl) on it | ||
for idx, sql_row in sql_results.iterrows(): | ||
if verbose: | ||
print(f"> Training on {sql_row['sql']}") | ||
vn.train(ddl=sql_row["sql"]) | ||
|
||
super().__init__( | ||
openai_api_key=openai_api_key, | ||
sql_db_url=sql_db_url, | ||
vn=vn, | ||
ask_kwargs=ask_kwargs or {}, | ||
**kwargs, | ||
) | ||
|
||
def custom_query(self, query_str: str) -> RESPONSE_TYPE: | ||
"""Query.""" | ||
from vanna.base import VannaBase | ||
|
||
vn = cast(VannaBase, self.vn) | ||
ask_kwargs = {"visualize": False, "print_results": False} | ||
ask_kwargs.update(self.ask_kwargs) | ||
sql = vn.generate_sql( | ||
query_str, | ||
**ask_kwargs, | ||
) | ||
result = vn.run_sql(sql) | ||
if result is None: | ||
raise ValueError("Vanna returned None.") | ||
sql, df, _ = result | ||
|
||
return Response(response=str(df), metadata={"sql": sql, "df": df}) | ||
|
||
|
||
class VannaPack(BaseLlamaPack): | ||
"""Vanna AI pack. | ||
Uses OpenAI and ChromaDB. Of course Vanna.AI allows you to connect to many more dbs | ||
and use more models - feel free to refer to their page for more details: | ||
https://vanna.ai/docs/snowflake-openai-vanna-vannadb.html | ||
""" | ||
|
||
def __init__( | ||
self, | ||
openai_api_key: str, | ||
sql_db_url: str, | ||
**kwargs: Any, | ||
) -> None: | ||
"""Init params.""" | ||
|
||
self.vanna_query_engine = VannaQueryEngine( | ||
openai_api_key=openai_api_key, sql_db_url=sql_db_url, **kwargs | ||
) | ||
|
||
def get_modules(self) -> Dict[str, Any]: | ||
"""Get modules.""" | ||
return { | ||
"vanna_query_engine": self.vanna_query_engine, | ||
} | ||
|
||
def run(self, *args: Any, **kwargs: Any) -> Any: | ||
"""Run the pipeline.""" | ||
return self.vanna_query_engine.query(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
vanna==0.0.36 |
Oops, something went wrong.