Skip to content

Commit

Permalink
Extract helper
Browse files Browse the repository at this point in the history
  • Loading branch information
Yun-Kim committed Dec 10, 2024
1 parent a9abc46 commit 23e706b
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 20 deletions.
4 changes: 3 additions & 1 deletion ddtrace/contrib/internal/openai/_endpoint_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def _record_request(self, pin, integration, span, args, kwargs):
span.set_tag_str("openai.request.messages.%d.content" % idx, integration.trunc(str(content)))
span.set_tag_str("openai.request.messages.%d.role" % idx, str(role))
span.set_tag_str("openai.request.messages.%d.name" % idx, str(name))
if kwargs.get("stream") and not kwargs.get("stream_options", {}).get("include_usage", False):
if parse_version(OPENAI_VERSION) >= (1, 26) and kwargs.get("stream"):
if kwargs.get("stream_options", {}).get("include_usage", None) is not None:
return
span._set_ctx_item("openai_stream_magic", True)
stream_options = kwargs.get("stream_options", {})
stream_options["include_usage"] = True
Expand Down
44 changes: 33 additions & 11 deletions ddtrace/contrib/internal/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,19 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False):
self._is_completion = is_completion
self._kwargs = kwargs

def _extract_token_chunk(self, chunk):
"""Attempt to extract the token chunk (last chunk in the stream) from the streamed response."""
if not self._dd_span._get_ctx_item("openai_stream_magic"):
return
choice = getattr(chunk, "choices", [None])[0]
if not getattr(choice, "finish_reason", None):
return
try:
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
except (StopIteration, GeneratorExit):
pass


class TracedOpenAIStream(BaseTracedOpenAIStream):
def __enter__(self):
Expand All @@ -51,11 +64,7 @@ def __iter__(self):
exception_raised = False
try:
for chunk in self.__wrapped__:
if self._dd_span._get_ctx_item("openai_stream_magic"):
choice = getattr(chunk, "choices", [None])[0]
if getattr(choice, "finish_reason", None):
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
Expand All @@ -73,11 +82,7 @@ def __iter__(self):
def __next__(self):
try:
chunk = next(self.__wrapped__)
if self._dd_span._get_ctx_item("openai_stream_magic"):
choice = getattr(chunk, "choices", [None])[0]
if getattr(choice, "finish_reason", None):
usage_chunk = next(self.__wrapped__)
self._streamed_chunks[0].insert(0, usage_chunk)
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopIteration:
Expand All @@ -103,11 +108,28 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb)

def __aiter__(self):
return self
exception_raised = False
try:
for chunk in self.__wrapped__:
self._extract_token_chunk(chunk)
yield chunk
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
except Exception:
self._dd_span.set_exc_info(*sys.exc_info())
exception_raised = True
raise
finally:
if not exception_raised:
_process_finished_stream(
self._dd_integration, self._dd_span, self._kwargs, self._streamed_chunks, self._is_completion
)
self._dd_span.finish()
self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns)

async def __anext__(self):
try:
chunk = await self.__wrapped__.__anext__()
self._extract_token_chunk(chunk)
_loop_handler(self._dd_span, chunk, self._streamed_chunks)
return chunk
except StopAsyncIteration:
Expand Down
3 changes: 0 additions & 3 deletions ddtrace/llmobs/_integrations/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,6 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages
input_messages.append({"content": str(_get_attr(m, "content", "")), "role": str(_get_attr(m, "role", ""))})
span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages))

if span._get_ctx_item("openai_stream_magic"):
kwargs.pop("stream_options", None)

parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")}
span.set_tag_str(METADATA, safe_json(parameters))

Expand Down
10 changes: 7 additions & 3 deletions tests/contrib/openai/test_openai_llmobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,12 +518,16 @@ async def test_chat_completion_azure_async(
)
)

@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer):
"""Ensure llmobs records are emitted for chat completion endpoints when configured.
Also ensure the llmobs records have the correct tagging including trace/span ID for trace correlation.
"""
with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed.yaml"):

with get_openai_vcr(subdirectory_name="v1").use_cassette("chat_completion_streamed_tokens.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
with mock.patch("ddtrace.contrib.internal.openai.utils._est_tokens") as mock_est:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
Expand All @@ -547,8 +551,8 @@ def test_chat_completion_stream(self, openai, ddtrace_global_config, mock_llmobs
model_provider="openai",
input_messages=input_messages,
output_messages=[{"content": expected_completion, "role": "assistant"}],
metadata={"stream": True, "user": "ddtrace-test"},
token_metrics={"input_tokens": 8, "output_tokens": 8, "total_tokens": 16},
metadata={"stream": True, "stream_options": {"include_usage": True}, "user": "ddtrace-test"},
token_metrics={"input_tokens": 17, "output_tokens": 19, "total_tokens": 36},
tags={"ml_app": "<ml-app-name>", "service": "tests.contrib.openai"},
)
)
Expand Down
56 changes: 54 additions & 2 deletions tests/contrib/openai/test_openai_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,7 +1042,59 @@ def test_completion_stream_context_manager(openai, openai_vcr, mock_metrics, moc
assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls


@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer):
with openai_vcr.use_cassette("chat_completion_streamed_tokens.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
expected_completion = "The Los Angeles Dodgers won the World Series in 2020."
client = openai.OpenAI()
resp = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "Who won the world series in 2020?"},
],
stream=True,
user="ddtrace-test",
n=None,
)
span = snapshot_tracer.current_span()
chunks = [c for c in resp]
assert len(chunks) == 15
completion = "".join([c.choices[0].delta.content for c in chunks if c.choices[0].delta.content is not None])
assert completion == expected_completion

assert span.get_tag("openai.response.choices.0.message.content") == expected_completion
assert span.get_tag("openai.response.choices.0.message.role") == "assistant"
assert span.get_tag("openai.response.choices.0.finish_reason") == "stop"

expected_tags = [
"version:",
"env:",
"service:tests.contrib.openai",
"openai.request.model:gpt-3.5-turbo",
"model:gpt-3.5-turbo",
"openai.request.endpoint:/v1/chat/completions",
"openai.request.method:POST",
"openai.organization.id:",
"openai.organization.name:datadog-4",
"openai.user.api_key:sk-...key>",
"error:0",
]
assert mock.call.distribution("request.duration", span.duration_ns, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.gauge("ratelimit.requests", 3000, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.prompt", 17, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.completion", 19, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.total", 36, tags=expected_tags) in mock_metrics.mock_calls


@pytest.mark.skipif(
parse_version(openai_module.version.VERSION) < (1, 26), reason="Stream options only available openai >= 1.26"
)
def test_chat_completion_stream_explicit_no_tokens(openai, openai_vcr, mock_metrics, snapshot_tracer):
with openai_vcr.use_cassette("chat_completion_streamed.yaml"):
with mock.patch("ddtrace.contrib.internal.openai.utils.encoding_for_model", create=True) as mock_encoding:
mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8]
Expand All @@ -1054,10 +1106,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
{"role": "user", "content": "Who won the world series in 2020?"},
],
stream=True,
stream_options={"include_usage": False},
user="ddtrace-test",
n=None,
)
prompt_tokens = 8
span = snapshot_tracer.current_span()
chunks = [c for c in resp]
assert len(chunks) == 15
Expand Down Expand Up @@ -1087,7 +1139,7 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace
expected_tags += ["openai.estimated:true"]
if TIKTOKEN_AVAILABLE:
expected_tags = expected_tags[:-1]
assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.prompt", 8, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls
assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls

Expand Down

0 comments on commit 23e706b

Please sign in to comment.