From 23e706bcaf1568ea1318bc127530b77f97525072 Mon Sep 17 00:00:00 2001
From: Yun Kim <yun.kim@datadoghq.com>
Date: Tue, 10 Dec 2024 17:52:18 -0500
Subject: [PATCH] Extract helper

---
 .../internal/openai/_endpoint_hooks.py        |  4 +-
 ddtrace/contrib/internal/openai/utils.py      | 44 +++++++++++----
 ddtrace/llmobs/_integrations/openai.py        |  3 -
 tests/contrib/openai/test_openai_llmobs.py    | 10 +++-
 tests/contrib/openai/test_openai_v1.py        | 56 ++++++++++++++++++-
 5 files changed, 97 insertions(+), 20 deletions(-)

diff --git a/ddtrace/contrib/internal/openai/_endpoint_hooks.py b/ddtrace/contrib/internal/openai/_endpoint_hooks.py
index f7ee114e450..f7c2ceee6b0 100644
--- a/ddtrace/contrib/internal/openai/_endpoint_hooks.py
+++ b/ddtrace/contrib/internal/openai/_endpoint_hooks.py
@@ -255,7 +255,9 @@ def _record_request(self, pin, integration, span, args, kwargs):
                 span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
             span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
             span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
-        if kwargs.get("stream") and not kwargs.get("stream_options", {}).get("include_usage", False):
+        if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
+            if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
+                return
             span._set_ctx_item("openai_stream_magic", True)
             stream_options = kwargs.get("stream_options", {})
             stream_options["include_usage"] = True
diff --git a/ddtrace/contrib/internal/openai/utils.py b/ddtrace/contrib/internal/openai/utils.py
index 75db35b42cd..09061aa823d 100644
--- a/ddtrace/contrib/internal/openai/utils.py
+++ b/ddtrace/contrib/internal/openai/utils.py
@@ -38,6 +38,19 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
         self._is_completion = is_completion
         self._kwargs = kwargs
 
