diff --git a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py index 6cea28d0f2f5..07f3fe74802c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/semantic_kernel/_sk_chat_completion_adapter.py @@ -1,5 +1,6 @@ import json from typing import Any, Literal, Mapping, Optional, Sequence +import warnings from autogen_core import FunctionCall from autogen_core._cancellation_token import CancellationToken @@ -18,7 +19,6 @@ from semantic_kernel.contents.chat_history import ChatHistory from semantic_kernel.contents.chat_message_content import ChatMessageContent from semantic_kernel.contents.function_call_content import FunctionCallContent -from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent from semantic_kernel.functions.kernel_plugin import KernelPlugin from semantic_kernel.kernel import Kernel from typing_extensions import AsyncGenerator, Union @@ -427,6 +427,28 @@ async def create( thought=thought, ) + @staticmethod + def _merge_function_call_content(existing_call: FunctionCallContent, new_chunk: FunctionCallContent) -> None: + """Helper to merge partial argument chunks from new_chunk into existing_call.""" + if isinstance(existing_call.arguments, str) and isinstance(new_chunk.arguments, str): + existing_call.arguments += new_chunk.arguments + elif isinstance(existing_call.arguments, dict) and isinstance(new_chunk.arguments, dict): + existing_call.arguments.update(new_chunk.arguments) + elif not existing_call.arguments or existing_call.arguments in ("{}", ""): + # If existing had no arguments yet, just take the new one + existing_call.arguments = new_chunk.arguments + else: + # If there's a mismatch (str vs dict), handle as needed + warnings.warn("Mismatch in argument types during merge. Existing arguments retained.", stacklevel=2) + + # Optionally update name/function_name if newly provided + if new_chunk.name: + existing_call.name = new_chunk.name + if new_chunk.plugin_name: + existing_call.plugin_name = new_chunk.plugin_name + if new_chunk.function_name: + existing_call.function_name = new_chunk.function_name + async def create_stream( self, messages: Sequence[LLMMessage], @@ -460,6 +482,7 @@ async def create_stream( Yields: Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls. """ + kernel = self._get_kernel(extra_create_args) chat_history = self._convert_to_chat_history(messages) user_settings = self._get_prompt_settings(extra_create_args) @@ -468,54 +491,105 @@ async def create_stream( prompt_tokens = 0 completion_tokens = 0 - accumulated_content = "" + accumulated_text = "" + + # Keep track of in-progress function calls. Keyed by ID + # because partial chunks for the same function call might arrive separately. + function_calls_in_progress: dict[str, FunctionCallContent] = {} + + # Track the ID of the last function call we saw so we can continue + # accumulating chunk arguments for that call if new items have id=None + last_function_call_id: Optional[str] = None async for streaming_messages in self._sk_client.get_streaming_chat_message_contents( chat_history, settings=settings, kernel=kernel ): for msg in streaming_messages: - if not isinstance(msg, StreamingChatMessageContent): - continue - # Track token usage if msg.metadata and "usage" in msg.metadata: usage = msg.metadata["usage"] prompt_tokens = getattr(usage, "prompt_tokens", 0) completion_tokens = getattr(usage, "completion_tokens", 0) - # Check for function calls - if any(isinstance(item, FunctionCallContent) for item in msg.items): - function_calls = self._process_tool_calls(msg) + # Process function call deltas + for item in msg.items: + if isinstance(item, FunctionCallContent): + # If the chunk has a valid ID, we start or continue that ID explicitly + if item.id: + last_function_call_id = item.id + if last_function_call_id not in function_calls_in_progress: + function_calls_in_progress[last_function_call_id] = item + else: + # Merge partial arguments into existing call + existing_call = function_calls_in_progress[last_function_call_id] + self._merge_function_call_content(existing_call, item) + else: + # item.id is None, so we assume it belongs to the last known ID + if not last_function_call_id: + # No call in progress means we can't merge + # You could either skip or raise an error here + warnings.warn( + "Received function call chunk with no ID and no call in progress.", stacklevel=2 + ) + continue + + existing_call = function_calls_in_progress[last_function_call_id] + # Merge partial chunk + self._merge_function_call_content(existing_call, item) + + # Check if the model signaled tool_calls finished + if msg.finish_reason == "tool_calls" and function_calls_in_progress: + calls_to_yield: list[FunctionCall] = [] + for _, call_content in function_calls_in_progress.items(): + plugin_name = call_content.plugin_name or "" + function_name = call_content.function_name + if plugin_name: + full_name = f"{plugin_name}-{function_name}" + else: + full_name = function_name + + if isinstance(call_content.arguments, dict): + arguments = json.dumps(call_content.arguments) + else: + assert isinstance(call_content.arguments, str) + arguments = call_content.arguments or "{}" + + calls_to_yield.append( + FunctionCall( + id=call_content.id or "unknown_id", + name=full_name, + arguments=arguments, + ) + ) + # Yield all function calls in progress yield CreateResult( - content=function_calls, + content=calls_to_yield, finish_reason="function_calls", usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), cached=False, ) return - # Handle text content + # Handle any plain text in the message if msg.content: - accumulated_content += msg.content + accumulated_text += msg.content yield msg.content - # Final yield if there was text content - if accumulated_content: - self._total_prompt_tokens += prompt_tokens - self._total_completion_tokens += completion_tokens - - if isinstance(accumulated_content, str) and self._model_info["family"] == ModelFamily.R1: - thought, accumulated_content = parse_r1_content(accumulated_content) - else: - thought = None - - yield CreateResult( - content=accumulated_content, - finish_reason="stop", - usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), - cached=False, - thought=thought, - ) + # If we exit the loop without tool calls finishing, yield whatever text was accumulated + self._total_prompt_tokens += prompt_tokens + self._total_completion_tokens += completion_tokens + + thought = None + if isinstance(accumulated_text, str) and self._model_info["family"] == ModelFamily.R1: + thought, accumulated_text = parse_r1_content(accumulated_text) + + yield CreateResult( + content=accumulated_text, + finish_reason="stop", + usage=RequestUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), + cached=False, + thought=thought, + ) def actual_usage(self) -> RequestUsage: return RequestUsage(prompt_tokens=self._total_prompt_tokens, completion_tokens=self._total_completion_tokens) diff --git a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py index 1b5a6ea03fa5..ce602d8fbad2 100644 --- a/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py +++ b/python/packages/autogen-ext/tests/models/test_sk_chat_completion_adapter.py @@ -7,7 +7,13 @@ from autogen_core.models import CreateResult, LLMMessage, ModelFamily, ModelInfo, SystemMessage, UserMessage from autogen_core.tools import BaseTool from autogen_ext.models.semantic_kernel import SKChatCompletionAdapter -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta +from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, +) from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel from semantic_kernel.connectors.ai.open_ai.services.azure_chat_completion import AzureChatCompletion @@ -72,7 +78,7 @@ async def mock_get_chat_message_contents( id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", plugin_name=None, - arguments="{}", + arguments='{"a": 2, "b": 2}', ) ], finish_reason=FinishReason.TOOL_CALLS, @@ -96,30 +102,89 @@ async def mock_get_streaming_chat_message_contents( **kwargs: Any, ) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: if "What is 2 + 2?" in str(chat_history): - # Mock response for calculator tool test - single message with function call + # Initial chunk with function call setup yield [ StreamingChatMessageContent( choice_index=0, - inner_content=None, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[ + Choice( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_UwVVI0iGEmcPwmKUigJcuuuF", + function=ChoiceDeltaToolCallFunction(name="calculator", arguments=""), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), ai_model_id="gpt-4o-mini", - metadata={ - "logprobs": None, - "id": "chatcmpl-AooRjGxKtdTke46keWkBQBKg033XW", - "created": 1736673679, - "usage": {"prompt_tokens": 53, "completion_tokens": 13}, - }, role=AuthorRole.ASSISTANT, - items=[ # type: ignore + items=[ FunctionCallContent( - id="call_n8135GXc2kbiaaDdpImsB1VW", - function_name="calculator", - plugin_name=None, - arguments="", - content_type="function_call", # type: ignore + id="call_UwVVI0iGEmcPwmKUigJcuuuF", function_name="calculator", arguments="" ) ], - finish_reason=None, - function_invoke_attempt=0, + ) + ] + + # Arguments chunks + for arg_chunk in ["{", '"a"', ":", " ", "2", ",", " ", '"b"', ":", " ", "2", "}"]: + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, function=ChoiceDeltaToolCallFunction(arguments=arg_chunk) + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[FunctionCallContent(function_name="calculator", arguments=arg_chunk)], + ) + ] + + # Final chunk with finish reason + yield [ + StreamingChatMessageContent( # type: ignore + choice_index=0, + inner_content=ChatCompletionChunk( + id="chatcmpl-123", + choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1736673679, + model="gpt-4o-mini", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=53, completion_tokens=13, total_tokens=66), + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + finish_reason=FinishReason.TOOL_CALLS, + metadata={"usage": {"prompt_tokens": 53, "completion_tokens": 13}}, ) ] else: @@ -449,3 +514,217 @@ async def mock_get_streaming_chat_message_contents( assert response_chunks[-1].finish_reason == "stop" assert response_chunks[-1].content == "Hello!" assert response_chunks[-1].thought == "Reasoning..." + + +@pytest.mark.asyncio +async def test_sk_chat_completion_stream_with_multiple_function_calls() -> None: + """ + This test returns two distinct function calls via streaming, each one arriving in pieces. + We intentionally set name, plugin_name, and function_name in the later partial chunks so + that _merge_function_call_content is triggered to update them. + """ + + async def mock_get_streaming_chat_message_contents( + chat_history: ChatHistory, + settings: PromptExecutionSettings, + **kwargs: Any, + ) -> AsyncGenerator[list["StreamingChatMessageContent"], Any]: + # First partial chunk for call_1 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-1", + choices=[ + Choice( + delta=ChoiceDelta( + role="assistant", + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_1", + function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg1":'), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_1", + # no plugin_name/function_name yet + name=None, + arguments='{"arg1":', + ) + ], + ) + ] + # Second partial chunk for call_1 (updates plugin_name/function_name) + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-2", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + function=ChoiceDeltaToolCallFunction( + # Provide the rest of the arguments + arguments='"value1"}', + name="firstFunction", + ), + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_1", plugin_name="myPlugin", function_name="firstFunction", arguments='"value1"}' + ) + ], + ) + ] + # Now partial chunk for a second call, call_2 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-3", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, + id="call_2", + function=ChoiceDeltaToolCallFunction(name=None, arguments='{"arg2":"another"}'), + type="function", + ) + ], + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[FunctionCallContent(id="call_2", arguments='{"arg2":"another"}')], + ) + ] + # Next partial chunk updates name, plugin_name, function_name for call_2 + yield [ + StreamingChatMessageContent( + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-4", + choices=[ + Choice( + delta=ChoiceDelta( + tool_calls=[ + ChoiceDeltaToolCall( + index=0, function=ChoiceDeltaToolCallFunction(name="secondFunction") + ) + ] + ), + finish_reason=None, + index=0, + ) + ], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + items=[ + FunctionCallContent( + id="call_2", + name="someFancyName", + plugin_name="anotherPlugin", + function_name="secondFunction", + ) + ], + ) + ] + # Final chunk signals finish with tool_calls + yield [ + StreamingChatMessageContent( # type: ignore + choice_index=0, + inner_content=ChatCompletionChunk( + id="chunk-id-5", + choices=[Choice(delta=ChoiceDelta(), finish_reason="tool_calls", index=0)], + created=1736679999, + model="gpt-4o-mini", + object="chat.completion.chunk", + usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15), + ), + ai_model_id="gpt-4o-mini", + role=AuthorRole.ASSISTANT, + finish_reason=FinishReason.TOOL_CALLS, + metadata={"usage": {"prompt_tokens": 10, "completion_tokens": 5}}, + ) + ] + + # Mock SK client + mock_client = AsyncMock(spec=AzureChatCompletion) + mock_client.get_streaming_chat_message_contents = mock_get_streaming_chat_message_contents + + # Create adapter and kernel + kernel = Kernel(memory=NullMemory()) + adapter = SKChatCompletionAdapter(mock_client, kernel=kernel) + + # Call create_stream with no actual tools (we just test the multiple calls) + messages: list[LLMMessage] = [ + SystemMessage(content="You are a helpful assistant."), + UserMessage(content="Call two different plugin functions", source="user"), + ] + + # Collect streaming outputs + response_chunks: list[CreateResult | str] = [] + async for chunk in adapter.create_stream(messages=messages): + response_chunks.append(chunk) + + # The final chunk should be a CreateResult with function_calls + assert len(response_chunks) > 0 + final_chunk = response_chunks[-1] + assert isinstance(final_chunk, CreateResult) + assert final_chunk.finish_reason == "function_calls" + assert isinstance(final_chunk.content, list) + assert len(final_chunk.content) == 2 # We expect 2 calls + + # Verify first call merged name + arguments + first_call = final_chunk.content[0] + assert first_call.id == "call_1" + assert first_call.name == "myPlugin-firstFunction" # pluginName-functionName + assert '{"arg1":"value1"}' in first_call.arguments + + # Verify second call also merged everything + second_call = final_chunk.content[1] + assert second_call.id == "call_2" + assert second_call.name == "anotherPlugin-secondFunction" + assert '{"arg2":"another"}' in second_call.arguments