Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support structured outputs response format based on signature in JSON adapter #1881

Merged
merged 12 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 59 additions & 2 deletions dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import enum
import inspect
import json
import logging
import textwrap
from copy import deepcopy
from typing import Any, Dict, KeysView, Literal, NamedTuple, get_args, get_origin

import json_repair
import litellm
import pydantic
from pydantic import TypeAdapter
from pydantic import TypeAdapter, create_model
from pydantic.fields import FieldInfo

from dspy.adapters.base import Adapter
Expand All @@ -18,6 +20,8 @@
from ..signatures.signature import SignatureMeta
from ..signatures.utils import get_dspy_field_type

_logger = logging.getLogger(__name__)


class FieldInfoWithName(NamedTuple):
name: str
Expand All @@ -35,7 +39,16 @@ def __call__(self, lm, lm_kwargs, signature, demos, inputs, _parse_values=True):
try:
provider = lm.model.split("/", 1)[0] or "openai"
if "response_format" in litellm.get_supported_openai_params(model=lm.model, custom_llm_provider=provider):
outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
try:
response_format = _get_structured_outputs_response_format(signature)
outputs = lm(**inputs, **lm_kwargs, response_format=response_format)
except Exception:
Comment on lines +42 to +45
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LM providers have differing levels of support for response_format fields. For example, Databricks doesn't support anyOf / allOf, but OpenAI does.

A blanket try/catch seems appropriate here to start.

_logger.warning(
"Failed to obtain response using signature-based structured outputs"
" response format: Falling back to default 'json_object' response format."
" Exception: {e}"
)
Comment on lines +47 to +50
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect to hit this case for tuples, until there's support for prefixItems in the OpenAI structured outputs API. Other vendors, e.g. Databricks, will likely lag even further behind (e.g. Databricks doesn't support anyOf currently, but OpenAI does), meaning that we could hit this case for additional output types

outputs = lm(**inputs, **lm_kwargs, response_format={"type": "json_object"})
else:
outputs = lm(**inputs, **lm_kwargs)

Expand Down Expand Up @@ -303,3 +316,47 @@ def format_signature_fields_for_instructions(role, fields: Dict[str, FieldInfo])
# ", and then ending with the marker for `completed`.")

return "\n\n".join(parts).strip()


def _get_structured_outputs_response_format(signature: SignatureMeta) -> pydantic.BaseModel:
"""
Constructs the LiteLLM / OpenAI `response_format` parameter for obtaining structured outputs
from an LM request, based on the output fields of the specified DSPy signature.

Args:
signature: The DSPy signature for which to obtain the `response_format` request parameter.
Returns:
A Pydantic model representing the `response_format` parameter for the LM request.
"""

def filter_json_schema_extra(field_name: str, field_info: FieldInfo) -> FieldInfo:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs test coverage

"""
Recursively filter the `json_schema_extra` of a FieldInfo to include only `desc`.
Handles nested structures in FieldInfo objects while preserving original model names.
"""
field_copy = deepcopy(field_info) # Make a copy to avoid mutating the original

# Update `json_schema_extra` for the current field
if field_copy.json_schema_extra:
field_copy.json_schema_extra = {}
field_desc = field_info.json_schema_extra.get("desc")
if field_desc is not None and field_desc != f"${{{field_name}}}":
field_copy.json_schema_extra["desc"] = field_desc

# Handle nested models
if hasattr(field_copy.annotation, "__pydantic_model__"):
# Recursively update fields of the nested model
nested_model = field_copy.annotation.__pydantic_model__
updated_fields = {
key: filter_json_schema_extra(key, value) for key, value in nested_model.__fields__.items()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious - why do we need recursive handling? nested_model.__fields__ should be Pydantic fields instead of DSPy fields, do they also have these DSPy internal attributes like __dspy_field_type?

Copy link
Collaborator Author

@dbczumar dbczumar Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the vast majority of cases, hopefully not....

Though the following user error will silently produce this state, which is probably best to exclude from response_format because the program still runs

import dspy

import pydantic

class Obj(pydantic.BaseModel):
    a: int = dspy.OutputField()
    b: str

class MySig(dspy.Signature):
    inp: str = dspy.InputField() 
    outp: Obj = dspy.OutputField()

print(MySig.schema())
{'$defs': {'Obj': {'properties': {'a': {'__dspy_field_type': 'output', 'title': 'A', 'type': 'integer'}, 'b': {'title': 'B', 'type': 'string'}}, 'required': ['a', 'b'], 'title': 'Obj', 'type': 'object'}}, 'description': 'Given the fields `inp`, produce the fields `outp`.', 'properties': {'inp': {'__dspy_field_type': 'input', 'desc': '${inp}', 'prefix': 'Inp:', 'title': 'Inp', 'type': 'string'}, 'outp': {'$ref': '#/$defs/Obj', '__dspy_field_type': 'output', 'desc': '${outp}', 'prefix': 'Outp:'}}, 'required': ['inp', 'outp'], 'title': 'MySig', 'type': 'object'}

}
# Create a new model with the same name and updated fields
field_copy.annotation = create_model(nested_model.__name__, **updated_fields)

