-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate cachetools for in-memory LM caching, including unhashable t…
…ypes & pydantic (#1896) * Impl Signed-off-by: dbczumar <[email protected]> * Cachetools add Signed-off-by: dbczumar <[email protected]> * Inline Signed-off-by: dbczumar <[email protected]> * tweak Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Update lm.py --------- Signed-off-by: dbczumar <[email protected]>
- Loading branch information
1 parent
f5d8c12
commit 92199e4
Showing
6 changed files
with
142 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,77 @@ | ||
from unittest import mock | ||
|
||
import pydantic | ||
import pytest | ||
|
||
import dspy | ||
from tests.test_utils.server import litellm_test_server | ||
|
||
|
||
def test_lms_can_be_queried(litellm_test_server): | ||
def test_chat_lms_can_be_queried(litellm_test_server): | ||
api_base, _ = litellm_test_server | ||
expected_response = ["Hi!"] | ||
|
||
openai_lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="chat", | ||
) | ||
openai_lm("openai query") | ||
assert openai_lm("openai query") == expected_response | ||
|
||
azure_openai_lm = dspy.LM( | ||
model="azure/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="chat", | ||
) | ||
azure_openai_lm("azure openai query") | ||
assert azure_openai_lm("azure openai query") == expected_response | ||
|
||
|
||
def test_text_lms_can_be_queried(litellm_test_server): | ||
api_base, _ = litellm_test_server | ||
expected_response = ["Hi!"] | ||
|
||
openai_lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="text", | ||
) | ||
assert openai_lm("openai query") == expected_response | ||
|
||
azure_openai_lm = dspy.LM( | ||
model="azure/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
model_type="text", | ||
) | ||
assert azure_openai_lm("azure openai query") == expected_response | ||
|
||
|
||
def test_lm_calls_support_unhashable_types(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
lm_with_unhashable_callable = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
# Define a callable kwarg for the LM to use during inference | ||
azure_ad_token_provider=lambda *args, **kwargs: None, | ||
) | ||
lm_with_unhashable_callable("Query") | ||
|
||
|
||
def test_lm_calls_support_pydantic_models(litellm_test_server): | ||
api_base, server_log_file_path = litellm_test_server | ||
|
||
class ResponseFormat(pydantic.BaseModel): | ||
response: str | ||
|
||
lm = dspy.LM( | ||
model="openai/dspy-test-model", | ||
api_base=api_base, | ||
api_key="fakekey", | ||
response_format=ResponseFormat, | ||
) | ||
lm("Query") |