Skip to content

Commit

Permalink
Integrate cachetools for in-memory LM caching, including unhashable t…
Browse files Browse the repository at this point in the history
…ypes & pydantic (#1896)

* Impl

Signed-off-by: dbczumar <[email protected]>

* Cachetools add

Signed-off-by: dbczumar <[email protected]>

* Inline

Signed-off-by: dbczumar <[email protected]>

* tweak

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* fix

Signed-off-by: dbczumar <[email protected]>

* Update lm.py

---------

Signed-off-by: dbczumar <[email protected]>
  • Loading branch information
dbczumar authored and isaacbmiller committed Dec 11, 2024
1 parent f5d8c12 commit 92199e4
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 25 deletions.
77 changes: 59 additions & 18 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import threading
import uuid
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, List, Literal, Optional

import litellm
import pydantic
import ujson
from cachetools import LRUCache, cached

from dspy.adapters.base import Adapter
from dspy.clients.openai import OpenAIProvider
Expand Down Expand Up @@ -92,7 +95,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand Down Expand Up @@ -153,7 +156,11 @@ def thread_function_wrapper():
thread = threading.Thread(target=thread_function_wrapper)
model_to_finetune = self.finetuning_model or self.model
job = self.provider.TrainingJob(
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format
thread=thread,
model=model_to_finetune,
train_data=train_data,
train_kwargs=train_kwargs,
data_format=data_format,
)
thread.start()

Expand Down Expand Up @@ -212,47 +219,81 @@ def copy(self, **kwargs):
return new_instance


@functools.lru_cache(maxsize=None)
def cached_litellm_completion(request, num_retries: int):
def request_cache(maxsize: Optional[int] = None):
"""
A threadsafe decorator to create an in-memory LRU cache for LM inference functions that accept
a dictionary-like LM request. An in-memory cache for LM calls is critical for ensuring
good performance when optimizing and evaluating DSPy LMs (disk caching alone is too slow).
Args:
maxsize: The maximum size of the cache. If unspecified, no max size is enforced (cache is unbounded).
Returns:
A decorator that wraps the target function with caching.
"""

def cache_key(request: Dict[str, Any]) -> str:
# Transform Pydantic models into JSON-convertible format and exclude unhashable objects
params = {k: (v.dict() if isinstance(v, pydantic.BaseModel) else v) for k, v in request.items()}
params = {k: v for k, v in params.items() if not callable(v)}
return sha256(ujson.dumps(params, sort_keys=True).encode()).hexdigest()

def decorator(func):
@cached(
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead
cache=LRUCache(maxsize=maxsize or float("inf")),
key=lambda request, *args, **kwargs: cache_key(request),
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried
# concurrently, e.g. during optimization and evaluation
lock=threading.Lock(),
)
@functools.wraps(func)
def wrapper(request: dict, *args, **kwargs):
return func(request, *args, **kwargs)

return wrapper

return decorator


@request_cache(maxsize=None)
def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
return litellm_completion(
request,
cache={"no-cache": False, "no-store": False},
num_retries=num_retries,
)


def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
return litellm.completion(
num_retries=num_retries,
cache=cache,
**kwargs,
**request,
)


@functools.lru_cache(maxsize=None)
def cached_litellm_text_completion(request, num_retries: int):
@request_cache(maxsize=None)
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):
return litellm_text_completion(
request,
num_retries=num_retries,
cache={"no-cache": False, "no-store": False},
)


def litellm_text_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)

def litellm_text_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
# Extract the provider and model from the model string.
# TODO: Not all the models are in the format of "provider/model"
model = kwargs.pop("model").split("/", 1)
model = request.pop("model").split("/", 1)
provider, model = model[0] if len(model) > 1 else "openai", model[-1]

# Use the API key and base from the kwargs, or from the environment.
api_key = kwargs.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = kwargs.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")
# Use the API key and base from the request, or from the environment.
api_key = request.pop("api_key", None) or os.getenv(f"{provider}_API_KEY")
api_base = request.pop("api_base", None) or os.getenv(f"{provider}_API_BASE")

# Build the prompt from the messages.
prompt = "\n\n".join([x["content"] for x in kwargs.pop("messages")] + ["BEGIN RESPONSE:"])
prompt = "\n\n".join([x["content"] for x in request.pop("messages")] + ["BEGIN RESPONSE:"])

return litellm.text_completion(
cache=cache,
Expand All @@ -261,5 +302,5 @@ def litellm_text_completion(request, num_retries: int, cache={"no-cache": True,
api_base=api_base,
prompt=prompt,
num_retries=num_retries,
**kwargs,
**request,
)
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"tenacity>=8.2.3",
"anyio",
"asyncer==0.0.8",
"cachetools",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -138,6 +139,7 @@ falkordb = "^1.0.9"
json-repair = "^0.30.0"
tenacity = ">=8.2.3"
asyncer = "0.0.8"
cachetools = "^5.5.0"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.3"
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
anyio
asyncer==0.0.8
backoff
cachetools
datasets
diskcache
httpx
Expand All @@ -15,5 +18,3 @@ requests
tenacity>=8.2.3
tqdm
ujson
anyio
asyncer==0.0.8
20 changes: 20 additions & 0 deletions tests/caching/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ def test_lm_calls_are_cached_across_interpreter_sessions(litellm_test_server, te

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 0


def test_lm_calls_are_cached_in_memory_when_expected(litellm_test_server, temporary_blank_cache_dir):
api_base, server_log_file_path = litellm_test_server

lm1 = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
)
lm1("Example query")
# Remove the disk cache, after which the LM must rely on in-memory caching
shutil.rmtree(temporary_blank_cache_dir)
lm1("Example query2")
lm1("Example query2")
lm1("Example query2")
lm1("Example query2")

request_logs = read_litellm_test_server_request_logs(server_log_file_path)
assert len(request_logs) == 2
59 changes: 56 additions & 3 deletions tests/clients/test_lm.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,77 @@
from unittest import mock

import pydantic
import pytest

import dspy
from tests.test_utils.server import litellm_test_server


def test_lms_can_be_queried(litellm_test_server):
def test_chat_lms_can_be_queried(litellm_test_server):
api_base, _ = litellm_test_server
expected_response = ["Hi!"]

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
openai_lm("openai query")
assert openai_lm("openai query") == expected_response

azure_openai_lm = dspy.LM(
model="azure/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="chat",
)
azure_openai_lm("azure openai query")
assert azure_openai_lm("azure openai query") == expected_response


def test_text_lms_can_be_queried(litellm_test_server):
api_base, _ = litellm_test_server
expected_response = ["Hi!"]

openai_lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="text",
)
assert openai_lm("openai query") == expected_response

azure_openai_lm = dspy.LM(
model="azure/dspy-test-model",
api_base=api_base,
api_key="fakekey",
model_type="text",
)
assert azure_openai_lm("azure openai query") == expected_response


def test_lm_calls_support_unhashable_types(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

lm_with_unhashable_callable = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
# Define a callable kwarg for the LM to use during inference
azure_ad_token_provider=lambda *args, **kwargs: None,
)
lm_with_unhashable_callable("Query")


def test_lm_calls_support_pydantic_models(litellm_test_server):
api_base, server_log_file_path = litellm_test_server

class ResponseFormat(pydantic.BaseModel):
response: str

lm = dspy.LM(
model="openai/dspy-test-model",
api_base=api_base,
api_key="fakekey",
response_format=ResponseFormat,
)
lm("Query")

0 comments on commit 92199e4

Please sign in to comment.