diff --git a/ddtrace/contrib/internal/openai/utils.py b/ddtrace/contrib/internal/openai/utils.py index 09061aa823d..ad98b518cde 100644 --- a/ddtrace/contrib/internal/openai/utils.py +++ b/ddtrace/contrib/internal/openai/utils.py @@ -38,19 +38,6 @@ def __init__(self, wrapped, integration, span, kwargs, is_completion=False): self._is_completion = is_completion self._kwargs = kwargs - def _extract_token_chunk(self, chunk): - """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" - if not self._dd_span._get_ctx_item("openai_stream_magic"): - return - choice = getattr(chunk, "choices", [None])[0] - if not getattr(choice, "finish_reason", None): - return - try: - usage_chunk = next(self.__wrapped__) - self._streamed_chunks[0].insert(0, usage_chunk) - except (StopIteration, GeneratorExit): - pass - class TracedOpenAIStream(BaseTracedOpenAIStream): def __enter__(self): @@ -98,6 +85,18 @@ def __next__(self): self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) raise + def _extract_token_chunk(self, chunk): + """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" + if not self._dd_span._get_ctx_item("openai_stream_magic"): + return + choice = getattr(chunk, "choices", [None])[0] + if not getattr(choice, "finish_reason", None): + return + try: + usage_chunk = next(self.__wrapped__) + self._streamed_chunks[0].insert(0, usage_chunk) + except (StopIteration, GeneratorExit): + return class TracedOpenAIAsyncStream(BaseTracedOpenAIStream): async def __aenter__(self): @@ -107,11 +106,11 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): await self.__wrapped__.__aexit__(exc_type, exc_val, exc_tb) - def __aiter__(self): + async def __aiter__(self): exception_raised = False try: - for chunk in self.__wrapped__: - self._extract_token_chunk(chunk) + async for chunk in self.__wrapped__: + await self._extract_token_chunk(chunk) yield chunk _loop_handler(self._dd_span, chunk, self._streamed_chunks) except Exception: @@ -128,8 +127,8 @@ def __aiter__(self): async def __anext__(self): try: - chunk = await self.__wrapped__.__anext__() - self._extract_token_chunk(chunk) + chunk = await anext(self.__wrapped__) + await self._extract_token_chunk(chunk) _loop_handler(self._dd_span, chunk, self._streamed_chunks) return chunk except StopAsyncIteration: @@ -145,6 +144,19 @@ async def __anext__(self): self._dd_integration.metric(self._dd_span, "dist", "request.duration", self._dd_span.duration_ns) raise + async def _extract_token_chunk(self, chunk): + """Attempt to extract the token chunk (last chunk in the stream) from the streamed response.""" + if not self._dd_span._get_ctx_item("openai_stream_magic"): + return + choice = getattr(chunk, "choices", [None])[0] + if not getattr(choice, "finish_reason", None): + return + try: + usage_chunk = await anext(self.__wrapped__) + self._streamed_chunks[0].insert(0, usage_chunk) + except (StopAsyncIteration, GeneratorExit): + return + def _compute_token_count(content, model): # type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int]