diff --git a/ddtrace/contrib/internal/openai/_endpoint_hooks.py b/ddtrace/contrib/internal/openai/_endpoint_hooks.py index f7ee114e450..f7c2ceee6b0 100644 --- a/ddtrace/contrib/internal/openai/_endpoint_hooks.py +++ b/ddtrace/contrib/internal/openai/_endpoint_hooks.py @@ -255,7 +255,9 @@ def _record_request(self, pin, integration, span, args, kwargs): span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content))) span.set_tag_str("openai.request.messages.%d.role" % idx, str(role)) span.set_tag_str("openai.request.messages.%d.name" % idx, str(name)) - if kwargs.get("stream") and not kwargs.get("stream_options", {}).get("include_usage", False): + if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"): + if kwargs.get("stream_options", {}).get("include_usage", None) is not None: + return span._set_ctx_item("openai_stream_magic", True) stream_options = kwargs.get("stream_options", {}) stream_options["include_usage"] = True diff --git a/ddtrace/contrib/internal/openai/utils.py b/ddtrace/contrib/internal/openai/utils.py index 75db35b42cd..09061aa823d 100644 --- a/ddtrace/contrib/internal/openai/utils.py +++ b/ddtrace/contrib/internal/openai/utils.py @@ -38,6 +38,19 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False): self._is_completion = is_completion self._kwargs = kwargs + def _extract_token_chunk(self, chunk): + """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" + if not self._dd_span._get_ctx_item("openai_stream_magic"): + return + choice = getattr(chunk, "choices", [None])[0] + if not getattr(choice, "finish_reason", None): + return + try: + usage_chunk = next(self.__wrapped__) + self._streamed_chunks[0].insert(0, usage_chunk) + except (StopIteration, GeneratorExit): + pass + class TracedOpenAIStream(BaseTracedOpenAIStream): def __enter__(self): @@ -51,11 +64,7 @@ def __iter__(self): exception_raised = False try: for chunk in self.__wrapped__: - if self._dd_span._get_ctx_item("openai_stream_magic"): - choice = getattr(chunk, "choices", [None])[0] - if getattr(choice, "finish_reason", None): - usage_chunk = next(self.__wrapped__) - self._streamed_chunks[0].insert(0, usage_chunk) + self._extract_token_chunk(chunk) yield chunk _loop_handler(self._dd_span, chunk, self._streamed_chunks) except Exception: @@ -73,11 +82,7 @@ def __iter__(self): def __next__(self): try: chunk = next(self.__wrapped__) - if self._dd_span._get_ctx_item("openai_stream_magic"): - choice = getattr(chunk, "choices", [None])[0] - if getattr(choice, "finish_reason", None): - usage_chunk = next(self.__wrapped__) - self._streamed_chunks[0].insert(0, usage_chunk) + self._extract_token_chunk(chunk) _loop_handler(self._dd_span, chunk, self._streamed_chunks) return chunk except StopIteration: @@ -103,11 +108,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) def __aiter__(self): - return self + exception_raised = False + try: + for chunk in self.__wrapped__: + self._extract_token_chunk(chunk) + yield chunk + _loop_handler(self._dd_span, chunk, self._streamed_chunks) + except Exception: + self._dd_span.set_exc_info(*sys.exc_info()) + exception_raised = True + raise + finally: + if not exception_raised: + _process_finished_stream( + self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion + ) + self._dd_span.finish() + self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) async def __anext__(self): try: chunk = await self.__wrapped__.__anext__() + self._extract_token_chunk(chunk) _loop_handler(self._dd_span, chunk, self._streamed_chunks) return chunk except StopAsyncIteration: diff --git a/ddtrace/llmobs/_integrations/openai.py b/ddtrace/llmobs/_integrations/openai.py index 96cd2dfa071..5c9e73eaca7 100644 --- a/ddtrace/llmobs/_integrations/openai.py +++ b/ddtrace/llmobs/_integrations/openai.py @@ -191,9 +191,6 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages input_messages.append({"content": str(_get_attr(m, "content", "")), "role": str(_get_attr(m, "role", ""))}) span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - if span._get_ctx_item("openai_stream_magic"): - kwargs.pop("stream_options", None) - parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")} span.set_tag_str(METADATA, safe_json(parameters)) diff --git a/tests/contrib/openai/test_openai_llmobs.py b/tests/contrib/openai/test_openai_llmobs.py index ddba259e928..75e4506fdeb 100644 --- a/tests/contrib/openai/test_openai_llmobs.py +++ b/tests/contrib/openai/test_openai_llmobs.py @@ -518,12 +518,16 @@ async def test_chat_completion_azure_async( ) ) + @pytest.mark.skipif( + parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26" + ) def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer): """Ensure llmobs records are emitted for chat completion endpoints when configured. Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation. """ - with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"): + + with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"): with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] @@ -547,8 +551,8 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs model_provider="openai", input_messages=input_messages, output_messages=[{"content": expected_completion, "role": "assistant"}], - metadata={"stream": True, "user": "ddtrace-test"}, - token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16}, + metadata={"stream": True, "stream_options": {"include_usage": True}, "user": "ddtrace-test"}, + token_metrics={"input_tokens": 17, "output_tokens": 19, "total_tokens": 36}, tags={"ml_app": "", "service": "tests.contrib.openai"}, ) ) diff --git a/tests/contrib/openai/test_openai_v1.py b/tests/contrib/openai/test_openai_v1.py index f13de144fc5..57fecab55ca 100644 --- a/tests/contrib/openai/test_openai_v1.py +++ b/tests/contrib/openai/test_openai_v1.py @@ -1042,7 +1042,59 @@ def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, moc assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls +@pytest.mark.skipif( + parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26" +) def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer): + with openai_vcr.use_cassette("chat_completion_streamed_tokens.yaml"): + with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: + mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] + expected_completion = "The Los Angeles Dodgers won the World Series in 2020." + client = openai.OpenAI() + resp = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "user", "content": "Who won the world series in 2020?"}, + ], + stream=True, + user="ddtrace-test", + n=None, + ) + span = snapshot_tracer.current_span() + chunks = [c for c in resp] + assert len(chunks) == 15 + completion = "".join([c.choices[0].delta.content for c in chunks if c.choices[0].delta.content is not None]) + assert completion == expected_completion + + assert span.get_tag("openai.response.choices.0.message.content") == expected_completion + assert span.get_tag("openai.response.choices.0.message.role") == "assistant" + assert span.get_tag("openai.response.choices.0.finish_reason") == "stop" + + expected_tags = [ + "version:", + "env:", + "service:tests.contrib.openai", + "openai.request.model:gpt-3.5-turbo", + "model:gpt-3.5-turbo", + "openai.request.endpoint:/v1/chat/completions", + "openai.request.method:POST", + "openai.organization.id:", + "openai.organization.name:datadog-4", + "openai.user.api_key:sk-...key>", + "error:0", + ] + assert mock.call.distribution("request.duration", span.duration_ns, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.gauge("ratelimit.requests", 3000, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.prompt", 17, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", 19, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", 36, tags=expected_tags) in mock_metrics.mock_calls + + +@pytest.mark.skipif( + parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26" +) +def test_chat_completion_stream_explicit_no_tokens(openai, openai_vcr, mock_metrics, snapshot_tracer): with openai_vcr.use_cassette("chat_completion_streamed.yaml"): with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding: mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] @@ -1054,10 +1106,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace {"role": "user", "content": "Who won the world series in 2020?"}, ], stream=True, + stream_options={"include_usage": False}, user="ddtrace-test", n=None, ) - prompt_tokens = 8 span = snapshot_tracer.current_span() chunks = [c for c in resp] assert len(chunks) == 15 @@ -1087,7 +1139,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace expected_tags += ["openai.estimated:true"] if TIKTOKEN_AVAILABLE: expected_tags = expected_tags[:-1] - assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.prompt", 8, tags=expected_tags) in mock_metrics.mock_calls assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls