Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Databricks integrations for DSPy LMs/RM configurations #430

Merged
merged 5 commits into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions docs/language_models_client.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,44 @@ class Together(HFModel):
### Methods

Refer to [`dspy.OpenAI`](#openai) documentation.


## Databricks (Model Serving Endpoints)

### Usage
```python
lm = dspy.Databricks(model="databricks-mpt-30b-instruct")
```

### Constructor

The constructor inherits from the `GPT3` class and verifies the Databricks authentication credentials for using Databricks Model Serving API through the OpenAI SDK.
We expect the following environment variables to be set:
- `openai.api_key`: Databricks API key.
- `openai.base_url`: Databricks Model Endpoint url

The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the Databricks OpenAI SDK, such as `temperature`, `max_tokens`, `top_p`, and `n`. However, it removes the `frequency_penalty` and `presence_penalty` arguments as these are not currently supported by the Databricks API.

```python
class Databricks(GPT3):
def __init__(
self,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_type: Literal["chat", "text"] = None,
**kwargs,
):
```

**Parameters:**
- `model` (_str_): models hosted on Databricks.
- `stop` (_List[str]_, _optional_): List of stopping tokens to end generation.
- `api_key` (_Optional[str]_): Databricks API key. Defaults to None
- `api_base` (_Optional[str]_): Databricks Model Endpoint url Defaults to None.
- `model_type` (_Literal["chat", "text", "embeddings"]_): Specified model type to use.
- `**kwargs`: Additional language model arguments to pass to the API provider.

### Methods

Refer to [`dspy.OpenAI`](#openai) documentation.
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .cache_utils import *
from .gpt3 import *
from .databricks import *
from .hf import HFModel
from .colbertv2 import ColBERTv2
from .sentence_vectorizer import *
Expand Down
158 changes: 158 additions & 0 deletions dsp/modules/databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import logging
from logging.handlers import RotatingFileHandler

# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(message)s',
handlers=[
logging.FileHandler('openai_usage.log')
]
)

import functools
import json
from typing import Any, Literal, Optional, cast

import dsp
import backoff
import openai
from openai import OpenAI

from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory, cache_turn_on
from dsp.modules.gpt3 import GPT3

try:
from openai.openai_object import OpenAIObject
import openai.error
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
except Exception:
ERRORS = (openai.RateLimitError, openai.APIError)
OpenAIObject = dict


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details)
)

class Databricks(GPT3):
"""Wrapper around DSPy's OpenAI Wrapper. Supports Databricks Model Serving Endpoints for OpenAI SDK on both Chat, Completions, and Embeddings models.

Args:
model (str, required): Databricks-hosted LLM model to use.
api_key (Optional[str], optional): Databricks authentication token. Defaults to None.
api_base (Optional[str], optional): Databricks model serving endpoint. Defaults to None.
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
**kwargs: Additional arguments to pass to the OpenAI API provider.
"""

def __init__(
self,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_type: Literal["chat", "text", "embeddings"] = None,
**kwargs,
):
super().__init__(
model=model,
api_key=api_key,
api_provider="openai",
api_base=api_base,
model_type=model_type,
**kwargs,
)

self.kwargs.pop('frequency_penalty', None)
self.kwargs.pop('presence_penalty', None)

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs

kwargs = {**self.kwargs, **kwargs}
if self.model_type == "chat":
kwargs["messages"] = [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}]
kwargs = {"stringify_request": json.dumps(kwargs)}
response = custom_client_chat_request(**kwargs).json()
response = json.loads(response)
else:
kwargs["prompt"] = prompt
response = custom_client_completions_request(**kwargs).json()
response = json.loads(response)

history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return response

def embeddings(self, prompt: str, **kwargs):
kwargs = {**self.kwargs, **kwargs}
kwargs["input"] = prompt
kwargs.pop('temperature', None)
kwargs.pop('max_tokens', None)
kwargs.pop('top_p', None)
kwargs.pop('n', None)
response = custom_client_embeddings_request(**kwargs).json()
response = json.loads(response)
embeddings = [cur_obj['embedding'] for cur_obj in response['data']][0]
return embeddings

def __call__(self, prompt: str, **kwargs):
if self.model_type == "embeddings":
return self.embeddings(prompt, **kwargs)
else:
return super().__call__(prompt, **kwargs)

def create_custom_client():
client = OpenAI(api_key=openai.api_key, base_url=openai.base_url)
return client

def custom_client_chat_request(**kwargs):
return cached_custom_client_chat_request_v2_wrapped(**kwargs)

def custom_client_embeddings_request(**kwargs):
return cached_custom_client_embeddings_request_v2_wrapped(**kwargs)

def custom_client_completions_request(**kwargs):
return cached_custom_client_completions_request_v2_wrapped(**kwargs)


@CacheMemory.cache
def cached_custom_client_chat_request_v2(**kwargs):
client = create_custom_client()
return client.chat.completions.create(**kwargs)

@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def cached_custom_client_chat_request_v2_wrapped(**kwargs):
if "stringify_request" in kwargs:
kwargs = json.loads(kwargs["stringify_request"])
return cached_custom_client_chat_request_v2(**kwargs)

@CacheMemory.cache
def cached_custom_client_completions_request_v2(**kwargs):
client = create_custom_client()
return client.completions.create(**kwargs)

@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def cached_custom_client_completions_request_v2_wrapped(**kwargs):
return cached_custom_client_completions_request_v2(**kwargs)

