Skip to content

Commit

Permalink
Merge pull request #1564 from dbczumar/databricks_rm
Browse files Browse the repository at this point in the history
[1] DatabricksRM: Use databricks-sdk to fetch token / workspace URL + several small improvements
  • Loading branch information
okhat authored Sep 30, 2024
2 parents 6041955 + ae86c9c commit 59ae987
Show file tree
Hide file tree
Showing 12 changed files with 405 additions and 142 deletions.
2 changes: 1 addition & 1 deletion docs/api/retrieval_model_clients/AzureCognitiveSearch.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 2
sidebar_position: 3
---

# retrieve.AzureCognitiveSearch
Expand Down
5 changes: 4 additions & 1 deletion docs/api/retrieval_model_clients/ChromadbRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 3
sidebar_position: 4
---

# retrieve.ChromadbRM
Expand All @@ -18,6 +18,7 @@ ChromadbRM(
```

**Parameters:**

- `collection_name` (_str_): The name of the chromadb collection.
- `persist_directory` (_str_): Path to the directory where chromadb data is persisted.
- `embedding_function` (_Optional[EmbeddingFunction[Embeddable]]_, _optional_): The function used for embedding documents and queries. Defaults to `DefaultEmbeddingFunction()` if not specified.
Expand All @@ -30,10 +31,12 @@ ChromadbRM(
Search the chromadb collection for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_function`.

**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`

### Quickstart with OpenAI Embeddings
Expand Down
120 changes: 120 additions & 0 deletions docs/api/retrieval_model_clients/DatabricksRM.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
---
sidebar_position: 2
---

# retrieve.DatabricksRM

### Constructor

Initialize an instance of the `DatabricksRM` retriever class, which enables DSPy programs to query
[Databricks Mosaic AI Vector Search](https://docs.databricks.com/en/generative-ai/vector-search.html#mosaic-ai-vector-search)
indexes for document retrieval.

```python
DatabricksRM(
databricks_index_name: str,
databricks_endpoint: Optional[str] = None,
databricks_token: Optional[str] = None,
columns: Optional[List[str]] = None,
filters_json: Optional[str] = None,
k: int = 3,
docs_id_column_name: str = "id",
text_column_name: str = "text",
)
```

**Parameters:**

- `databricks_index_name (str)`: The name of the Databricks Vector Search Index to query.
- `databricks_endpoint (Optional[str])`: The URL of the Databricks Workspace containing
the Vector Search Index. Defaults to the value of the `DATABRICKS_HOST` environment variable.
If unspecified, the Databricks SDK is used to identify the endpoint based on the current
environment.
- `databricks_token (Optional[str])`: The Databricks Workspace authentication token to use
when querying the Vector Search Index. Defaults to the value of the `DATABRICKS_TOKEN`
environment variable. If unspecified, the Databricks SDK is used to identify the token based on
the current environment.
- `columns (Optional[List[str]])`: Extra column names to include in response, in addition to the
document id and text columns specified by `docs_id_column_name` and `text_column_name`.
- `filters_json (Optional[str])`: A JSON string specifying additional query filters.
Example filters: `{"id <": 5}` selects records that have an `id` column value
less than 5, and `{"id >=": 5, "id <": 10}` selects records that have an `id`
column value greater than or equal to 5 and less than 10.
- `k (int)`: The number of documents to retrieve.
- `docs_id_column_name (str)`: The name of the column in the Databricks Vector Search Index
containing document IDs.
- `text_column_name (str)`: The name of the column in the Databricks Vector Search Index
containing document text to retrieve.

### Methods

#### `def forward(self, query: Union[str, List[float]], query_type: str = "ANN", filters_json: Optional[str] = None) -> dspy.Prediction:`

Retrieve documents from a Databricks Mosaic AI Vector Search Index that are relevant to the
specified query.

**Parameters:**

- `query (Union[str, List[float]])`: The query text or numeric query vector
for which to retrieve relevant documents.
- `query_type (str)`: The type of search query to perform against the
Databricks Vector Search Index. Must be either 'ANN' (approximate nearest neighbor) or 'HYBRID'
(hybrid search).
- `filters_json (Optional[str])`: A JSON string specifying additional query filters.
Example filters: `{"id <": 5}` selects records that have an `id` column value
less than 5, and `{"id >=": 5, "id <": 10}` selects records that have an `id`
column value greater than or equal to 5 and less than 10. If specified, this
parameter overrides the `filters_json` parameter passed to the constructor.

**Returns:**

- `dspy.Prediction`: A `dotdict` containing retrieved documents. The schema is
`{'docs': List[str], 'doc_ids': List[Any], extra_columns: List[Dict[str, Any]]}`.
The `docs` entry contains the retrieved document content.

### Quickstart

To retrieve documents using Databricks Mosaic AI Vector Search, you must [create a
Databricks Mosaic AI Vector Search Index](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html)
first.

The following example code demonstrates how to set up a Databricks Mosaic AI
[Direct Access Vector Search Index](https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)
and use the `DatabricksRM` DSPy retriever module to query the index. The example requires
the `databricks-vectorsearch` Python library to be installed.

```python
from databricks.vector_search.client import VectorSearchClient

# Create a Databricks Vector Search Endpoint
client = VectorSearchClient()
client.create_endpoint(
name="your_vector_search_endpoint_name",
endpoint_type="STANDARD"
)

# Create a Databricks Direct Access Vector Search Index
index = client.create_direct_access_index(
endpoint_name="your_vector_search_endpoint_name",
index_name="your_index_name",
primary_key="id",
embedding_dimension=1024,
embedding_vector_column="text_vector",
schema={
"id": "int",
"field2": "str",
"field3": "float",
"text_vector": "array<float>"
}
)

# Create a DatabricksRM retriever and retrieve the top-3 most relevant documents from the
# Databricks Direct Access Vector Search Index corresponding to an example query
retriever = DatabricksRM(
databricks_index_name = "your_index_name",
docs_id_column_name="id",
text_column_name="field2",
k=3
)
retrieved_results = DatabricksRM(query="Example query text", query_type="hybrid"))
```
6 changes: 4 additions & 2 deletions docs/api/retrieval_model_clients/FaissRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 4
sidebar_position: 5
---

# retrieve.FaissRM
Expand All @@ -17,6 +17,7 @@ FaissRM(
```

**Parameters:**

- `document_chunks` (_List[str]_): a list of strings that comprises the corpus to search. You cannot add/insert/upsert to this list after creating this FaissRM object.
- `vectorizer` (_dsp.modules.sentence_vectorizer.BaseSentenceVectorizer_, _optional_): If not provided, a dsp.modules.sentence_vectorizer.SentenceTransformersVectorizer object is created and used.
- `k` (_int_, _optional_): The number of top passages to retrieve. Defaults to 3.
Expand All @@ -28,16 +29,17 @@ FaissRM(
Search the FaissRM vector database for the top `k` passages matching the given query or queries, using embeddings generated via the vectorizer specified at FaissRM construction time

**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with a `long_text` attribute and an `index` attribute. The `index` attribute is the index in the document_chunks array provided to this FaissRM object at construction time.

### Quickstart with the default vectorizer

The **FaissRM** module provides a retriever that uses an in-memory Faiss vector database. This module does not include a vectorizer; instead it supports any subclass of **dsp.modules.sentence_vectorizer.BaseSentenceVectorizer**. If a vectorizer is not provided, an instance of **dsp.modules.sentence_vectorizer.SentenceTransformersVectorizer** is created and used by **FaissRM**. Note that the default embedding model for **SentenceTransformersVectorizer** is **all-MiniLM-L6-v2**


```python
import dspy
from dspy.retrieve.faiss_rm import FaissRM
Expand Down
10 changes: 6 additions & 4 deletions docs/api/retrieval_model_clients/MilvusRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 5
sidebar_position: 6
---

# retrieve.MilvusRM
Expand All @@ -20,13 +20,14 @@ MilvusRM(
```

**Parameters:**

- `collection_name (str)`: The name of the Milvus collection to query against.
- `uri (str, optional)`: The Milvus connection uri. Defaults to "http://localhost:19530".
- `token (str, optional)`: The Milvus connection token. Defaults to None.
- `db_name (str, optional)`: The Milvus database name. Defaults to "default".
- `embedding_function (callable, optional)`: The function to convert a list of text to embeddings.
The embedding function should take a list of text strings as input and output a list of embeddings.
Defaults to None. By default, it will get OpenAI client by the environment variable OPENAI_API_KEY and use OpenAI's embedding model "text-embedding-3-small" with the default dimension.
The embedding function should take a list of text strings as input and output a list of embeddings.
Defaults to None. By default, it will get OpenAI client by the environment variable OPENAI_API_KEY and use OpenAI's embedding model "text-embedding-3-small" with the default dimension.
- `k (int, optional)`: The number of top passages to retrieve. Defaults to 3.

### Methods
Expand All @@ -36,10 +37,12 @@ MilvusRM(
Search the Milvus collection for the top `k` passages matching the given query or queries, using embeddings generated via the default OpenAI embedding or the specified `embedding_function`.

**Parameters:**

- `query_or_queries` (_Union[str, List[str]]_): The query or list of queries to search for.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages, each represented as a `dotdict` with schema `[{"id": str, "score": float, "long_text": str, "metadatas": dict }]`

### Quickstart
Expand Down Expand Up @@ -71,7 +74,6 @@ for result in results:
print("Document:", result.long_text, "\n")
```


#### Customized Embedding Function

```python
Expand Down
14 changes: 11 additions & 3 deletions docs/api/retrieval_model_clients/MyScaleRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 6
sidebar_position: 7
---

# retrieve.MyScaleRM
Expand All @@ -9,6 +9,7 @@ sidebar_position: 6
Initializes an instance of the `MyScaleRM` class, which is designed to use MyScaleDB (a ClickHouse fork optimized for vector similarity and full-text search) to retrieve documents based on query embeddings. This class supports embedding generation using either local models or OpenAI's API and manages database interactions efficiently.

### Syntax

```python
MyScaleRM(
client: clickhouse_connect.driver.client.Client,
Expand All @@ -22,7 +23,9 @@ MyScaleRM(
local_embed_model: Optional[str] = None
)
```

## Parameters for `MyScaleRM` Constructor

- `client` (_clickhouse_connect.driver.client.Client_): A client connection to the MyScaleDB database, used to execute queries and manage interactions with the database.
- `table` (_str_): Specifies the table within MyScaleDB from which data will be retrieved. This table should be equipped with a vector column for conducting similarity searches.
- `database` (_str_, optional): The name of the database where the table is located, defaulting to `"default"`.
Expand All @@ -34,21 +37,26 @@ MyScaleRM(
- `local_embed_model` (_str, optional_): Specifies a local model for embedding generation, chosen if local computation is preferred.

## Methods

### `forward`

Executes a retrieval operation based on a user's query and returns the top `k` relevant results using the embeddings generated by the specified method.

### Syntax

```python
def forward(self, user_query: str, k: Optional[int] = None) -> dspy.Prediction
```

## Parameters

- `user_query` (_str_): The query to retrieve matching passages.
- `k` (_Optional[int], optional_): The number of top matches to retrieve. If not provided, it defaults to the `k` value set during class initialization.

## Returns

- `dspy.Prediction`: Contains the retrieved passages, formatted as a list of `dotdict` objects. Each entry includes:
- **long_text (str)**: The text content of the retrieved passage.
- **long_text (str)**: The text content of the retrieved passage.

## Description

Expand Down Expand Up @@ -77,4 +85,4 @@ passages = results.passages
for passage in passages:
print(passage['long_text'], "\n")

```
```
12 changes: 7 additions & 5 deletions docs/api/retrieval_model_clients/Neo4jRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 7
sidebar_position: 8
---

# retrieve.neo4j_rm
Expand Down Expand Up @@ -34,36 +34,38 @@ You need to define the credentials as environment variables:
- `OPENAI_API_KEY` (_str_): Specifies the API key required for authenticiating with OpenAI's services.

**Parameters:**

- `index_name` (_str_): Specifies the name of the vector index to be used within Neo4j for organizing and querying data.
- `text_node_property` (_str_, _optional_): Defines the specific property of nodes that will be returned.
- `k` (_int_, _optional_): The number of top results to return from the retrieval operation. It defaults to 5 if not explicitly specified.
- `retrieval_query` (_str_, _optional_): A custom query string provided for retrieving data. If not provided, a default query tailored to the `text_node_property` will be used.
- `embedding_provider` (_str_, _optional_): The name of the service provider for generating embeddings. Defaults to "openai" if not specified.
- `embedding_model` (_str_, _optional_): The specific embedding model to use from the provider. By default, it uses the "text-embedding-ada-002" model from OpenAI.


### Methods

#### `forward(self, query: [str], k: Optional[int] = None) -> dspy.Prediction`

Search the neo4j vector index for the top `k` passages matching the given query or queries, using embeddings generated via the specified `embedding_model`.

**Parameters:**
- `query` (str_): The query.

- `query` (str\_): The query.
- `k` (_Optional[int]_, _optional_): The number of results to retrieve. If not specified, defaults to the value set during initialization.

**Returns:**

- `dspy.Prediction`: Contains the retrieved passages as a list of string with the prediction signature.

ex:

```python
Prediction(
passages=['Passage 1 Lorem Ipsum awesome', 'Passage 2 Lorem Ipsum Youppidoo', 'Passage 3 Lorem Ipsum Yassssss']
)
```

### Quick Example how to use Neo4j in a local environment.

### Quick Example how to use Neo4j in a local environment.

```python
from dspy.retrieve.neo4j_rm import Neo4jRM
Expand Down
8 changes: 5 additions & 3 deletions docs/api/retrieval_model_clients/RAGatouilleRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 8
sidebar_position: 9
---

# retrieve.RAGatouilleRM
Expand All @@ -13,12 +13,13 @@ class RAGatouilleRM(dspy.Retrieve):
def __init__(
self,
index_root: str,
index_name: str,
index_name: str,
k: int = 3,
):
```

**Parameters:**

- `index_root` (_str_): Folder path where your index is stored.
- `index_name` (_str_): Name of the index you want to retrieve from.
- `k` (_int_): The default number of passages to retrieve. Defaults to `3`.
Expand All @@ -30,9 +31,10 @@ class RAGatouilleRM(dspy.Retrieve):
Enables making queries to the RAGatouille-made index for retrieval. Internally, the method handles the specifics of preparing the query to obtain the response. The function handles the retrieval of the top-k passages based on the provided query.

**Parameters:**

- `query_or_queries` (Union[str, List[str]]): Query string used for retrieval.
- `k` (_int_, _optional_): Number of passages to retrieve. Defaults to 3.

**Returns:**
- `dspy.Prediction`: List of k passages

- `dspy.Prediction`: List of k passages
2 changes: 1 addition & 1 deletion docs/api/retrieval_model_clients/SnowflakeRM.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
sidebar_position: 9
sidebar_position: 10
---

# retrieve.SnowflakeRM
Expand Down
Loading

0 comments on commit 59ae987

Please sign in to comment.