Skip to content

Commit

Permalink
feat(openai): add token usage stream options to request (#11606)
Browse files Browse the repository at this point in the history
This PR adds special casing such that any user's openai streamed
chat/completion requests, unless explicitly specified otherwise, will by
default include the token usage as part of the streamed response.

### Motivation
OpenAI streamed responses have historically not provided token usage
details as part of the streamed response. However OpenAI earlier this
year added a `stream_options: {"include_usage": True}` kwarg option to
the chat/completions API to provide token usage details as part of an
additional stream chunk at the end of the streamed response.

If this kwarg option was not specified by the user, then token usage is
not provided by OpenAI and our current behavior is to give our best
effort to 1) use the `tiktoken` library to calculate token counts, or 2)
use a very crude heuristic to estimate token counts. Both are not ideal
as neither alternative takes into account function/tool calling. **It is
simpler and more accurate to just request the token counts from OpenAI
directly.**

### Proposed design
There are 2 major components for this feature:
1. If a user does not specify `stream_options: {"include_usage": True}`
as a kwarg on the chat/completions call, we need to manually insert that
as part of the kwargs before the request is made.
2. If a user does not specify `stream_options: {"include_usage": True}`
as a kwarg on the chat/completions call but we add that option on the
integration-side, the returned streamed response will include an
additional chunk (with empty content) at the end containing token usage
information. To avoid disrupting user applications with one more chunk
(with different content/fields) than expected, the integration should
automatically extract the last chunk under the hood.