+    def _extract_token_chunk(self, chunk):
+        """Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
+        if not self._dd_span._get_ctx_item("openai_stream_magic"):
+            return
+        choice = getattr(chunk, "choices", [None])[0]
+        if not getattr(choice, "finish_reason", None):
+            return
+        try:
+            usage_chunk = next(self.__wrapped__)
+            self._streamed_chunks[0].insert(0, usage_chunk)
+        except (StopIteration, GeneratorExit):
+            pass
+
 
 class TracedOpenAIStream(BaseTracedOpenAIStream):
     def __enter__(self):
@@ -51,11 +64,7 @@ def __iter__(self):
         exception_raised = False
         try:
             for chunk in self.__wrapped__:
-                if self._dd_span._get_ctx_item("openai_stream_magic"):
-                    choice = getattr(chunk, "choices", [None])[0]
-                    if getattr(choice, "finish_reason", None):
-                        usage_chunk = next(self.__wrapped__)
-                        self._streamed_chunks[0].insert(0, usage_chunk)
+                self._extract_token_chunk(chunk)
                 yield chunk
                 _loop_handler(self._dd_span, chunk, self._streamed_chunks)
         except Exception:
@@ -73,11 +82,7 @@ def __iter__(self):
     def __next__(self):
         try:
             chunk = next(self.__wrapped__)
-            if self._dd_span._get_ctx_item("openai_stream_magic"):
-                choice = getattr(chunk, "choices", [None])[0]
-                if getattr(choice, "finish_reason", None):
-                    usage_chunk = next(self.__wrapped__)
-                    self._streamed_chunks[0].insert(0, usage_chunk)
+            self._extract_token_chunk(chunk)
             _loop_handler(self._dd_span, chunk, self._streamed_chunks)
             return chunk
         except StopIteration:
@@ -103,11 +108,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
         await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)
 
     def __aiter__(self):
-        return self
+        exception_raised = False
+        try:
+            for chunk in self.__wrapped__:
+                self._extract_token_chunk(chunk)
+                yield chunk
+                _loop_handler(self._dd_span, chunk, self._streamed_chunks)
+        except Exception:
+            self._dd_span.set_exc_info(*sys.exc_info())
+            exception_raised = True
+            raise
+        finally:
+            if not exception_raised:
+                _process_finished_stream(
+                    self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
+                )
+            self._dd_span.finish()
+            self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
 
     async def __anext__(self):
         try:
             chunk = await self.__wrapped__.__anext__()
+            self._extract_token_chunk(chunk)
             _loop_handler(self._dd_span, chunk, self._streamed_chunks)
             return chunk
         except StopAsyncIteration:
diff --git a/ddtrace/llmobs/_integrations/openai.py b/ddtrace/llmobs/_integrations/openai.py
index 96cd2dfa071..5c9e73eaca7 100644
--- a/ddtrace/llmobs/_integrations/openai.py
+++ b/ddtrace/llmobs/_integrations/openai.py
@@ -191,9 +191,6 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages
             input_messages.append({"content": str(_get_attr(m, "content", "")), "role": str(_get_attr(m, "role", ""))})
         span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))
 
-        if span._get_ctx_item("openai_stream_magic"):
-            kwargs.pop("stream_options", None)
-
         parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")}
         span.set_tag_str(METADATA, safe_json(parameters))
 
diff --git a/tests/contrib/openai/test_openai_llmobs.py b/tests/contrib/openai/test_openai_llmobs.py
index ddba259e928..75e4506fdeb 100644
--- a/tests/contrib/openai/test_openai_llmobs.py
+++ b/tests/contrib/openai/test_openai_llmobs.py
@@ -518,12 +518,16 @@ async def test_chat_completion_azure_async(
             )
         )
 
+    @pytest.mark.skipif(
+        parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
+    )
     def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
         """Ensure llmobs records are emitted for chat completion endpoints when configured.
 
         Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
         """
-        with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):
+
+        with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
             with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
                 with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
                     mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
@@ -547,8 +551,8 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
                 model_provider="openai",
                 input_messages=input_messages,
                 output_messages=[{"content": expected_completion, "role": "assistant"}],
-                metadata={"stream": True, "user": "ddtrace-test"},
-                token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
+                metadata={"stream": True, "stream_options": {"include_usage": True}, "user": "ddtrace-test"},
+                token_metrics={"input_tokens": 17, "output_tokens": 19, "total_tokens": 36},
                 tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
             )
         )
diff --git a/tests/contrib/openai/test_openai_v1.py b/tests/contrib/openai/test_openai_v1.py
index f13de144fc5..57fecab55ca 100644
--- a/tests/contrib/openai/test_openai_v1.py
+++ b/tests/contrib/openai/test_openai_v1.py
@@ -1042,7 +1042,59 @@ def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, moc
     assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls
 
 
+@pytest.mark.skipif(
+    parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
+)
 def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer):
+    with openai_vcr.use_cassette("chat_completion_streamed_tokens.yaml"):
+        with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
+            mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
+            expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
+            client = openai.OpenAI()
+            resp = client.chat.completions.create(
+                model="gpt-3.5-turbo",
+                messages=[
+                    {"role": "user", "content": "Who won the world series in 2020?"},
+                ],
+                stream=True,
+                user="ddtrace-test",
+                n=None,
+            )
+            span = snapshot_tracer.current_span()
+            chunks = [c for c in resp]
+            assert len(chunks) == 15
+            completion = "".join([c.choices[0].delta.content for c in chunks if c.choices[0].delta.content is not None])
+            assert completion == expected_completion
+
+    assert span.get_tag("openai.response.choices.0.message.content") == expected_completion
+    assert span.get_tag("openai.response.choices.0.message.role") == "assistant"
+    assert span.get_tag("openai.response.choices.0.finish_reason") == "stop"
+
+    expected_tags = [
+        "version:",
+        "env:",
+        "service:tests.contrib.openai",
+        "openai.request.model:gpt-3.5-turbo",
+        "model:gpt-3.5-turbo",
+        "openai.request.endpoint:/v1/chat/completions",
+        "openai.request.method:POST",
+        "openai.organization.id:",
+        "openai.organization.name:datadog-4",
+        "openai.user.api_key:sk-...key>",
+        "error:0",
+    ]
+    assert mock.call.distribution("request.duration", span.duration_ns, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.gauge("ratelimit.requests", 3000, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.distribution("tokens.prompt", 17, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.distribution("tokens.completion", 19, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.distribution("tokens.total", 36, tags=expected_tags) in mock_metrics.mock_calls
+
+
+@pytest.mark.skipif(
+    parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
+)
+def test_chat_completion_stream_explicit_no_tokens(openai, openai_vcr, mock_metrics, snapshot_tracer):
     with openai_vcr.use_cassette("chat_completion_streamed.yaml"):
         with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
             mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
@@ -1054,10 +1106,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
                     {"role": "user", "content": "Who won the world series in 2020?"},
                 ],
                 stream=True,
+                stream_options={"include_usage": False},
                 user="ddtrace-test",
                 n=None,
             )
-            prompt_tokens = 8
             span = snapshot_tracer.current_span()
             chunks = [c for c in resp]
             assert len(chunks) == 15
@@ -1087,7 +1139,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
     expected_tags += ["openai.estimated:true"]
     if TIKTOKEN_AVAILABLE:
         expected_tags = expected_tags[:-1]
-    assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls
+    assert mock.call.distribution("tokens.prompt", 8, tags=expected_tags) in mock_metrics.mock_calls
     assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls
     assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls