Skip to content

Commit

Permalink
ci: tool choice integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Feb 7, 2025
1 parent 5deb5c6 commit e9ed1bc
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,12 @@ def _to_model_kwargs(
if (v := parameters.get("reasoning_effort")) is not None:
if v in ("low", "medium", "high"):
ans["reasoning_effort"] = v
if "tools" in obj and obj["tools"] and (tools := list(_to_tools(obj["tools"]))):
ans["tools"] = tools
if "tools" in obj:
tool_kwargs = _to_tool_kwargs(obj["tools"])
if "tools" in tool_kwargs:
ans["tools"] = tool_kwargs["tools"]
if "tool_choice" in tool_kwargs:
ans["tool_choice"] = tool_kwargs["tool_choice"]
return ans


Expand Down
Empty file.
87 changes: 87 additions & 0 deletions src/phoenix/server/api/helpers/prompts/conversions/anthropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union

from typing_extensions import assert_never

if TYPE_CHECKING:
from anthropic.types import (
ToolChoiceAnyParam,
ToolChoiceAutoParam,
ToolChoiceParam,
ToolChoiceToolParam,
)

from phoenix.server.api.helpers.prompts.models import (
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
PromptToolChoiceZeroOrMore,
)


class AnthropicToolChoiceConversion:
@staticmethod
def to_anthropic(
obj: Union[
PromptToolChoiceZeroOrMore,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
],
disable_parallel_tool_use: Optional[bool] = None,
) -> ToolChoiceParam:
if obj.type == "zero-or-more":
choice_auto: ToolChoiceAutoParam = {"type": "auto"}
if disable_parallel_tool_use is not None:
choice_auto["disable_parallel_tool_use"] = disable_parallel_tool_use
return choice_auto
if obj.type == "one-or-more":
choice_any: ToolChoiceAnyParam = {"type": "any"}
if disable_parallel_tool_use is not None:
choice_any["disable_parallel_tool_use"] = disable_parallel_tool_use
return choice_any
if obj.type == "specific-function-tool":
choice_tool: ToolChoiceToolParam = {"type": "tool", "name": obj.function_name}
if disable_parallel_tool_use is not None:
choice_tool["disable_parallel_tool_use"] = disable_parallel_tool_use
return choice_tool
assert_never(obj.type)

@staticmethod
def from_anthropic(
obj: ToolChoiceParam,
) -> tuple[
Union[
PromptToolChoiceZeroOrMore,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
],
Optional[bool],
]:
from phoenix.server.api.helpers.prompts.models import (
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
PromptToolChoiceZeroOrMore,
)

if obj["type"] == "auto":
disable_parallel_tool_use = (
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
)
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero-or-more")
return choice_zero_or_more, disable_parallel_tool_use
if obj["type"] == "any":
disable_parallel_tool_use = (
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
)
choice_one_or_more = PromptToolChoiceOneOrMore(type="one-or-more")
return choice_one_or_more, disable_parallel_tool_use
if obj["type"] == "tool":
disable_parallel_tool_use = (
obj["disable_parallel_tool_use"] if "disable_parallel_tool_use" in obj else None
)
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
type="specific-function-tool",
function_name=obj["name"],
)
return choice_function_tool, disable_parallel_tool_use
assert_never(obj)
78 changes: 78 additions & 0 deletions src/phoenix/server/api/helpers/prompts/conversions/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union

from typing_extensions import assert_never

if TYPE_CHECKING:
from openai.types.chat import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolChoiceOptionParam,
)
from openai.types.chat.chat_completion_named_tool_choice_param import Function

from phoenix.server.api.helpers.prompts.models import (
PromptToolChoiceNone,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
PromptToolChoiceZeroOrMore,
)


class OpenAIToolChoiceConversion:
@staticmethod
def to_openai(
obj: Union[
PromptToolChoiceNone,
PromptToolChoiceZeroOrMore,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
],
) -> ChatCompletionToolChoiceOptionParam:
if obj.type == "none":
return "none"
if obj.type == "zero-or-more":
return "auto"
if obj.type == "one-or-more":
return "required"
if obj.type == "specific-function-tool":
choice_tool: ChatCompletionNamedToolChoiceParam = {
"type": "function",
"function": {"name": obj.function_name},
}
return choice_tool
assert_never(obj)

