-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support structured outputs response format based on signature in JSON adapter #1881
Changes from 1 commit
7a1a84f
b7dfbb8
ed9d504
6e01b2e
7c0e03b
5af146d
6007f61
68d8877
b87cf96
90dc353
40dee38
80f9f34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,15 @@ | |
import enum | ||
import inspect | ||
import json | ||
import logging | ||
import textwrap | ||
from copy import deepcopy | ||
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin | ||
|
||
import json_repair | ||
import litellm | ||
import pydantic | ||
from pydantic import TypeAdapter | ||
from pydantic import TypeAdapter, create_model | ||
from pydantic.fields import FieldInfo | ||
|
||
from dspy.adapters.base import Adapter | ||
|
@@ -18,6 +20,8 @@ | |
from ..signatures.signature import SignatureMeta | ||
from ..signatures.utils import get_dspy_field_type | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class FieldInfoWithName(NamedTuple): | ||
name: str | ||
|
@@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True): | |
try: | ||
provider = lm.model.split("/", 1)[0] or "openai" | ||
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider): | ||
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) | ||
try: | ||
response_format = _get_structured_outputs_response_format(signature) | ||
outputs = lm(**inputs, **lm_kwargs, response_format=response_format) | ||
except Exception: | ||
_logger.warning( | ||
"Failed to obtain response using signature-based structured outputs" | ||
" response format: Falling back to default 'json_object' response format." | ||
" Exception: {e}" | ||
) | ||
Comment on lines
+47
to
+50
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We expect to hit this case for tuples, until there's support for |
||
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"}) | ||
else: | ||
outputs = lm(**inputs, **lm_kwargs) | ||
|
||
|
@@ -303,3 +316,47 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo]) | |
# ", and then ending with the marker for `completed`.") | ||
|
||
return "\n\n".join(parts).strip() | ||
|
||
|
||
def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel: | ||
""" | ||
Obtains the LiteLLM / OpenAI `response_format` parameter for obtaining structured outputs from | ||
an LM request, based on the output fields of the specified DSPy signature. | ||
|
||
Args: | ||
signature: The DSPy signature for which to obtain the `response_format` request parameter. | ||
Returns: | ||
A Pydantic model representing the `response_format` parameter for the LM request. | ||
""" | ||
|
||
def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs test coverage |
||
""" | ||
Recursively filter the `json_schema_extra` of a FieldInfo to include only `desc`. | ||
Handles nested structures in FieldInfo objects while preserving original model names. | ||
""" | ||
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original | ||
|
||
# Update `json_schema_extra` for the current field | ||
if field_copy.json_schema_extra: | ||
field_copy.json_schema_extra = {} | ||
field_desc = field_info.json_schema_extra.get("desc") | ||
if field_desc is not None and field_desc != f"${{{field_name}}}": | ||
field_copy.json_schema_extra["desc"] = field_desc | ||
|
||
# Handle nested models | ||
if hasattr(field_copy.annotation, "__pydantic_model__"): | ||
# Recursively update fields of the nested model | ||
nested_model = field_copy.annotation.__pydantic_model__ | ||
updated_fields = { | ||
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious - why do we need recursive handling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the vast majority of cases, hopefully not.... Though the following user error will silently produce this state, which is probably best to exclude from
|
||
} | ||
# Create a new model with the same name and updated fields | ||
field_copy.annotation = create_model(nested_model.__name__, **updated_fields) | ||
|
||
return field_copy | ||
|
||
output_pydantic_fields = { | ||
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items() | ||
} | ||
DSPyProgramOutputs = create_model("DSPyProgramOutputs", **output_pydantic_fields) | ||
return DSPyProgramOutputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,7 +92,7 @@ def __call__(self, prompt=None, messages=None, **kwargs): | |
completion = cached_litellm_text_completion if cache else litellm_text_completion | ||
|
||
response = completion( | ||
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When response_format is a pydantic model (recommended by LiteLLM), There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @dbczumar, please let me know if this approach is worthy of a PR, I am happy to contribute. Custom adapter using this approach. |
||
request=dict(model=self.model, messages=messages, **kwargs), | ||
num_retries=self.num_retries, | ||
) | ||
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]] | ||
|
@@ -153,7 +153,11 @@ def thread_function_wrapper(): | |
thread = threading.Thread(target=thread_function_wrapper) | ||
model_to_finetune = self.finetuning_model or self.model | ||
job = self.provider.TrainingJob( | ||
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format | ||
thread=thread, | ||
model=model_to_finetune, | ||
train_data=train_data, | ||
train_kwargs=train_kwargs, | ||
data_format=data_format, | ||
) | ||
thread.start() | ||
|
||
|
@@ -212,7 +216,7 @@ def copy(self, **kwargs): | |
return new_instance | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
# @functools.lru_cache(maxsize=None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a hack to get the implementation working end to end for test purposes (see also https://github.com/stanfordnlp/dspy/pull/1881/files#r1867211773). We need a proper fix before merge, e.g. #1862 (though it's not 100% clear to me why we need LRU caching here in the first place on top of the caching that LiteLLM is already providing) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @okhat I'm sure I'm missing something here - let me know if there's additional context motivating this |
||
def cached_litellm_completion(request, num_retries: int): | ||
return litellm_completion( | ||
request, | ||
|
@@ -222,12 +226,12 @@ def cached_litellm_completion(request, num_retries: int): | |
|
||
|
||
def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}): | ||
kwargs = ujson.loads(request) | ||
return litellm.completion( | ||
result = litellm.completion( | ||
num_retries=num_retries, | ||
cache=cache, | ||
**kwargs, | ||
**request, | ||
) | ||
return result | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from typing import List | ||
|
||
import pydantic | ||
import pytest | ||
|
||
import dspy | ||
from tests.reliability.utils import assert_program_output_correct, known_failing_models | ||
|
@@ -33,22 +34,29 @@ class QA(dspy.Signature): | |
assert_program_output_correct( | ||
program_input=question, | ||
program_output=answer.comments, | ||
grading_guidelines="The comments should be relevant to the answer", | ||
grading_guidelines=( | ||
"The comments should be relevant to the answer. They don't need to restate the answer explicitly." | ||
), | ||
) | ||
assert answer.certainty >= 0 | ||
assert answer.certainty <= 1 | ||
assert len(answer.comments) >= 2 | ||
|
||
|
||
def test_color_classification_using_enum(): | ||
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CoT fails this test case with chat and json adapters on master:
However, it passes on the PR branch :D |
||
def test_color_classification_using_enum(module): | ||
Color = Enum("Color", ["RED", "GREEN", "BLUE"]) | ||
|
||
class Colorful(dspy.Signature): | ||
text: str = dspy.InputField() | ||
color: Color = dspy.OutputField() | ||
|
||
program = dspy.Predict(Colorful) | ||
color = program(text="The sky is blue").color | ||
program = module(Colorful) | ||
# Note: The precise text, including the trailing period, is important here for ensuring that | ||
# the program is correctly extracting the color from the text; previous implementations have | ||
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum | ||
# responses for "The sky is blue" (without the period). | ||
color = program(text="The sky is blue.").color | ||
|
||
assert color == Color.BLUE | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,7 +31,6 @@ def assert_program_output_correct( | |
grading_guidelines = [grading_guidelines] | ||
|
||
with judge_dspy_configuration(): | ||
print("GUIDELINES", grading_guidelines) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removing a leftover & unintentional debugging statement from test generation code |
||
for guideline_entry in grading_guidelines: | ||
judge_response = _get_judge_program()( | ||
program_input=str(program_input), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LM providers have differing levels of support for
response_format
fields. For example, Databricks doesn't support anyOf / allOf, but OpenAI does.A blanket try/catch seems appropriate here to start.