@CacheMemory.cache
def cached_custom_client_embeddings_request_v2(**kwargs):
client = create_custom_client()
return client.embeddings.create(**kwargs)

@functools.lru_cache(maxsize=None if cache_turn_on else 0)
@NotebookCacheMemory.cache
def cached_custom_client_embeddings_request_v2_wrapped(**kwargs):
return cached_custom_client_embeddings_request_v2(**kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
settings = dsp.settings

OpenAI = dsp.GPT3
Databricks = dsp.Databricks
Cohere = dsp.Cohere
ColBERTv2 = dsp.ColBERTv2
Pyserini = dsp.PyseriniRetriever
Expand Down
130 changes: 130 additions & 0 deletions dspy/retrieve/databricks_rm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import dspy
import os
import requests
from typing import Union, List, Optional
from collections import defaultdict
from dspy.primitives.prediction import Prediction

class DatabricksRM(dspy.Retrieve):
"""
A retrieval module that uses Databricks Vector Search Endpoint to return the top-k embeddings for a given query.

Args:
databricks_index_name (str): Databricks vector search index to query
databricks_endpoint (str): Databricks index endpoint url
databricks_token (str): Databricks authentication token
columns (list[str]): Column names to include in response
filters_json (str, optional): JSON string for query filters
k (int, optional): Number of top embeddings to retrieve. Defaults to 3.

Examples:
Below is a code snippet that shows how to configure Databricks Vector Search endpoints:

(example adapted from "Databricks: How to create and query a Vector Search Index:
https://docs.databricks.com/en/generative-ai/create-query-vector-search.html#create-a-vector-search-index)

```python
from databricks.vector_search.client import VectorSearchClient

#Creating Vector Search Client

client = VectorSearchClient()

client.create_endpoint(
name="your_vector_search_endpoint_name",
endpoint_type="STANDARD"
)

#Creating Vector Search Index using Python SDK
#Example for Direct Vector Acces Index

index = client.create_direct_access_index(
endpoint_name="your_databricks_host_url",
index_name="your_index_name",
primary_key="id",
embedding_dimension=1024,
embedding_vector_column="text_vector",
schema={
"id": "int",
"field2": "str",
"field3": "float",
"text_vector": "array<float>"}
)

llm = dspy.OpenAI(model="gpt-3.5-turbo")
retriever_model = DatabricksRM(databricks_index_name = "your_index_name",
databricks_endpoint = "your_databricks_host_url", databricks_token = "your_databricks_token", columns= ["id", "field2", "field3", "text_vector"], k=3)
dspy.settings.configure(lm=llm, rm=retriever_model)
```

Below is a code snippet that shows how to query the Databricks Direct Vector Access Index using the forward() function.
```python
self.retrieve = DatabricksRM(query=[1, 2, 3], query_type = 'vector')
```
"""
def __init__(self, databricks_index_name = None, databricks_endpoint = None, databricks_token = None, columns = None, filters_json = None, k = 3):
super().__init__(k=k)
if not databricks_token and not os.environ.get("DATABRICKS_TOKEN"):
raise ValueError("You must supply databricks_token or set environment variable DATABRICKS_TOKEN")
if not databricks_endpoint and not os.environ.get("DATABRICKS_HOST"):
raise ValueError("You must supply databricks_endpoint or set environment variable DATABRICKS_HOST")
if not databricks_index_name:
raise ValueError("You must supply vector index name")
if not columns:
raise ValueError("You must specify a list of column names to be included in the response")
self.databricks_token = databricks_token if databricks_token else os.environ["DATABRICKS_TOKEN"]
self.databricks_endpoint = databricks_endpoint if databricks_endpoint else os.environ["DATABRICKS_HOST"]
self.databricks_index_name = databricks_index_name
self.columns = columns
self.filters_json = filters_json
self.k = k

def forward(self, query: Union[str, List[float]], query_type: str = 'vector') -> dspy.Prediction:
"""Search with Databricks Vector Search Client for self.k top results for query

Args:
query (Union[str, List[float]]): query to search for.
query_type (str): 'vector' for Direct Vector Access Index and Delta Sync Index using self-managed vectors or 'text' for Delta Sync Index using model endpoint.

Returns:
dspy.Prediction: An object containing the retrieved results.
"""
headers = {
"Authorization": f"Bearer {self.databricks_token}",
"Content-Type": "application/json"
}
payload = {
"columns": self.columns,
"num_results": self.k
}
if query_type == 'vector':
if not isinstance(query, list):
raise ValueError("Query must be a list of floats for query_vector")
payload["query_vector"] = query
elif query_type == 'text':
if not isinstance(query, str):
raise ValueError("Query must be a string for query_text")
payload["query_text"] = query
else:
raise ValueError("Invalid query type specified. Use 'vector' or 'text'.")
if self.filters_json:
payload["filters_json"] = self.filters_json
response = requests.post(
f"{self.databricks_endpoint}/api/2.0/vector-search/indexes/{self.databricks_index_name}/query",
json=payload,
headers=headers
)
results = response.json()

docs = defaultdict(float)
text, score = None, None
for data_row in results["result"]["data_array"]:
for col, val in zip(results["manifest"]["columns"], data_row):
if col["name"] == 'text':
text = val
if col["name"] == 'score':
score = val
docs[text] += score

sorted_docs = sorted(docs.items(), key=lambda x: x[1], reverse=True)[:self.k]
return Prediction(docs=[doc for doc, _ in sorted_docs])
Loading