@staticmethod
def from_openai(
obj: ChatCompletionToolChoiceOptionParam,
) -> Union[
PromptToolChoiceNone,
PromptToolChoiceZeroOrMore,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
]:
from phoenix.server.api.helpers.prompts.models import (
PromptToolChoiceNone,
PromptToolChoiceOneOrMore,
PromptToolChoiceSpecificFunctionTool,
PromptToolChoiceZeroOrMore,
)

if obj == "none":
choice_none = PromptToolChoiceNone(type="none")
return choice_none
if obj == "auto":
choice_zero_or_more = PromptToolChoiceZeroOrMore(type="zero-or-more")
return choice_zero_or_more
if obj == "required":
choice_one_or_more = PromptToolChoiceOneOrMore(type="one-or-more")
return choice_one_or_more
if obj["type"] == "function":
function: Function = obj["function"]
choice_function_tool = PromptToolChoiceSpecificFunctionTool(
type="specific-function-tool",
function_name=function["name"],
)
return choice_function_tool
assert_never(obj["type"])
23 changes: 20 additions & 3 deletions src/phoenix/server/api/helpers/prompts/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Literal, Optional, Union
from typing import Any, Literal, Mapping, Optional, Union

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated, TypeAlias, assert_never
Expand All @@ -10,6 +10,8 @@
JSONSchemaDraft7ObjectSchema,
JSONSchemaObjectSchema,
)
from phoenix.server.api.helpers.prompts.conversions.anthropic import AnthropicToolChoiceConversion
from phoenix.server.api.helpers.prompts.conversions.openai import OpenAIToolChoiceConversion

JSONSerializable = Union[None, bool, int, float, str, dict[str, Any], list[Any]]

Expand Down Expand Up @@ -349,7 +351,11 @@ class AnthropicToolDefinition(PromptModel):
description: str = UNDEFINED


def normalize_tools(schemas: list[dict[str, Any]], model_provider: str) -> PromptToolsV1:
def normalize_tools(
schemas: list[dict[str, Any]],
model_provider: str,
tool_choice: Optional[Union[str, Mapping[str, Any]]] = None,
) -> PromptToolsV1:
tools: list[PromptFunctionToolV1]
if model_provider.lower() == "openai":
openai_tools = [OpenAIToolDefinition.model_validate(schema) for schema in schemas]
Expand All @@ -359,7 +365,18 @@ def normalize_tools(schemas: list[dict[str, Any]], model_provider: str) -> Promp
tools = [_anthropic_to_prompt_tool(anthropic_tool) for anthropic_tool in anthropic_tools]
else:
raise ValueError(f"Unsupported model provider: {model_provider}")
return PromptToolsV1(type="tools-v1", tools=tools)
ans = PromptToolsV1(type="tools-v1", tools=tools)
if tool_choice is not None:
if model_provider.lower() == "openai":
ans.tool_choice = OpenAIToolChoiceConversion.from_openai(tool_choice) # type: ignore[arg-type]
if model_provider.lower() == "anthropic":
choice, disable_parallel_tool_calls = AnthropicToolChoiceConversion.from_anthropic(
tool_choice # type: ignore[arg-type]
)
ans.tool_choice = choice
if disable_parallel_tool_calls is not None:
ans.disable_parallel_tool_calls = disable_parallel_tool_calls
return ans


def denormalize_tools(tools: PromptToolsV1, model_provider: str) -> list[dict[str, Any]]:
Expand Down
14 changes: 11 additions & 3 deletions src/phoenix/server/api/mutations/prompt_mutations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Mapping, Optional, Union, cast

import strawberry
from fastapi import Request
Expand Down Expand Up @@ -84,9 +84,13 @@ async def create_chat_prompt(

input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input_prompt_version.tools]
tool_choice = cast(
Optional[Union[str, dict[str, Any]]],
cast(Mapping[str, Any], input.prompt_version.invocation_parameters).get("tool_choice"),
)
try:
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
normalize_tools(tool_definitions, input_prompt_version.model_provider, tool_choice)
if tool_definitions
else None
)
Expand Down Expand Up @@ -142,9 +146,13 @@ async def create_chat_prompt_version(

input_prompt_version = input.prompt_version
tool_definitions = [tool.definition for tool in input.prompt_version.tools]
tool_choice = cast(
Optional[Union[str, dict[str, Any]]],
cast(Mapping[str, Any], input.prompt_version.invocation_parameters).get("tool_choice"),
)
try:
tools = (
normalize_tools(tool_definitions, input_prompt_version.model_provider)
normalize_tools(tool_definitions, input_prompt_version.model_provider, tool_choice)
if tool_definitions
else None
)
Expand Down
87 changes: 84 additions & 3 deletions tests/integration/prompts/test_prompts.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
from __future__ import annotations

import json
from enum import Enum
from secrets import token_hex
from types import MappingProxyType
from typing import Any, Literal, Mapping, Sequence, cast
from typing import Any, Iterable, Literal, Mapping, Sequence, cast

import phoenix as px
import pytest
from anthropic.types import (
ToolChoiceAnyParam,
ToolChoiceAutoParam,
ToolChoiceParam,
ToolChoiceToolParam,
ToolParam,
)
from deepdiff.diff import DeepDiff
from openai import pydantic_function_tool
from openai.lib._parsing import type_to_response_format_param
from openai.types.chat import (
ChatCompletionNamedToolChoiceParam,
ChatCompletionToolChoiceOptionParam,
)
from openai.types.shared_params import ResponseFormatJSONSchema
from phoenix.client.__generated__.v1 import PromptVersion
from phoenix.client.utils import to_chat_messages_and_kwargs
Expand All @@ -36,7 +48,7 @@ def test_user_message(
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
prompt = _create_chat_prompt(u)
prompt = _create_chat_prompt(u, template_format="FSTRING")
x = token_hex(4)
messages, _ = to_chat_messages_and_kwargs(prompt, variables={"x": x})
assert not DeepDiff(messages, [{"role": "user", "content": f"hello {x}"}])
Expand All @@ -47,6 +59,11 @@ class _GetWeather(BaseModel):
country: str


class _GetPopulation(BaseModel):
country: str
year: int


class TestTools:
@pytest.mark.parametrize(
"types_",
Expand Down Expand Up @@ -82,6 +99,70 @@ def test_openai(
assert not DeepDiff(actual, expected)


class TestToolChoice:
@pytest.mark.parametrize(
"expected",
[
"none",
"auto",
"required",
ChatCompletionNamedToolChoiceParam(type="function", function={"name": "_GetWeather"}),
],
)
def test_openai(
self,
expected: ChatCompletionToolChoiceOptionParam,
_get_user: _GetUser,
monkeypatch: pytest.MonkeyPatch,
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
tools = [
ToolDefinitionInput(definition=json.loads(json.dumps(pydantic_function_tool(t))))
for t in cast(Iterable[type[BaseModel]], [_GetWeather, _GetPopulation])
]
invocation_parameters = {"tool_choice": expected}
prompt = _create_chat_prompt(u, tools=tools, invocation_parameters=invocation_parameters)
_, kwargs = to_chat_messages_and_kwargs(prompt)
assert "tool_choice" in kwargs
actual = kwargs["tool_choice"]
assert not DeepDiff(expected, actual)

@pytest.mark.parametrize(
"expected",
[
ToolChoiceAutoParam(type="auto"),
ToolChoiceAnyParam(type="any"),
ToolChoiceToolParam(type="tool", name="_GetWeather"),
],
)
def test_anthropic(
self,
expected: ToolChoiceParam,
_get_user: _GetUser,
monkeypatch: pytest.MonkeyPatch,
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
tools = [
ToolDefinitionInput(
definition=dict(ToolParam(name=t.__name__, input_schema=t.model_json_schema()))
)
for t in cast(Iterable[type[BaseModel]], [_GetWeather, _GetPopulation])
]
invocation_parameters = {"tool_choice": expected}
prompt = _create_chat_prompt(
u,
tools=tools,
invocation_parameters=invocation_parameters,
model_provider="ANTHROPIC",
)
_, kwargs = to_chat_messages_and_kwargs(prompt)
assert "tool_choice" in kwargs
actual = kwargs["tool_choice"]
assert not DeepDiff(expected, actual)


class _UIType(str, Enum):
div = "div"
button = "button"
Expand Down Expand Up @@ -142,7 +223,7 @@ def _create_chat_prompt(
response_format: ResponseFormatInput | None = None,
tools: Sequence[ToolDefinitionInput] = (),
invocation_parameters: Mapping[str, Any] = MappingProxyType({}),
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "FSTRING",
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "NONE",
) -> PromptVersion:
messages = list(messages) or [
PromptMessageInput(
Expand Down

0 comments on commit e9ed1bc

Please sign in to comment.