diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 278d5dc5..69968446 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -35,6 +35,7 @@ env: COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }} SERPERDEV_API_KEY: ${{ secrets.SERPERDEV_API_KEY }} + HF_API_TOKEN: ${{ secrets.HF_API_TOKEN }} OLLAMA_LLM_FOR_TESTS: "llama3.2:3b" jobs: diff --git a/docs/pydoc/config/generators_api.yml b/docs/pydoc/config/generators_api.yml index 3ef80214..ed071cd8 100644 --- a/docs/pydoc/config/generators_api.yml +++ b/docs/pydoc/config/generators_api.yml @@ -2,6 +2,7 @@ loaders: - type: haystack_pydoc_tools.loaders.CustomPythonLoader search_path: [../../../] modules: ["haystack_experimental.components.generators.chat.openai", + "haystack_experimental.components.generators.chat.hugging_face_api", "haystack_experimental.components.generators.ollama.chat.chat_generator"] ignore_when_discovered: ["__init__"] processors: diff --git a/haystack_experimental/components/__init__.py b/haystack_experimental/components/__init__.py index a6323214..db85c6e6 100644 --- a/haystack_experimental/components/__init__.py +++ b/haystack_experimental/components/__init__.py @@ -4,7 +4,7 @@ from .extractors import LLMMetadataExtractor -from .generators.chat import OpenAIChatGenerator +from .generators.chat import HuggingFaceAPIChatGenerator, OpenAIChatGenerator from .generators.ollama.chat.chat_generator import OllamaChatGenerator from .retrievers.auto_merging_retriever import AutoMergingRetriever from .retrievers.chat_message_retriever import ChatMessageRetriever @@ -16,6 +16,7 @@ "AutoMergingRetriever", "ChatMessageWriter", "ChatMessageRetriever", + "HuggingFaceAPIChatGenerator", "OllamaChatGenerator", "OpenAIChatGenerator", "LLMMetadataExtractor", diff --git a/haystack_experimental/components/generators/chat/__init__.py b/haystack_experimental/components/generators/chat/__init__.py index 594bd56e..1d3da10d 100644 --- a/haystack_experimental/components/generators/chat/__init__.py +++ b/haystack_experimental/components/generators/chat/__init__.py @@ -6,6 +6,9 @@ OpenAIChatGenerator, ) +from haystack_experimental.components.generators.chat.hugging_face_api import HuggingFaceAPIChatGenerator + __all__ = [ + "HuggingFaceAPIChatGenerator", "OpenAIChatGenerator", ] diff --git a/haystack_experimental/components/generators/chat/hugging_face_api.py b/haystack_experimental/components/generators/chat/hugging_face_api.py new file mode 100644 index 00000000..84fa1a48 --- /dev/null +++ b/haystack_experimental/components/generators/chat/hugging_face_api.py @@ -0,0 +1,351 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + +from haystack import component, default_from_dict, logging +from haystack.dataclasses import StreamingChunk +from haystack.lazy_imports import LazyImport +from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace +from haystack.utils.hf import HFGenerationAPIType + +with LazyImport(message="Run 'pip install \"huggingface_hub[inference]>=0.23.0\"'") as huggingface_hub_import: + from huggingface_hub import ( + ChatCompletionInputTool, + ChatCompletionOutput, + ChatCompletionStreamOutput, + ) + +from haystack.components.generators.chat.hugging_face_api import ( + HuggingFaceAPIChatGenerator as HuggingFaceAPIChatGeneratorBase, +) + +from haystack_experimental.dataclasses import ChatMessage, ToolCall +from haystack_experimental.dataclasses.tool import Tool, deserialize_tools_inplace + +logger = logging.getLogger(__name__) + + +def _convert_message_to_hfapi_format(message: ChatMessage) -> Dict[str, Any]: + """ + Convert a message to the format expected by Hugging Face API. + """ + text_contents = message.texts + tool_calls = message.tool_calls + tool_call_results = message.tool_call_results + + if not text_contents and not tool_calls and not tool_call_results: + raise ValueError("A `ChatMessage` must contain at least one `TextContent`, `ToolCall`, or `ToolCallResult`.") + elif len(text_contents) + len(tool_call_results) > 1: + raise ValueError("A `ChatMessage` can only contain one `TextContent` or one `ToolCallResult`.") + + # HF API always expects a content field, even if it is empty + hfapi_msg: Dict[str, Any] = {"role": message._role.value, "content": ""} + + if tool_call_results: + result = tool_call_results[0] + hfapi_msg["content"] = result.result + if tc_id := result.origin.id: + hfapi_msg["tool_call_id"] = tc_id + # HF API does not provide a way to communicate errors in tool invocations, so we ignore the error field + return hfapi_msg + + if text_contents: + hfapi_msg["content"] = text_contents[0] + if tool_calls: + hfapi_tool_calls = [] + for tc in tool_calls: + hfapi_tool_call = { + "type": "function", + "function": {"name": tc.tool_name, "arguments": tc.arguments}, + } + if tc.id is not None: + hfapi_tool_call["id"] = tc.id + hfapi_tool_calls.append(hfapi_tool_call) + hfapi_msg["tool_calls"] = hfapi_tool_calls + + return hfapi_msg + + +@component +class HuggingFaceAPIChatGenerator(HuggingFaceAPIChatGeneratorBase): + """ + Completes chats using Hugging Face APIs. + + HuggingFaceAPIChatGenerator uses the [ChatMessage](https://docs.haystack.deepset.ai/docs/data-classes#chatmessage) + format for input and output. Use it to generate text with Hugging Face APIs: + - [Free Serverless Inference API](https://huggingface.co/inference-api) + - [Paid Inference Endpoints](https://huggingface.co/inference-endpoints) + - [Self-hosted Text Generation Inference](https://github.com/huggingface/text-generation-inference) + + ### Usage examples + + #### With the free serverless inference API + + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.utils import Secret + from haystack.utils.hf import HFGenerationAPIType + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + # the api_type can be expressed using the HFGenerationAPIType enum or as a string + api_type = HFGenerationAPIType.SERVERLESS_INFERENCE_API + api_type = "serverless_inference_api" # this is equivalent to the above + + generator = HuggingFaceAPIChatGenerator(api_type=api_type, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_token("")) + + result = generator.run(messages) + print(result) + ``` + + #### With paid inference endpoints + + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.dataclasses import ChatMessage + from haystack.utils import Secret + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + generator = HuggingFaceAPIChatGenerator(api_type="inference_endpoints", + api_params={"url": ""}, + token=Secret.from_token("")) + + result = generator.run(messages) + print(result) + + #### With self-hosted text generation inference + + ```python + from haystack.components.generators.chat import HuggingFaceAPIChatGenerator + from haystack.dataclasses import ChatMessage + + messages = [ChatMessage.from_system("\\nYou are a helpful, respectful and honest assistant"), + ChatMessage.from_user("What's Natural Language Processing?")] + + generator = HuggingFaceAPIChatGenerator(api_type="text_generation_inference", + api_params={"url": "http://localhost:8080"}) + + result = generator.run(messages) + print(result) + ``` + """ + + def __init__( + self, + api_type: Union[HFGenerationAPIType, str], + api_params: Dict[str, str], + token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False), + generation_kwargs: Optional[Dict[str, Any]] = None, + stop_words: Optional[List[str]] = None, + streaming_callback: Optional[Callable[[StreamingChunk], None]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Initialize the HuggingFaceAPIChatGenerator instance. + + :param api_type: + The type of Hugging Face API to use. Available types: + - `text_generation_inference`: See [TGI](https://github.com/huggingface/text-generation-inference). + - `inference_endpoints`: See [Inference Endpoints](https://huggingface.co/inference-endpoints). + - `serverless_inference_api`: See [Serverless Inference API](https://huggingface.co/inference-api). + :param api_params: + A dictionary with the following keys: + - `model`: Hugging Face model ID. Required when `api_type` is `SERVERLESS_INFERENCE_API`. + - `url`: URL of the inference endpoint. Required when `api_type` is `INFERENCE_ENDPOINTS` or + `TEXT_GENERATION_INFERENCE`. + :param token: + The Hugging Face token to use as HTTP bearer authorization. + Check your HF token in your [account settings](https://huggingface.co/settings/tokens). + :param generation_kwargs: + A dictionary with keyword arguments to customize text generation. + Some examples: `max_tokens`, `temperature`, `top_p`. + For details, see [Hugging Face chat_completion documentation](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion). + :param stop_words: + An optional list of strings representing the stop words. + :param streaming_callback: + An optional callable for handling streaming responses. + :param tools: + A list of tools for which the model can prepare calls. + The chosen model should support tool/function calling, according to the model card. + Support for tools in the Hugging Face API and TGI is not yet fully refined and you may experience + unexpected behavior. + """ + + # the base class __init__ also checks the hugingface_hub lazy import + super(HuggingFaceAPIChatGenerator, self).__init__( + api_type=api_type, + api_params=api_params, + token=token, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + if tools: + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + self.tools = tools + + if tools and streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + A dictionary containing the serialized component. + """ + serialized = super(HuggingFaceAPIChatGenerator, self).to_dict() + serialized["init_parameters"]["tools"] = [tool.to_dict() for tool in self.tools] if self.tools else None + return serialized + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "HuggingFaceAPIChatGenerator": + """ + Deserialize this component from a dictionary. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + deserialize_tools_inplace(data["init_parameters"], key="tools") + init_params = data.get("init_parameters", {}) + serialized_callback_handler = init_params.get("streaming_callback") + if serialized_callback_handler: + data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler) + + return default_from_dict(cls, data) + + @component.output_types(replies=List[ChatMessage]) + def run( + self, + messages: List[ChatMessage], + generation_kwargs: Optional[Dict[str, Any]] = None, + tools: Optional[List[Tool]] = None, + ): + """ + Invoke the text generation inference based on the provided messages and generation parameters. + + :param messages: + A list of ChatMessage objects representing the input messages. + :param generation_kwargs: + Additional keyword arguments for text generation. + :param tools: + A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set + during component initialization. + :returns: A dictionary with the following keys: + - `replies`: A list containing the generated responses as ChatMessage objects. + """ + + # update generation kwargs by merging with the default ones + generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})} + + formatted_messages = [_convert_message_to_hfapi_format(message) for message in messages] + + tools = tools or self.tools + if tools: + tool_names = [tool.name for tool in tools] + duplicate_tool_names = {name for name in tool_names if tool_names.count(name) > 1} + if duplicate_tool_names: + raise ValueError(f"Duplicate tool names found: {duplicate_tool_names}") + + if tools and self.streaming_callback is not None: + raise ValueError("Using tools and streaming at the same time is not supported. Please choose one.") + + if self.streaming_callback: + return self._run_streaming(formatted_messages, generation_kwargs) + + hf_tools = None + if tools: + hf_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] + + return self._run_non_streaming(formatted_messages, generation_kwargs, hf_tools) + + def _run_streaming(self, messages: List[Dict[str, str]], generation_kwargs: Dict[str, Any]): + api_output: Iterable[ChatCompletionStreamOutput] = self._client.chat_completion( + messages, stream=True, **generation_kwargs + ) + + generated_text = "" + + for chunk in api_output: + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = chunk.choices[0] + + text = choice.delta.content + if text: + generated_text += text + + finish_reason = choice.finish_reason + + meta = {} + if finish_reason: + meta["finish_reason"] = finish_reason + + stream_chunk = StreamingChunk(text, meta) + self.streaming_callback(stream_chunk) + + message = ChatMessage.from_assistant(text=generated_text) + message.meta.update( + { + "model": self._client.model, + "finish_reason": finish_reason, + "index": 0, + "usage": {"prompt_tokens": 0, "completion_tokens": 0}, # not available in streaming + } + ) + + return {"replies": [message]} + + def _run_non_streaming( + self, + messages: List[Dict[str, str]], + generation_kwargs: Dict[str, Any], + tools: Optional[List["ChatCompletionInputTool"]] = None, + ) -> Dict[str, List[ChatMessage]]: + api_chat_output: ChatCompletionOutput = self._client.chat_completion( + messages=messages, tools=tools, **generation_kwargs + ) + + if len(api_chat_output.choices) == 0: + return {"replies": []} + + # n is unused, so the API always returns only one choice + # the argument is probably allowed for compatibility with OpenAI + # see https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.n + choice = api_chat_output.choices[0] + + text = choice.message.content + tool_calls = [] + + if hfapi_tool_calls := choice.message.tool_calls: + for hfapi_tc in hfapi_tool_calls: + tool_call = ToolCall( + tool_name=hfapi_tc.function.name, + arguments=hfapi_tc.function.arguments, + id=hfapi_tc.id, + ) + tool_calls.append(tool_call) + + meta = { + "model": self._client.model, + "finish_reason": choice.finish_reason, + "index": choice.index, + "usage": { + "prompt_tokens": api_chat_output.usage.prompt_tokens, + "completion_tokens": api_chat_output.usage.completion_tokens, + }, + } + + message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta) + return {"replies": [message]} diff --git a/test/components/generators/chat/test_hugging_face_api.py b/test/components/generators/chat/test_hugging_face_api.py new file mode 100644 index 00000000..c244bd8d --- /dev/null +++ b/test/components/generators/chat/test_hugging_face_api.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import MagicMock, Mock, patch + +import pytest +from haystack import Pipeline +from haystack.dataclasses import StreamingChunk +from haystack.utils.auth import Secret +from haystack.utils.hf import HFGenerationAPIType +from huggingface_hub import ( + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, +) +from huggingface_hub.utils import RepositoryNotFoundError + +from haystack_experimental.components.generators.chat.hugging_face_api import ( + HuggingFaceAPIChatGenerator, + _convert_message_to_hfapi_format, +) +from haystack_experimental.dataclasses import ChatMessage, ChatRole, TextContent, Tool, ToolCall + + +@pytest.fixture +def chat_messages(): + return [ + ChatMessage.from_system("You are a helpful assistant speaking A2 level of English"), + ChatMessage.from_user("Tell me about Berlin"), + ] + + +@pytest.fixture +def tools(): + tool_parameters = {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]} + tool = Tool( + name="weather", + description="useful to determine the weather in a given location", + parameters=tool_parameters, + function=lambda x: x, + ) + + return [tool] + + +@pytest.fixture +def mock_check_valid_model(): + with patch( + "haystack.components.generators.chat.hugging_face_api.check_valid_model", MagicMock(return_value=None) + ) as mock: + yield mock + + +@pytest.fixture +def mock_chat_completion(): + # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client#huggingface_hub.InferenceClient.chat_completion.example + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage(content="The capital of France is Paris.", role="assistant"), + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + usage=ChatCompletionOutputUsage(completion_tokens=8, prompt_tokens=17, total_tokens=25), + created=1710498360, + ) + + mock_chat_completion.return_value = completion + yield mock_chat_completion + + +# used to test serialization of streaming_callback +def streaming_callback_handler(x): + return x + + +def test_convert_message_to_hfapi_format(): + message = ChatMessage.from_system("You are good assistant") + assert _convert_message_to_hfapi_format(message) == {"role": "system", "content": "You are good assistant"} + + message = ChatMessage.from_user("I have a question") + assert _convert_message_to_hfapi_format(message) == {"role": "user", "content": "I have a question"} + + message = ChatMessage.from_assistant(text="I have an answer", meta={"finish_reason": "stop"}) + assert _convert_message_to_hfapi_format(message) == {"role": "assistant", "content": "I have an answer"} + + message = ChatMessage.from_assistant( + tool_calls=[ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"})] + ) + assert _convert_message_to_hfapi_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}} + ], + } + + message = ChatMessage.from_assistant(tool_calls=[ToolCall(tool_name="weather", arguments={"city": "Paris"})]) + assert _convert_message_to_hfapi_format(message) == { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "weather", "arguments": {"city": "Paris"}}}], + } + + tool_result = {"weather": "sunny", "temperature": "25"} + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(id="123", tool_name="weather", arguments={"city": "Paris"}) + ) + assert _convert_message_to_hfapi_format(message) == {"role": "tool", "content": tool_result, "tool_call_id": "123"} + + message = ChatMessage.from_tool( + tool_result=tool_result, origin=ToolCall(tool_name="weather", arguments={"city": "Paris"}) + ) + assert _convert_message_to_hfapi_format(message) == {"role": "tool", "content": tool_result} + + +def test_convert_message_to_hfapi_invalid(): + message = ChatMessage(_role=ChatRole.ASSISTANT, _content=[]) + with pytest.raises(ValueError): + _convert_message_to_hfapi_format(message) + + message = ChatMessage( + _role=ChatRole.ASSISTANT, + _content=[TextContent(text="I have an answer"), TextContent(text="I have another answer")], + ) + with pytest.raises(ValueError): + _convert_message_to_hfapi_format(message) + + +class TestHuggingFaceAPIChatGenerator: + def test_init_invalid_api_type(self): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator(api_type="invalid_api_type", api_params={}) + + def test_init_serverless(self, mock_check_valid_model): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_serverless_with_tools(self, mock_check_valid_model, tools): + model = "HuggingFaceH4/zephyr-7b-alpha" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": model}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + tools=tools, + ) + + assert generator.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator.api_params == {"model": model} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools == tools + + def test_init_serverless_invalid_model(self, mock_check_valid_model): + mock_check_valid_model.side_effect = RepositoryNotFoundError("Invalid model id") + with pytest.raises(RepositoryNotFoundError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"model": "invalid_model_id"} + ) + + def test_init_serverless_no_model(self): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, api_params={"param": "irrelevant"} + ) + + def test_init_tgi(self): + url = "https://some_model.com" + generation_kwargs = {"temperature": 0.6} + stop_words = ["stop"] + streaming_callback = None + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, + api_params={"url": url}, + token=None, + generation_kwargs=generation_kwargs, + stop_words=stop_words, + streaming_callback=streaming_callback, + ) + + assert generator.api_type == HFGenerationAPIType.TEXT_GENERATION_INFERENCE + assert generator.api_params == {"url": url} + assert generator.generation_kwargs == {**generation_kwargs, **{"stop": ["stop"]}, **{"max_tokens": 512}} + assert generator.streaming_callback == streaming_callback + assert generator.tools is None + + def test_init_tgi_invalid_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"url": "invalid_url"} + ) + + def test_init_tgi_no_url(self): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.TEXT_GENERATION_INFERENCE, api_params={"param": "irrelevant"} + ) + + def test_init_fail_with_duplicate_tool_names(self, mock_check_valid_model, tools): + duplicate_tools = [tools[0], tools[0]] + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=duplicate_tools, + ) + + def test_init_fail_with_tools_and_streaming(self, mock_check_valid_model, tools): + with pytest.raises(ValueError): + HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "irrelevant"}, + tools=tools, + streaming_callback=streaming_callback_handler, + ) + + def test_to_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + + result = generator.to_dict() + init_params = result["init_parameters"] + + assert init_params["api_type"] == "serverless_inference_api" + assert init_params["api_params"] == {"model": "HuggingFaceH4/zephyr-7b-beta"} + assert init_params["token"] == {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"} + assert init_params["generation_kwargs"] == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert init_params["streaming_callback"] is None + assert init_params["tools"] == [ + { + "description": "description", + "function": "builtins.print", + "name": "name", + "parameters": { + "x": { + "type": "string", + }, + }, + } + ] + + def test_from_dict(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_env_var("ENV_VAR", strict=False), + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + result = generator.to_dict() + + # now deserialize, call from_dict + generator_2 = HuggingFaceAPIChatGenerator.from_dict(result) + assert generator_2.api_type == HFGenerationAPIType.SERVERLESS_INFERENCE_API + assert generator_2.api_params == {"model": "HuggingFaceH4/zephyr-7b-beta"} + assert generator_2.token == Secret.from_env_var("ENV_VAR", strict=False) + assert generator_2.generation_kwargs == {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512} + assert generator_2.streaming_callback is None + assert generator_2.tools == [tool] + + def test_serde_in_pipeline(self, mock_check_valid_model): + tool = Tool(name="name", description="description", parameters={"x": {"type": "string"}}, function=print) + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + token=Secret.from_env_var("ENV_VAR", strict=False), + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + tools=[tool], + ) + + pipeline = Pipeline() + pipeline.add_component("generator", generator) + + pipeline_dict = pipeline.to_dict() + assert pipeline_dict == { + "metadata": {}, + "max_runs_per_component": 100, + "components": { + "generator": { + "type": "haystack_experimental.components.generators.chat.hugging_face_api.HuggingFaceAPIChatGenerator", + "init_parameters": { + "api_type": "serverless_inference_api", + "api_params": {"model": "HuggingFaceH4/zephyr-7b-beta"}, + "token": {"type": "env_var", "env_vars": ["ENV_VAR"], "strict": False}, + "generation_kwargs": {"temperature": 0.6, "stop": ["stop", "words"], "max_tokens": 512}, + "streaming_callback": None, + "tools": [ + { + "name": "name", + "description": "description", + "parameters": {"x": {"type": "string"}}, + "function": "builtins.print", + } + ], + }, + } + }, + "connections": [], + } + + pipeline_yaml = pipeline.dumps() + + new_pipeline = Pipeline.loads(pipeline_yaml) + assert new_pipeline == pipeline + + def test_run(self, mock_check_valid_model, mock_chat_completion, chat_messages): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + generation_kwargs={"temperature": 0.6}, + stop_words=["stop", "words"], + streaming_callback=None, + ) + + response = generator.run(messages=chat_messages) + + # check kwargs passed to chat_completion + _, kwargs = mock_chat_completion.call_args + hf_messages = [ + {"role": "system", "content": "You are a helpful assistant speaking A2 level of English"}, + {"role": "user", "content": "Tell me about Berlin"}, + ] + assert kwargs == { + "temperature": 0.6, + "stop": ["stop", "words"], + "max_tokens": 512, + "tools": None, + "messages": hf_messages, + } + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_with_streaming_callback(self, mock_check_valid_model, mock_chat_completion, chat_messages): + streaming_call_count = 0 + + # Define the streaming callback function + def streaming_callback_fn(chunk: StreamingChunk): + nonlocal streaming_call_count + streaming_call_count += 1 + assert isinstance(chunk, StreamingChunk) + + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_fn, + ) + + # Create a fake streamed response + # self needed here, don't remove + def mock_iter(self): + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content="The", role="assistant"), + index=0, + finish_reason=None, + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + yield ChatCompletionStreamOutput( + choices=[ + ChatCompletionStreamOutputChoice( + delta=ChatCompletionStreamOutputDelta(content=None, role=None), index=0, finish_reason="length" + ) + ], + id="some_id", + model="some_model", + system_fingerprint="some_fingerprint", + created=1710498504, + ) + + mock_response = Mock(**{"__iter__": mock_iter}) + mock_chat_completion.return_value = mock_response + + # Generate text response with streaming callback + response = generator.run(chat_messages) + + # check kwargs passed to text_generation + _, kwargs = mock_chat_completion.call_args + assert kwargs == {"stop": [], "stream": True, "max_tokens": 512} + + # Assert that the streaming callback was called twice + assert streaming_call_count == 2 + + # Assert that the response contains the generated replies + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + + def test_run_fail_with_tools_and_streaming(self, tools): + component = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-2-13b-chat-hf"}, + streaming_callback=streaming_callback_handler, + ) + + with pytest.raises(ValueError): + message = ChatMessage.from_user("irrelevant") + component.run([message], tools=tools) + + def test_run_with_tools(self, mock_check_valid_model, tools): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "meta-llama/Llama-3.1-70B-Instruct"}, + tools=tools, + ) + + with patch("huggingface_hub.InferenceClient.chat_completion", autospec=True) as mock_chat_completion: + completion = ChatCompletionOutput( + choices=[ + ChatCompletionOutputComplete( + finish_reason="stop", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + tool_calls=[ + ChatCompletionOutputToolCall( + function=ChatCompletionOutputFunctionDefinition( + arguments={"city": "Paris"}, name="weather", description=None + ), + id="0", + type="function", + ) + ], + ), + logprobs=None, + ) + ], + created=1729074760, + id="", + model="meta-llama/Llama-3.1-70B-Instruct", + system_fingerprint="2.3.2-dev0-sha-28bb7ae", + usage=ChatCompletionOutputUsage(completion_tokens=30, prompt_tokens=426, total_tokens=456), + ) + mock_chat_completion.return_value = completion + + messages = [ChatMessage.from_user("What is the weather in Paris?")] + response = generator.run(messages=messages) + + assert isinstance(response, dict) + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) == 1 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert response["replies"][0].tool_calls[0].tool_name == "weather" + assert response["replies"][0].tool_calls[0].arguments == {"city": "Paris"} + assert response["replies"][0].tool_calls[0].id == "0" + assert response["replies"][0].meta == { + "finish_reason": "stop", + "index": 0, + "model": "meta-llama/Llama-3.1-70B-Instruct", + "usage": {"completion_tokens": 30, "prompt_tokens": 426}, + } + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_serverless(self): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"max_tokens": 20}, + ) + + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + def test_live_run_serverless_streaming(self): + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"max_tokens": 20}, + streaming_callback=streaming_callback_handler, + ) + + messages = [ChatMessage.from_user("What is the capital of France?")] + response = generator.run(messages=messages) + + assert "replies" in response + assert isinstance(response["replies"], list) + assert len(response["replies"]) > 0 + assert [isinstance(reply, ChatMessage) for reply in response["replies"]] + assert "usage" in response["replies"][0].meta + assert "prompt_tokens" in response["replies"][0].meta["usage"] + assert "completion_tokens" in response["replies"][0].meta["usage"] + + @pytest.mark.integration + @pytest.mark.skipif( + not os.environ.get("HF_API_TOKEN", None), + reason="Export an env var called HF_API_TOKEN containing the Hugging Face token to run this test.", + ) + @pytest.mark.integration + def test_live_run_with_tools(self, tools): + """ + We test the round trip: generate tool call, pass tool message, generate response. + + The model used here (zephyr-7b-beta) is always available and not gated. + Even if it does not officially support tools, TGI+HF API make it work. + """ + + chat_messages = [ChatMessage.from_user("What's the weather like in Paris and Munich?")] + generator = HuggingFaceAPIChatGenerator( + api_type=HFGenerationAPIType.SERVERLESS_INFERENCE_API, + api_params={"model": "HuggingFaceH4/zephyr-7b-beta"}, + generation_kwargs={"temperature": 0.5}, + ) + + results = generator.run(chat_messages, tools=tools) + assert len(results["replies"]) == 1 + message = results["replies"][0] + + assert message.tool_calls + tool_call = message.tool_call + assert isinstance(tool_call, ToolCall) + assert tool_call.tool_name == "weather" + assert "city" in tool_call.arguments + assert "Paris" in tool_call.arguments["city"] + assert message.meta["finish_reason"] == "stop" + + new_messages = chat_messages + [message, ChatMessage.from_tool(tool_result="22° C", origin=tool_call)] + + # the model tends to make tool calls if provided with tools, so we don't pass them here + results = generator.run(new_messages, generation_kwargs={"max_tokens": 50}) + + assert len(results["replies"]) == 1 + final_message = results["replies"][0] + assert not final_message.tool_calls + assert len(final_message.text) > 0 + assert "paris" in final_message.text.lower()