-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate cachetools for in-memory LM caching, including unhashable types & pydantic #1896
Conversation
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
key=lambda request, *args, **kwargs: cache_key(request), | ||
# Use a lock to ensure thread safety for the cache when DSPy LMs are queried | ||
# concurrently, e.g. during optimization and evaluation | ||
lock=threading.Lock(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cachetools provides thread safety natively. alternatively, we could try to implement our own cache with the required thread safety functionality, but I suspect there might be bugs (best to reuse something that is known to work)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not a blocker for merge, but I'm slightly uneasy about Python-level locking (compared to whatever functools normally does?). Maybe it's required for thread safety, but since it's happening for every single LM call it's a bit worrisome.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @okhat ! Functools uses a Python lock as well (Rlock). Ill follow up with a small PR to use Rlock instead of Lock.
@cached( | ||
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead | ||
cache=LRUCache(maxsize=maxsize or float("inf")), | ||
key=lambda request, *args, **kwargs: cache_key(request), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the key advantage of cachetools. Unlike lru_cache
, it allows us to define a cache key by applying a custom function to one or more arguments, rather than forcing all arguments to be hashed / JSON-encoded, passed to the function, and then decoded afterwards. Encoding / decoding is infeasible for callables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do a global dspy.settings.request_cache default = LRUCache(maxsize=10_000_000)
and then have this function pull from that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With this naming, it could be confused with the disk cache though, right? It seems like we'd want some unified way to refer to both caches, or more distinctive naming. Thoughts?
return litellm_completion( | ||
request, | ||
cache={"no-cache": False, "no-store": False}, | ||
num_retries=num_retries, | ||
) | ||
|
||
|
||
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
kwargs = ujson.loads(request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We no longer have to serialize / deserialize request
within the litellm_completion
and litellm_text_completion
calls
tests/caching/test_caching.py
Outdated
def test_lm_calls_support_unhashable_types(litellm_test_server, temporary_blank_cache_dir): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
lm_with_unhashable_callable = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
# Define a callable kwarg for the LM to use during inference | ||
azure_ad_token_provider=lambda *args, **kwargs: None, | ||
) | ||
lm_with_unhashable_callable("Query") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fails on main
with:
)
E TypeError: <function test_lm_calls_support_unhashable_types.<locals>.<lambda> at 0x31204d5a0> is not JSON serializable
tests/caching/test_caching.py
Outdated
def test_lm_calls_support_pydantic_models(litellm_test_server, temporary_blank_cache_dir): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
class ResponseFormat(pydantic.BaseModel): | ||
response: str | ||
|
||
lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
response_format=ResponseFormat, | ||
) | ||
lm("Query") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fails on main
with:
TypeError: <class 'tests.caching.test_caching.test_lm_calls_support_pydantic_models.<locals>.ResponseFormat'> is not JSON serializable
@@ -212,47 +219,82 @@ def copy(self, **kwargs): | |||
return new_instance | |||
|
|||
|
|||
@functools.lru_cache(maxsize=None) | |||
def cached_litellm_completion(request, num_retries: int): | |||
def request_cache(maxsize: Optional[int] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@okhat @bahtman @CyrusNuevoDia Thoughts on this approach? See inline comments discussing advantages below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks cool!
Could set default maxsize = float("inf")
here
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
Signed-off-by: dbczumar <[email protected]>
assert azure_openai_lm("azure openai query") == expected_response | ||
|
||
|
||
def test_text_lms_can_be_queried(litellm_test_server): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're making changes to litellm_text_completion
as well, we should have some coverage for LM queries with model_type="text"
def test_lm_calls_support_unhashable_types(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
lm_with_unhashable_callable = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
# Define a callable kwarg for the LM to use during inference | ||
azure_ad_token_provider=lambda *args, **kwargs: None, | ||
) | ||
lm_with_unhashable_callable("Query") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fails on main
with:
E TypeError: <function test_lm_calls_support_unhashable_types.<locals>.<lambda> at 0x31204d5a0> is not JSON serializable
def test_lm_calls_support_pydantic_models(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
class ResponseFormat(pydantic.BaseModel): | ||
response: str | ||
|
||
lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
response_format=ResponseFormat, | ||
) | ||
lm("Query") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fails on main
with:
TypeError: <class 'tests.caching.test_caching.test_lm_calls_support_pydantic_models.<locals>.ResponseFormat'> is not JSON serializable
Looks awesome! Is there a way to have a global cache that we can dump/load? |
Totally! We can add that if / when we need it by leveraging cachetools |
Awesome, lgtm! Appreciate you 🙏 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One quick fix and lgtm
@cached( | ||
# NB: cachetools doesn't support maxsize=None; it recommends using float("inf") instead | ||
cache=LRUCache(maxsize=maxsize or float("inf")), | ||
key=lambda request, *args, **kwargs: cache_key(request), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you do a global dspy.settings.request_cache default = LRUCache(maxsize=10_000_000)
and then have this function pull from that?
…ypes & pydantic (#1896) * Impl Signed-off-by: dbczumar <[email protected]> * Cachetools add Signed-off-by: dbczumar <[email protected]> * Inline Signed-off-by: dbczumar <[email protected]> * tweak Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update lm.py --------- Signed-off-by: dbczumar <[email protected]>
Integrate cachetools for in-memory LM caching, including unhashable types & pydantic
Fixes #1759