Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(Views): updated views to reflect refactors #1582

Merged
merged 3 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandasai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def create(
if df is not None:
schema = df.schema
schema.name = sanitize_sql_table_name(dataset_name)
df.to_parquet(parquet_file_path, index=False)
parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path)
df.to_parquet(parquet_file_path_abs_path, index=False)
elif view:
_relation = [Relation(**relation) for relation in relations or ()]
schema: SemanticLayerSchema = SemanticLayerSchema(
Expand Down
7 changes: 5 additions & 2 deletions pandasai/agent/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
with duckdb.connect() as con:
# Register all DataFrames in the state
for df in self._state.dfs:
con.register(df.schema.source.table, df)
con.register(df.schema.name, df)

# Execute the query and fetch the result as a pandas DataFrame
result = con.sql(query).df()
Expand All @@ -145,7 +145,10 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame:
if not self._state.dfs:
raise ValueError("No DataFrames available to register for query execution.")

if self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES:
if (
self._state.dfs[0].schema.source
and self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES
scaliseraoul marked this conversation as resolved.
Show resolved Hide resolved
):
return self._execute_local_sql_query(query)
else:
return self._state.dfs[0].execute_sql_query(query)
Expand Down
7 changes: 2 additions & 5 deletions pandasai/core/code_generation/code_cleaning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,8 @@ def _clean_sql_query(self, sql_query: str) -> str:
sql_query = sql_query.rstrip(";")
table_names = extract_table_names(sql_query)
allowed_table_names = {
df.schema.source.table: df.schema.source.table for df in self.context.dfs
} | {
f'"{df.schema.source.table}"': df.schema.source.table
for df in self.context.dfs
}
df.schema.name: df.schema.name for df in self.context.dfs
} | {f'"{df.schema.name}"': df.schema.name for df in self.context.dfs}
return self._replace_table_names(sql_query, table_names, allowed_table_names)

def _validate_and_make_table_name_case_sensitive(self, node: ast.AST) -> ast.AST:
Expand Down
4 changes: 2 additions & 2 deletions pandasai/data_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@

from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import MethodNotImplementedError
from pandasai.helpers.path import get_validated_dataset_path
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name

from .. import ConfigManager
from ..constants import (
LOCAL_SOURCE_TYPES,
)
from .query_builder import QueryBuilder
from .semantic_layer_schema import SemanticLayerSchema
from .transformation_manager import TransformationManager
from .view_query_builder import ViewQueryBuilder


class DatasetLoader:
def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
self.schema = schema
self.dataset_path = dataset_path
self.org_name, self.dataset_name = get_validated_dataset_path(self.dataset_path)

@classmethod
def create_loader_from_schema(
Expand Down
7 changes: 4 additions & 3 deletions pandasai/data_loader/local_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from pandasai.dataframe.base import DataFrame
from pandasai.exceptions import InvalidDataSourceType

from ..config import ConfigManager
from ..constants import (
LOCAL_SOURCE_TYPES,
)
from .loader import DatasetLoader
from .transformation_manager import TransformationManager


class LocalDatasetLoader(DatasetLoader):
Expand Down Expand Up @@ -44,10 +44,11 @@ def _load_from_local_source(self) -> pd.DataFrame:
return self._read_csv_or_parquet(filepath, source_type)

def _read_csv_or_parquet(self, file_path: str, file_format: str) -> pd.DataFrame:
file_manager = ConfigManager.get().file_manager
if file_format == "parquet":
df = pd.read_parquet(file_path)
df = pd.read_parquet(file_manager.abs_path(file_path))
elif file_format == "csv":
df = pd.read_csv(file_path)
df = pd.read_csv(file_manager.abs_path(file_path))
else:
raise ValueError(f"Unsupported file format: {file_format}")

Expand Down
13 changes: 9 additions & 4 deletions pandasai/data_loader/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from typing import Any, Dict, List, Union

from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema
Expand All @@ -8,11 +9,15 @@ def __init__(self, schema: SemanticLayerSchema):
self.schema = schema

def format_query(self, query):
return query
pattern = re.compile(
rf"\bFROM\s+{re.escape(self.schema.name)}\b", re.IGNORECASE
)
replacement = self._get_from_statement()
return pattern.sub(replacement, query)

def build_query(self) -> str:
columns = self._get_columns()
query = f"SELECT {columns}"
query = f"SELECT {columns} "
query += self._get_from_statement()
query += self._add_order_by()
query += self._add_limit()
Expand All @@ -26,7 +31,7 @@ def _get_columns(self) -> str:
return "*"

def _get_from_statement(self):
return f" FROM {self.schema.source.table.lower()}"
return f"FROM {self.schema.source.table.lower()}"

def _add_order_by(self) -> str:
if not self.schema.order_by:
Expand All @@ -47,7 +52,7 @@ def _add_limit(self, n=None) -> str:
def get_head_query(self, n=5):
source_type = self.schema.source.type
columns = self._get_columns()
query = f"SELECT {columns}"
query = f"SELECT {columns} "
query += self._get_from_statement()
order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()"
return f"{query} ORDER BY {order_by} LIMIT {n}"
Expand Down
16 changes: 16 additions & 0 deletions pandasai/data_loader/semantic_layer_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ class SQLConnectionConfig(BaseModel):
user: str = Field(..., description="Database username")
password: str = Field(..., description="Database password")

def __eq__(self, other):
return (
self.host == other.host
and self.port == other.port
and self.database == other.database
and self.user == other.user
and self.password == other.password
)
scaliseraoul marked this conversation as resolved.
Show resolved Hide resolved


class Column(BaseModel):
name: str = Field(..., description="Name of the column.")
Expand Down Expand Up @@ -174,6 +183,13 @@ class Source(BaseModel):
)
table: Optional[str] = Field(None, description="Table of the data source.")

def is_compatible_source(self, source2: "Source"):
if self.type in LOCAL_SOURCE_TYPES and source2.type in LOCAL_SOURCE_TYPES:
return True
if self.type in REMOTE_SOURCE_TYPES and source2.type in REMOTE_SOURCE_TYPES:
return self.connection == source2.connection
return False

@model_validator(mode="before")
@classmethod
def validate_type_and_fields(cls, values):
Expand Down
77 changes: 75 additions & 2 deletions pandasai/data_loader/view_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
from typing import Dict, Optional

import pandas as pd

from pandasai.dataframe.virtual_dataframe import VirtualDataFrame

from .semantic_layer_schema import SemanticLayerSchema
from .. import InvalidConfigError
from ..exceptions import MaliciousQueryError
from ..helpers.sql_sanitizer import is_sql_query_safe
from .loader import DatasetLoader
from .semantic_layer_schema import SemanticLayerSchema, Source, is_schema_source_same
from .sql_loader import SQLDatasetLoader
from .view_query_builder import ViewQueryBuilder

Expand All @@ -10,13 +18,78 @@ class ViewDatasetLoader(SQLDatasetLoader):
Loader for view-based datasets.
"""

# get the datasets name
# get the datasets schemas
# pass to the query builder
# generate SELECT IN statement with JOINs
scaliseraoul marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
super().__init__(schema, dataset_path)
self.query_builder: ViewQueryBuilder = ViewQueryBuilder(schema)
self.dependencies_datasets = self._get_dependencies_datasets()
self.schema_dependencies_dict: dict[
str, DatasetLoader
] = self._get_dependencies_schemas()
self.source: Source = list(self.schema_dependencies_dict.values())[
0
].schema.source
self.query_builder: ViewQueryBuilder = ViewQueryBuilder(
schema, self.schema_dependencies_dict
)

def _get_dependencies_datasets(self) -> set[str]:
return {
table.split(".")[0]
for relation in self.schema.relations
for table in (relation.from_, relation.to)
}

def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]:
dependency_dict = {
dep: DatasetLoader.create_loader_from_path(f"{self.org_name}/{dep}")
for dep in self.dependencies_datasets
}

loaders = list(dependency_dict.values())
base_source = loaders[0].schema.source

for loader in loaders[1:]:
if not base_source.is_compatible_source(loader.schema.source):
raise ValueError(
f"Source in loader with schema {loader.schema} is not compatible with the first loader's source."
)

return dependency_dict

def load(self) -> VirtualDataFrame:
return VirtualDataFrame(
schema=self.schema,
data_loader=ViewDatasetLoader(self.schema, self.dataset_path),
path=self.dataset_path,
)

def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFrame:
source_type = self.source.type
connection_info = self.source.connection

formatted_query = self.query_builder.format_query(query)
load_function = self._get_loader_function(source_type)

if not is_sql_query_safe(formatted_query):
raise MaliciousQueryError(
"The SQL query is deemed unsafe and will not be executed."
)
try:
dataframe: pd.DataFrame = load_function(
connection_info, formatted_query, params
)
return dataframe

except ModuleNotFoundError as e:
raise ImportError(
f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`."
) from e

except Exception as e:
raise RuntimeError(
f"Failed to execute query for '{source_type}' with: {formatted_query}"
) from e
68 changes: 50 additions & 18 deletions pandasai/data_loader/view_query_builder.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
from typing import Any, Dict, List, Union
import re
from typing import Dict

from pandasai.data_loader.loader import DatasetLoader
from pandasai.data_loader.query_builder import QueryBuilder
from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
from pandasai.data_loader.sql_loader import SQLDatasetLoader


class ViewQueryBuilder(QueryBuilder):
def __init__(self, schema: SemanticLayerSchema):
def __init__(
self,
schema: SemanticLayerSchema,
schema_dependencies_dict: Dict[str, DatasetLoader],
):
super().__init__(schema)

def format_query(self, query):
return f"{self._get_with_statement()}{query}"
self.schema_dependencies_dict = schema_dependencies_dict

def build_query(self) -> str:
columns = self._get_columns()
query = self._get_with_statement()
query += f"SELECT {columns}"
query = f"SELECT {columns} "
query += self._get_from_statement()
query += self._add_order_by()
query += self._add_limit()
Expand All @@ -28,13 +32,31 @@ def _get_columns(self) -> str:
else:
return super()._get_columns()

def _get_from_statement(self):
return f" FROM {self.schema.name}"
def _get_columns_for_table(self, query):
match = re.search(r"SELECT\s+(.*?)\s+FROM", query, re.IGNORECASE)
if not match:
return None

columns = match.group(1).split(",")
return [col.strip() for col in columns]

def _get_sub_query_from_loader(self, loader: SQLDatasetLoader) -> (str, str):
query = loader.query_builder.build_query()
return query, loader.schema.name

def _get_with_statement(self):
def _get_from_statement(self):
relations = self.schema.relations
first_table = relations[0].from_.split(".")[0]
query = f"WITH {self.schema.name} AS ( SELECT\n"
first_dataset = relations[0].from_.split(".")[0]
first_loader = self.schema_dependencies_dict[first_dataset]

if isinstance(first_loader, SQLDatasetLoader):
first_query, first_name = self._get_sub_query_from_loader(first_loader)
else:
raise ValueError(
f"Views for local datasets or nested views are currently not supported."
)

query = f"FROM ( SELECT\n"

if self.schema.columns:
query += ", ".join(
Expand All @@ -44,11 +66,21 @@ def _get_with_statement(self):
]
)
else:
query += "*"
query += "* "

query += f"\nFROM {first_table}"
query += f"\nFROM ( {first_query} ) AS {first_name}"
for relation in relations:
to_table = relation.to.split(".")[0]
query += f"\nJOIN {to_table} ON {relation.from_} = {relation.to}"
query += ")\n"
to_datasets = relation.to.split(".")[0]
loader = self.schema_dependencies_dict[to_datasets]
subquery, dataset_name = self._get_sub_query_from_loader(loader)
query += f"\nJOIN ( {subquery} ) AS {dataset_name}\n"
query += f"ON {relation.from_} = {relation.to}"
query += f") AS {self.schema.name}\n"

return query

def get_head_query(self, n=5):
columns = self._get_columns()
query = f"SELECT {columns}"
query += self._get_from_statement()
return f"{query} LIMIT {n}"
9 changes: 4 additions & 5 deletions pandasai/dataframe/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,16 @@ def push(self):
"name": self.schema.name,
}

dataset_directory = os.path.join("datasets", self.path)
file_manager = ConfigManager.get().file_manager
headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}

files = []
schema_file_path = os.path.join(dataset_directory, "schema.yaml")
data_file_path = os.path.join(dataset_directory, "data.parquet")
schema_file_path = os.path.join(self.path, "schema.yaml")
data_file_path = os.path.join(self.path, "data.parquet")

# Open schema.yaml
schema_file = file_manager.load_binary(schema_file_path)
files.append(("files", ("schema.yaml", schema_file, "application/x-yaml")))

files = [("files", ("schema.yaml", schema_file, "application/x-yaml"))]

# Check if data.parquet exists and open it
if file_manager.exists(data_file_path):
Expand Down
6 changes: 1 addition & 5 deletions pandasai/helpers/dataframe_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ def serialize(df: "DataFrame") -> str:
Returns:
str: dataframe stringify
"""
dataframe_info = "<table"

# Add name attribute if available
if df.schema.source.table is not None:
dataframe_info += f' table_name="{df.schema.source.table}"'
dataframe_info = f'<table table_name="{df.schema.name}"'

# Add description attribute if available
if df.schema.description is not None:
Expand Down
Loading
Loading