Note: if a user does explicitly specify `stream_options:
{"include_usage": False}`, then we must respect their intent and avoid
adding token usage into the kwargs. We'll add in our release note that
we cannot guarantee 100% accurate token counts in this case.`

### Streamed reading logic change

Additionally, we make a change to `__iter__/__aiter__` methods of our
traced streamed responses. Previously we returned the traced streamed
response (and relied on the underlying `__next__/__anext__` methods),
but to ensure spans will be finished even if the streamed response is
not fully consumed, we change the `__iter__/__aiter__` methods to
implement the stream consumption using a try/catch/finally.

Note: this only applies to 
1. When users use `__iter__/__aiter__()`, since directly calling
`__next__()/__anext__()` individually will not let us know when the
overall response is fully consumed.
2. When users use `__aiter__()` and break early, they are still
responsible for calling `resp.close()`, since asynchronous generators do
not automatically close when the context manager is exited (this is held
until close() is called either manually or by the garbage collector).

### Testing

This PR modifies the existing OpenAI streamed completion/chat completion
tests to be simplified (use snapshots when possible instead of making
large numbers of tedious assertions) and to add coverage for the token
extraction behavior (existing tests remove `include_usage: True` options
to assert that the automatic extraction works, and we add a couple tests
asserting our original behavior if `include_usage: False` is explicitly
set).

## Checklist
- [x] PR author has checked that all the criteria below are met
- The PR description includes an overview of the change
- The PR description articulates the motivation for the change
- The change includes tests OR the PR description describes a testing
strategy
- The PR description notes risks associated with the change, if any
- Newly-added code is easy to change
- The change follows the [library release note
guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html)
- The change includes or references documentation updates if necessary
- Backport labels are set (if
[applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting))

## Reviewer Checklist
- [x] Reviewer has checked that all the criteria below are met 
- Title is accurate
- All changes are related to the pull request's stated goal
- Avoids breaking
[API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces)
changes
- Testing strategy adequately addresses listed risks
- Newly-added code is easy to change
- Release note makes sense to a user of the library
- If necessary, author has acknowledged and discussed the performance
implications of this PR as reported in the benchmarks PR comment
- Backport labels are set in a manner that is consistent with the
[release branch maintenance
policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
  • Loading branch information
Yun-Kim authored Jan 9, 2025
1 parent 5581f73 commit 35fe7b5
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 237 deletions.
8 changes: 8 additions & 0 deletions ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,14 @@ def _record_request(self, pin, integration, span, args, kwargs):
span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
# Only perform token chunk auto-extraction if this option is not explicitly set
return
span._set_ctx_item("_dd.auto_extract_token_chunk", True)
stream_options = kwargs.get("stream_options", {})
stream_options["include_usage"] = True
kwargs["stream_options"] = stream_options

def _record_response(self, pin, integration, span, args, kwargs, resp, error):
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
Expand Down
69 changes: 66 additions & 3 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,28 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)

def __iter__(self):
return self
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

def __next__(self):
try:
chunk = self.__wrapped__.__next__()
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopIteration:
Expand All @@ -68,6 +85,22 @@ def __next__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
# Only the second-last chunk in the stream with token usage enabled will have finish_reason set
return
try:
# User isn't expecting last token chunk to be present since it's not part of the default streamed response,
# so we consume it and extract the token usage metadata before it reaches the user.
usage_chunk = self.__wrapped__.__next__()
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
return


class TracedOpenAIAsyncStream(BaseTracedOpenAIStream):
async def __aenter__(self):
Expand All @@ -77,12 +110,29 @@ async def __aenter__(self):
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
return self
async def __aiter__(self):
exception_raised = False
try:
async for chunk in self.__wrapped__:
await self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
await self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopAsyncIteration:
Expand All @@ -98,6 +148,19 @@ async def __anext__(self):
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)
raise

async def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("_dd.auto_extract_token_chunk"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = await self.__wrapped__.__anext__()
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopAsyncIteration, GeneratorExit):
return


def _compute_token_count(content, model):
# type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
openai: Introduces automatic extraction of token usage from streamed chat completions.
Unless ``stream_options: {"include_usage": False}`` is explicitly set on your streamed chat completion request,
the OpenAI integration will add ``stream_options: {"include_usage": True}`` to your request and automatically extract the token usage chunk from the streamed response.
27 changes: 15 additions & 12 deletions tests/contrib/openai/test_openai_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,11 +518,17 @@ async def test_chat_completion_azure_async(
)
)

def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream_explicit_no_tokens(
self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer
):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""

with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
Expand All @@ -534,7 +540,11 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model=model, messages=input_messages, stream=True, user="ddtrace-test"
model=model,
messages=input_messages,
stream=True,
user="ddtrace-test",
stream_options={"include_usage": False},
)
for chunk in resp:
resp_model = chunk.model
Expand All @@ -547,7 +557,7 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
model_provider="openai",
input_messages=input_messages,
output_messages=[{"content": expected_completion, "role": "assistant"}],
metadata={"stream": True, "user": "ddtrace-test"},
metadata={"stream": True, "stream_options": {"include_usage": False}, "user": "ddtrace-test"},
token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
)
Expand All @@ -557,20 +567,14 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
parse_version(openai_module.version.VERSION) < (1, 26, 0), reason="Streamed tokens available in 1.26.0+"
)
def test_chat_completion_stream_tokens(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""
Ensure llmobs records are emitted for chat completion endpoints when configured
with the `stream_options={"include_usage": True}`.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
"""Assert that streamed token chunk extraction logic works when options are not explicitly passed from user."""
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
model = "gpt-3.5-turbo"
resp_model = model
input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}]
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model=model, messages=input_messages, stream=True, stream_options={"include_usage": True}
)
resp = client.chat.completions.create(model=model, messages=input_messages, stream=True)
for chunk in resp:
resp_model = chunk.model
span = mock_tracer.pop_traces()[0][0]
Expand Down Expand Up @@ -671,7 +675,6 @@ def test_chat_completion_tool_call_stream(self, openai, ddtrace_global_config, m
messages=[{"role": "user", "content": chat_completion_input_description}],
user="ddtrace-test",
stream=True,
stream_options={"include_usage": True},
)
for chunk in resp:
resp_model = chunk.model
Expand Down
Loading

0 comments on commit 35fe7b5

Please sign in to comment.