Skip to content

Commit

Permalink
ci: add integration test for anthropic tools
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerHYang committed Feb 6, 2025
1 parent 6f6ae97 commit 51fac39
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions tests/integration/prompts/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from enum import Enum
from secrets import token_hex
from types import MappingProxyType
from typing import Any, Literal, Mapping, Sequence, cast
from typing import Any, Iterable, Literal, Mapping, Sequence, cast

import phoenix as px
import pytest
from anthropic.types import ToolParam
from deepdiff.diff import DeepDiff
from openai import pydantic_function_tool
from openai.lib._parsing import type_to_response_format_param
Expand Down Expand Up @@ -36,7 +37,7 @@ def test_user_message(
) -> None:
u = _get_user(_MEMBER).log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
prompt = _create_chat_prompt(u)
prompt = _create_chat_prompt(u, template_format="FSTRING")
x = token_hex(4)
messages, _ = to_chat_messages_and_kwargs(prompt, variables={"x": x})
assert not DeepDiff(messages, [{"role": "user", "content": f"hello {x}"}])
Expand Down Expand Up @@ -81,6 +82,34 @@ def test_openai(
assert len(expected) == len(tools)
assert not DeepDiff(actual, expected)

@pytest.mark.parametrize(
"types_",
[
[_GetWeather],
],
)
def test_anthropic(
self,
types_: Sequence[type[BaseModel]],
_get_user: _GetUser,
monkeypatch: pytest.MonkeyPatch,
) -> None:
u = _get_user().log_in()
monkeypatch.setenv("PHOENIX_API_KEY", u.create_api_key())
expected: dict[str, ToolParam] = {
t.__name__: ToolParam(
name=t.__name__,
input_schema=t.model_json_schema(),
)
for t in types_
}
tools = [ToolDefinitionInput(definition=dict(v)) for v in expected.values()]
prompt = _create_chat_prompt(u, tools=tools, model_provider="ANTHROPIC")
_, kwargs = to_chat_messages_and_kwargs(prompt)
assert "tools" in kwargs
actual = {t["name"]: t for t in cast(Iterable[ToolParam], kwargs["tools"])}
assert not DeepDiff(expected, actual)


class _UIType(str, Enum):
div = "div"
Expand Down Expand Up @@ -142,7 +171,7 @@ def _create_chat_prompt(
response_format: ResponseFormatInput | None = None,
tools: Sequence[ToolDefinitionInput] = (),
invocation_parameters: Mapping[str, Any] = MappingProxyType({}),
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "FSTRING",
template_format: Literal["FSTRING", "MUSTACHE", "NONE"] = "NONE",
) -> PromptVersion:
messages = list(messages) or [
PromptMessageInput(
Expand Down

0 comments on commit 51fac39

Please sign in to comment.