Skip to content

Commit

Permalink
Merge pull request #1527 from sfc-gh-alherrera/dspy-snowflake
Browse files Browse the repository at this point in the history
fix(dspy): updating SnowflakeRM implementation to support passing filters/columns per query
  • Loading branch information
okhat authored Sep 23, 2024
2 parents 68eb90d + 4dd5bbd commit 915845b
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
22 changes: 11 additions & 11 deletions docs/api/retrieval_model_clients/SnowflakeRM.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ SnowflakeRM(
cortex_search_service: str,
snowflake_database: str,
snowflake_schema: dict,
retrieval_columns: list,
search_filter: dict = None,
k: int = 3,
)
```
Expand All @@ -26,8 +24,6 @@ SnowflakeRM(
- `cortex_search_service (str)`: The name of the Cortex Search service to be used.
- `snowflake_database (str)`: The name of the Snowflake database to be used with the Cortex Search service.
- `snowflake_schema (str)`: The name of the Snowflake schema to be used with the Cortex Search service.
- `retrieval_columns (str)`: A list of columns to return for each relevant result in the response.
- `search_filter (dict, optional)`: Optional filter object used for filtering results based on data in the ATTRIBUTES columns. See [Filter syntax](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/query-cortex-search-service#filter-syntax)
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.

### Methods
Expand All @@ -39,6 +35,8 @@ Query the Cortex Search service to retrieve the top k relevant results given a q
**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `retrieval_columns` (str)`: A list of columns to return for each relevant result in the response.
- `search_filter` (_Optional[dict]_): Optional filter object used for filtering results based on data in the ATTRIBUTES columns. See [Filter syntax](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/query-cortex-search-service#filter-syntax)
- `k` (_Optional[int]_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**
Expand All @@ -47,7 +45,7 @@ Query the Cortex Search service to retrieve the top k relevant results given a q

### Quickstart

To support passage retrieval from a Snowflake table with this integration, a Cortex Search endpoint must be configured.
To support passage retrieval from a Snowflake table with this integration, a [Cortex Search](https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-search/cortex-search-overview) endpoint must first be configured.

```python
from dspy.retrieve.snowflake_rm import SnowflakeRM
Expand All @@ -68,12 +66,14 @@ connection_parameters = {
snowpark = Session.builder.configs(connection_parameters).create()

snowflake_retriever = SnowflakeRM(snowflake_session=snowpark,
snowflake_database="<YOUR_SNOWFLAKE_DATABASE_NAME>",
snowflake_schema="<YOUR_SNOWFLAKE_SCHEMA_NAME>",
cortex_search_service="<YOUR_CORTEX_SERACH_SERVICE_NAME>",
k = 5)

results = snowflake_retriever("Explore the meaning of life",response_columns=["<NAME_OF_COLUMN_CONTAINING_TEXT>"])
cortex_search_service="<YOUR_CORTEX_SERACH_SERVICE_NAME>",
snowflake_database="<YOUR_SNOWFLAKE_DATABASE_NAME>",
snowflake_schema="<YOUR_SNOWFLAKE_SCHEMA_NAME>",
k = 5)

results = snowflake_retriever("Explore the meaning of life",
response_columns=["<NAME_OF_COLUMN_CONTAINING_TEXT>"],
filter={ "@eq": { "string_col": "value" } })

for result in results:
print("Document:", result.long_text, "\n")
Expand Down
27 changes: 15 additions & 12 deletions dspy/retrieve/snowflake_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,28 @@ def __init__(
cortex_search_service: str,
snowflake_database: str,
snowflake_schema: str,
retrieval_columns: list,
search_filter=None,
k: int = 3,
):
super().__init__(k=k)
self.k = k
self.cortex_search_service_name = cortex_search_service
self.retrieval_columns = retrieval_columns
self.search_filter = search_filter
self.client = self._fetch_cortex_service(
snowflake_session, snowflake_database, snowflake_schema, cortex_search_service
)

def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
"""Query Cortex Search endpoint for top relevant passages.
def forward(
self,
query_or_queries: Union[str, list[str]],
retrieval_columns: list[str],
filter: Optional[dict] = None,
k: Optional[int] = None,
) -> dspy.Prediction:
"""Query Cortex Search endpoint for top k relevant passages.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
retrieval_columns (List[str]): Columns to include in response.
filter (Optional[json]):Filter query.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
Returns:
dspy.Prediction: An object containing the retrieved passages.
"""
Expand All @@ -66,12 +69,12 @@ def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = No
response_chunks = self._query_cortex_search(
cortex_search_service=self.client,
query=cortex_query,
columns=self.retrieval_columns,
filter=self.search_filter,
columns=retrieval_columns,
filter=filter,
k=k,
)

if len(self.retrieval_columns) == 1:
if len(retrieval_columns) == 1:
passages.extend(
dotdict({"long_text": passage[self.retrieval_columns[0]]}) for passage in response_chunks["results"]
)
Expand Down Expand Up @@ -100,8 +103,8 @@ def _query_cortex_search(self, cortex_search_service, query, columns, filter, k)
Args:
cortex_search_service (object): cortex search service for querying
query (str): The query or queries to search for.
repsonse_columns: A comma-separated list of columns to return for each relevant result in the response. These columns must be included in the source query for the service.
filters: A filter object for filtering results based on data in the ATTRIBUTES columns. See Filter syntax.
columns (Optional[list]): A comma-separated list of columns to return for each relevant result in the response. These columns must be included in the source query for the service.
filter (Optional[json]): A filter object for filtering results based on data in the ATTRIBUTES columns. See Filter syntax.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
Returns:
Expand Down

0 comments on commit 915845b

Please sign in to comment.