return field_copy

output_pydantic_fields = {
key: (value.annotation, filter_json_schema_extra(key, value)) for key, value in signature.output_fields.items()
}
DSPyProgramOutputs = create_model("DSPyProgramOutputs", **output_pydantic_fields)
return DSPyProgramOutputs
16 changes: 10 additions & 6 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
request=ujson.dumps(dict(model=self.model, messages=messages, **kwargs)),
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When response_format is a pydantic model (recommended by LiteLLM), ujson.dumps() fails because pydantic models are not directly serializable using ujson.dumps(). This line diff is a temporary hack to get the implementation working end-to-end for test purposes. We need a proper solution before merge.

Copy link

@rohitgarud rohitgarud Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dbczumar, please let me know if this approach is worthy of a PR, I am happy to contribute. Custom adapter using this approach.

request=dict(model=self.model, messages=messages, **kwargs),
num_retries=self.num_retries,
)
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]
Expand Down Expand Up @@ -153,7 +153,11 @@ def thread_function_wrapper():
thread = threading.Thread(target=thread_function_wrapper)
model_to_finetune = self.finetuning_model or self.model
job = self.provider.TrainingJob(
thread=thread, model=model_to_finetune, train_data=train_data, train_kwargs=train_kwargs, data_format=data_format
thread=thread,
model=model_to_finetune,
train_data=train_data,
train_kwargs=train_kwargs,
data_format=data_format,
)
thread.start()

Expand Down Expand Up @@ -212,7 +216,7 @@ def copy(self, **kwargs):
return new_instance


@functools.lru_cache(maxsize=None)
# @functools.lru_cache(maxsize=None)
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hack to get the implementation working end to end for test purposes (see also https://github.com/stanfordnlp/dspy/pull/1881/files#r1867211773). We need a proper fix before merge, e.g. #1862 (though it's not 100% clear to me why we need LRU caching here in the first place on top of the caching that LiteLLM is already providing)

Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @okhat I'm sure I'm missing something here - let me know if there's additional context motivating this lru_cache.

def cached_litellm_completion(request, num_retries: int):
return litellm_completion(
request,
Expand All @@ -222,12 +226,12 @@ def cached_litellm_completion(request, num_retries: int):


def litellm_completion(request, num_retries: int, cache={"no-cache": True, "no-store": True}):
kwargs = ujson.loads(request)
return litellm.completion(
result = litellm.completion(
num_retries=num_retries,
cache=cache,
**kwargs,
**request,
)
return result


@functools.lru_cache(maxsize=None)
Expand Down
16 changes: 12 additions & 4 deletions tests/reliability/test_pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import pydantic
import pytest

import dspy
from tests.reliability.utils import assert_program_output_correct, known_failing_models
Expand Down Expand Up @@ -33,22 +34,29 @@ class QA(dspy.Signature):
assert_program_output_correct(
program_input=question,
program_output=answer.comments,
grading_guidelines="The comments should be relevant to the answer",
grading_guidelines=(
"The comments should be relevant to the answer. They don't need to restate the answer explicitly."
),
)
assert answer.certainty >= 0
assert answer.certainty <= 1
assert len(answer.comments) >= 2


def test_color_classification_using_enum():
@pytest.mark.parametrize("module", [dspy.Predict, dspy.ChainOfThought])
Copy link
Collaborator Author

@dbczumar dbczumar Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CoT fails this test case with chat and json adapters on master:

FAILED test_pydantic_models.py::test_color_classification_using_enum[llama-3.1-70b-instruct-ChainOfThought] - ValueError: Color.BLUE is not a valid name or value for the enum Color
================================================ 1 failed, 1 passed, 24 skipped, 26 deselected, 2 warnings in 0.22s

However, it passes on the PR branch :D

def test_color_classification_using_enum(module):
Color = Enum("Color", ["RED", "GREEN", "BLUE"])

class Colorful(dspy.Signature):
text: str = dspy.InputField()
color: Color = dspy.OutputField()

program = dspy.Predict(Colorful)
color = program(text="The sky is blue").color
program = module(Colorful)
# Note: The precise text, including the trailing period, is important here for ensuring that
# the program is correctly extracting the color from the text; previous implementations have
# produced invalid enum responses for "The sky is blue.", but they have produced valid enum
# responses for "The sky is blue" (without the period).
color = program(text="The sky is blue.").color

assert color == Color.BLUE

Expand Down
1 change: 0 additions & 1 deletion tests/reliability/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def assert_program_output_correct(
grading_guidelines = [grading_guidelines]

with judge_dspy_configuration():
print("GUIDELINES", grading_guidelines)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing a leftover & unintentional debugging statement from test generation code

for guideline_entry in grading_guidelines:
judge_response = _get_judge_program()(
program_input=str(program_input),
Expand Down
Loading