Skip to content

Commit

Permalink
feat: update client with response format for openai sdk (#6282)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang authored Feb 7, 2025
1 parent 5deb5c6 commit ce771ec
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
JSONSchemaDraft7ObjectSchema,
PromptFunctionToolV1,
PromptMessage,
PromptResponseFormatJSONSchema,
PromptToolChoiceNone,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
Expand Down Expand Up @@ -67,10 +68,9 @@
)
from openai.types.chat.chat_completion_assistant_message_param import ContentArrayOfContentPart
from openai.types.chat.chat_completion_named_tool_choice_param import Function
from openai.types.chat.completion_create_params import (
ResponseFormat,
)
from openai.types.shared_params import FunctionDefinition
from openai.types.chat.completion_create_params import ResponseFormat
from openai.types.shared_params import FunctionDefinition, ResponseFormatJSONSchema
from openai.types.shared_params.response_format_json_schema import JSONSchema

def _(obj: PromptVersion) -> None:
messages, kwargs = to_chat_messages_and_kwargs(obj)
Expand Down Expand Up @@ -166,8 +166,18 @@ def _to_model_kwargs(
if (v := parameters.get("reasoning_effort")) is not None:
if v in ("low", "medium", "high"):
ans["reasoning_effort"] = v
if "tools" in obj and obj["tools"] and (tools := list(_to_tools(obj["tools"]))):
ans["tools"] = tools
if "tools" in obj:
tool_kwargs = _to_tool_kwargs(obj["tools"])
if "tools" in tool_kwargs:
ans["tools"] = tool_kwargs["tools"]
if "tool_choice" in tool_kwargs:
ans["tool_choice"] = tool_kwargs["tool_choice"]
if "response_format" in obj:
response_format = obj["response_format"]
if response_format["type"] == "response-format-json-schema-v1":
ans["response_format"] = _to_response_format_json_schema(response_format)
elif TYPE_CHECKING:
assert_never(response_format)
return ans


Expand Down Expand Up @@ -277,12 +287,19 @@ def _to_tools(
obj: PromptToolsV1,
) -> Iterable[ChatCompletionToolParam]:
for tool in obj["tools"]:
function: FunctionDefinition = {"name": tool["name"]}
definition: FunctionDefinition = {"name": tool["name"]}
if "description" in tool:
function["description"] = tool["description"]
definition["description"] = tool["description"]
if "schema" in tool:
function["parameters"] = dict(tool["schema"]["json"])
yield {"type": "function", "function": function}
definition["parameters"] = dict(tool["schema"]["json"])
if "extra_parameters" in tool:
extra_parameters = tool["extra_parameters"]
if "strict" in extra_parameters and (
isinstance(v := extra_parameters["strict"], bool) or v is None
):
definition["strict"] = v
ans: ChatCompletionToolParam = {"type": "function", "function": definition}
yield ans


def _from_tools(
Expand All @@ -302,10 +319,66 @@ def _from_tools(
type="json-schema-draft-7-object-schema",
json=definition["parameters"],
)
if "strict" in definition:
function["extra_parameters"] = {"strict": definition["strict"]}
functions.append(function)
return PromptToolsV1(type="tools-v1", tools=functions)


def _to_response_format_json_schema(
obj: PromptResponseFormatJSONSchema,
) -> ResponseFormat:
json_schema: JSONSchema = {
"name": obj["name"],
}
schema = obj["schema"]
if schema["type"] == "json-schema-draft-7-object-schema":
json_schema["schema"] = dict(schema["json"])
elif TYPE_CHECKING:
assert_never(schema["type"])
if "description" in obj:
json_schema["description"] = obj["description"]
if "extra_parameters" in obj:
extra_parameters = obj["extra_parameters"]
if "strict" in extra_parameters and (
isinstance(v := extra_parameters["strict"], bool) or v is None
):
json_schema["strict"] = v
ans: ResponseFormatJSONSchema = {
"type": "json_schema",
"json_schema": json_schema,
}
return ans


def _from_response_format(
obj: ResponseFormat,
) -> PromptResponseFormatJSONSchema:
if obj["type"] == "json_schema":
json_schema: JSONSchema = obj["json_schema"]
extra_parameters: dict[str, Any] = {}
if "strict" in json_schema:
extra_parameters["strict"] = json_schema["strict"]
ans = PromptResponseFormatJSONSchema(
type="response-format-json-schema-v1",
extra_parameters=extra_parameters,
name=json_schema["name"],
schema=JSONSchemaDraft7ObjectSchema(
type="json-schema-draft-7-object-schema",
json=json_schema["schema"] if "schema" in json_schema else {},
),
)
if "description" in json_schema:
ans["description"] = json_schema["description"]
return ans
elif obj["type"] == "text":
raise NotImplementedError
elif obj["type"] == "json_object":
raise NotImplementedError
else:
assert_never(obj)


def _to_messages(
obj: PromptMessage,
variables: Mapping[str, str],
Expand Down
49 changes: 48 additions & 1 deletion packages/phoenix-client/tests/canary/sdk/openai/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from __future__ import annotations

import json
from typing import Any, Iterable, Mapping, Optional, Union
from enum import Enum
from typing import Any, Iterable, Mapping, Optional, Union, cast

import pytest
from deepdiff.diff import DeepDiff
from faker import Faker
from openai.lib._parsing import type_to_response_format_param
from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionContentPartImageParam,
Expand All @@ -19,11 +23,14 @@
from openai.types.chat.chat_completion_assistant_message_param import ContentArrayOfContentPart
from openai.types.chat.chat_completion_content_part_image_param import ImageURL
from openai.types.chat.chat_completion_message_tool_call_param import Function
from openai.types.chat.completion_create_params import ResponseFormat
from openai.types.shared_params import FunctionDefinition
from pydantic import BaseModel, create_model

from phoenix.client.__generated__.v1 import (
ImageContentPart,
PromptMessage,
PromptResponseFormatJSONSchema,
PromptToolsV1,
TextContentPart,
TextContentValue,
Expand All @@ -32,12 +39,14 @@
from phoenix.client.helpers.sdk.openai.chat import (
_from_image,
_from_message,
_from_response_format,
_from_text,
_from_tool_call,
_from_tool_kwargs,
_from_tools,
_to_image,
_to_messages,
_to_response_format_json_schema,
_to_text,
_to_tool_call,
_to_tool_kwargs,
Expand Down Expand Up @@ -242,6 +251,44 @@ def test_round_trip(self) -> None:
assert not DeepDiff(obj, new_obj)


class _UIType(str, Enum):
div = "div"
button = "button"
header = "header"
section = "section"
field = "field"
form = "form"


class _Attribute(BaseModel):
name: str
value: str


class _UI(BaseModel):
type: _UIType
label: str
children: list[_UI]
attributes: list[_Attribute]


_UI.model_rebuild()


class TestResponseFormat:
@pytest.mark.parametrize(
"type_",
[
create_model("Response", ui=(_UI, ...)),
],
)
def test_round_trip(self, type_: type[BaseModel]) -> None:
obj = cast(ResponseFormat, type_to_response_format_param(type_))
x: PromptResponseFormatJSONSchema = _from_response_format(obj)
new_obj = _to_response_format_json_schema(x)
assert not DeepDiff(obj, new_obj)


class TestToolKwargs:
@pytest.mark.parametrize(
"obj",
Expand Down
55 changes: 27 additions & 28 deletions tests/integration/prompts/test_prompts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

import json
from enum import Enum
from secrets import token_hex
from types import MappingProxyType
from typing import Any, Literal, Mapping, Sequence, cast
from typing import Any, Iterable, Literal, Mapping, Sequence, cast

import phoenix as px
import pytest
from deepdiff.diff import DeepDiff
from openai import pydantic_function_tool
from openai.lib._parsing import type_to_response_format_param
from openai.types.chat import ChatCompletionToolParam
from openai.types.shared_params import ResponseFormatJSONSchema
from phoenix.client.__generated__.v1 import PromptVersion
from phoenix.client.utils import to_chat_messages_and_kwargs
Expand All @@ -36,10 +38,11 @@ def test_user_message(
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
prompt = _create_chat_prompt(u)
x = token_hex(4)
expected = [{"role": "user", "content": f"hello {x}"}]
prompt = _create_chat_prompt(u, template_format="FSTRING")
messages, _ = to_chat_messages_and_kwargs(prompt, variables={"x": x})
assert not DeepDiff(messages, [{"role": "user", "content": f"hello {x}"}])
assert not DeepDiff(expected, messages)


class _GetWeather(BaseModel):
Expand All @@ -62,24 +65,22 @@ def test_openai(
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
tools = [ToolDefinitionInput(definition=dict(pydantic_function_tool(t))) for t in types_]
prompt = _create_chat_prompt(u, tools=tools)
assert "tools" in prompt
actual = {
t["name"]: t["schema"]["json"]
for t in prompt["tools"]["tools"]
if "schema" in t and "json" in t["schema"]
expected: Mapping[str, ChatCompletionToolParam] = {
t.__name__: cast(
ChatCompletionToolParam, json.loads(json.dumps(pydantic_function_tool(t)))
)
for t in types_
}
assert len(actual) == len(tools)
expected = {
t.definition["function"]["name"]: t.definition["function"]["parameters"]
for t in tools
if "function" in t.definition
and "name" in t.definition["function"]
and "parameters" in t.definition["function"]
tools = [ToolDefinitionInput(definition=dict(v)) for v in expected.values()]
prompt = _create_chat_prompt(u, tools=tools)
_, kwargs = to_chat_messages_and_kwargs(prompt)
assert "tools" in kwargs
actual: dict[str, ChatCompletionToolParam] = {
t["function"]["name"]: t
for t in cast(Iterable[ChatCompletionToolParam], kwargs["tools"])
if t["type"] == "function" and "parameters" in t["function"]
}
assert len(expected) == len(tools)
assert not DeepDiff(actual, expected)
assert not DeepDiff(expected, actual)


class _UIType(str, Enum):
Expand Down Expand Up @@ -121,15 +122,13 @@ def test_openai(
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
response_format = ResponseFormatInput(
definition=dict(cast(ResponseFormatJSONSchema, type_to_response_format_param(type_)))
)
expected = cast(ResponseFormatJSONSchema, type_to_response_format_param(type_))
response_format = ResponseFormatInput(definition=dict(expected))
prompt = _create_chat_prompt(u, response_format=response_format)
assert "response_format" in prompt
assert not DeepDiff(
prompt["response_format"]["schema"]["json"],
response_format.definition["json_schema"]["schema"],
)
_, kwargs = to_chat_messages_and_kwargs(prompt)
assert "response_format" in kwargs
actual = kwargs["response_format"]
assert not DeepDiff(expected, actual)


def _create_chat_prompt(
Expand All @@ -142,7 +141,7 @@ def _create_chat_prompt(
response_format: ResponseFormatInput | None = None,
tools: Sequence[ToolDefinitionInput] = (),
invocation_parameters: Mapping[str, Any] = MappingProxyType({}),
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "FSTRING",
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "NONE",
) -> PromptVersion:
messages = list(messages) or [
PromptMessageInput(
Expand Down

0 comments on commit ce771ec

Please sign in to comment.