From ddce4510780687cd1a9caab6caa8bace0b81784a Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Thu, 29 Feb 2024 12:47:06 -0500 Subject: [PATCH 1/8] fix(setup): re-introduce tiktoken optional dependency (#8558) Reverts an inadvertent removal of the `tiktoken` optional dependency in #6529. The OpenAI integration uses the `tiktoken` package to compute completion tokens from OpenAI chat/completion responses, and as a convenience in #6470 the `tiktoken` package was added as an optional dependency using the `ddtrace[openai]` extra package. This was inadvertently removed in #6529, so we're re-introducing this into setup.py. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4aad92ad702..e85443160ae 100644 --- a/setup.py +++ b/setup.py @@ -534,6 +534,7 @@ def get_exts_for(name): # users can include opentracing by having: # install_requires=['ddtrace[opentracing]', ...] "opentracing": ["opentracing>=2.0.0"], + "openai": ["tiktoken"], }, tests_require=["flake8"], cmdclass={ From 3e2dfe8cd39b112a8342dd38b5b5fc0b6747b9d7 Mon Sep 17 00:00:00 2001 From: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Date: Thu, 29 Feb 2024 10:27:01 -0800 Subject: [PATCH 2/8] ci: work around a bunch of unreliable test failures (#8556) This change adds workarounds for four unreliable test failures recently observed on the main branch. * https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56696/workflows/2a0f2e96-d1fc-4d40-8b5b-c414adfa40db/jobs/3599020 - loosens assertion to accept multiple calls, which have been observed * https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56703/workflows/de876b66-0a76-405e-aaba-efb28c79701b/jobs/3599380 - marks as flaky, because the failure mode appears to be incompatible with the test's intent * https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56727/workflows/a7ef12fa-bd08-4e0b-b126-cc6115051d27/jobs/3600333 - disables the pytest coverage plugin, which is implicated in the traceback * https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56760/workflows/7d0322a2-30bf-4afa-b501-b4ba3eb44b84/jobs/3601489 - makes `wait_for_num_traces` in snapshot tests permissive of extra traces, an approach similar to the one taken for other snapshot code paths ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [ ] Title is accurate - [ ] All changes are related to the pull request's stated goal - [ ] Description motivates each change - [ ] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [ ] Testing strategy adequately addresses listed risks - [ ] Change is maintainable (easy to change, telemetry, documentation) - [ ] Release note makes sense to a user of the library - [ ] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [ ] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- riotfile.py | 2 +- tests/profiling/collector/test_memalloc.py | 2 ++ tests/tracer/runtime/test_container.py | 2 +- tests/utils.py | 2 +- 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/riotfile.py b/riotfile.py index f1ede06bf1e..ba565179ce3 100644 --- a/riotfile.py +++ b/riotfile.py @@ -1552,7 +1552,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION): ), Venv( name="pytest", - command="pytest --no-ddtrace {cmdargs} tests/contrib/pytest/", + command="pytest --no-ddtrace --no-cov {cmdargs} tests/contrib/pytest/", pkgs={ "pytest-randomly": latest, }, diff --git a/tests/profiling/collector/test_memalloc.py b/tests/profiling/collector/test_memalloc.py index 712f6f93662..17120bd491a 100644 --- a/tests/profiling/collector/test_memalloc.py +++ b/tests/profiling/collector/test_memalloc.py @@ -9,6 +9,7 @@ from ddtrace.profiling.event import DDFrame from ddtrace.settings.profiling import ProfilingConfig from ddtrace.settings.profiling import _derive_default_heap_sample_size +from tests.utils import flaky try: @@ -154,6 +155,7 @@ def test_iter_events_multi_thread(): assert count_thread >= 1000 +@flaky(1719591602) def test_memory_collector(): r = recorder.Recorder() mc = memalloc.MemoryCollector(r) diff --git a/tests/tracer/runtime/test_container.py b/tests/tracer/runtime/test_container.py index 9d033f125ab..eb13a4d6c1d 100644 --- a/tests/tracer/runtime/test_container.py +++ b/tests/tracer/runtime/test_container.py @@ -310,7 +310,7 @@ def test_get_container_info(file_contents, container_id, node_inode): else: assert isinstance(info.node_inode, int) - mock_open.assert_called_once_with("/proc/self/cgroup", mode="r") + mock_open.assert_called_with("/proc/self/cgroup", mode="r") @pytest.mark.parametrize( diff --git a/tests/utils.py b/tests/utils.py index 1ca5ea40003..7ea7c16ae59 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1046,7 +1046,7 @@ def snapshot_context( r = conn.getresponse() if r.status == 200: traces = json.loads(r.read()) - if len(traces) == wait_for_num_traces: + if len(traces) >= wait_for_num_traces: break except Exception: pass From eff123f9e9e80b2df15bbc456c4d23b0afa63fc7 Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Thu, 29 Feb 2024 14:22:52 -0500 Subject: [PATCH 3/8] feat(openai): use tiktoken to calculate streamed completion tokens (#8557) This is a follow-up to #8521. This PR adds using tiktoken to calculate tokens for streamed completion/chat completion responses. We were previously using the number of stream chunks (which roughly correlated to tokens) to estimate streamed completion tokens, but we can use tiktoken (if installed in the user's env) instead for a more accurate computation. This is done by introducing a few private utility helper methods to: - Append each streamed chunk into the correct index in a streamed_completions list - Stitch together the streamed completion/chat_completion chunks into completion/messages that we can calculate token counts from. - Refactor the overcrowded `_handle_streamed_response()` helper to set token metrics based on the request prompt/input as well as set token metrics based on the output completion/chat_messages. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/contrib/openai/_endpoint_hooks.py | 106 +++++++-------- ddtrace/contrib/openai/utils.py | 155 +++++++++++++++++----- ddtrace/llmobs/_integrations/openai.py | 10 +- tests/contrib/openai/test_openai_v0.py | 40 ++---- tests/contrib/openai/test_openai_v1.py | 96 ++++++-------- 5 files changed, 227 insertions(+), 180 deletions(-) diff --git a/ddtrace/contrib/openai/_endpoint_hooks.py b/ddtrace/contrib/openai/_endpoint_hooks.py index 5b1372089f8..08ce9238040 100644 --- a/ddtrace/contrib/openai/_endpoint_hooks.py +++ b/ddtrace/contrib/openai/_endpoint_hooks.py @@ -1,10 +1,15 @@ from ddtrace.ext import SpanTypes -from .utils import _compute_prompt_token_count +from .utils import _construct_completion_from_streamed_chunks +from .utils import _construct_message_from_streamed_chunks from .utils import _format_openai_api_key from .utils import _is_async_generator from .utils import _is_generator +from .utils import _loop_handler +from .utils import _set_metrics_on_request +from .utils import _set_metrics_on_streamed_response from .utils import _tag_streamed_chat_completion_response +from .utils import _tag_streamed_completion_response from .utils import _tag_tool_calls @@ -89,13 +94,9 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): class _BaseCompletionHook(_EndpointHook): - """ - Share streamed response handling logic between Completion and ChatCompletion endpoints. - """ - _request_arg_params = ("api_key", "api_base", "api_type", "request_id", "api_version", "organization") - def _handle_streamed_response(self, integration, span, args, kwargs, resp): + def _handle_streamed_response(self, integration, span, kwargs, resp, operation_id): """Handle streamed response objects returned from endpoint calls. This method helps with streamed responses by wrapping the generator returned with a @@ -103,66 +104,46 @@ def _handle_streamed_response(self, integration, span, args, kwargs, resp): """ def shared_gen(): + completions = None + messages = None try: - num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0 streamed_chunks = yield - num_completion_tokens = sum(len(streamed_choice) for streamed_choice in streamed_chunks) - span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens) - total_tokens = num_prompt_tokens + num_completion_tokens - span.set_metric("openai.response.usage.total_tokens", total_tokens) - if span.get_metric("openai.request.prompt_tokens_estimated") == 0: - integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens) - else: - integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens, tags=["openai.estimated:true"]) - integration.metric( - span, "dist", "tokens.completion", num_completion_tokens, tags=["openai.estimated:true"] + _set_metrics_on_request( + integration, span, kwargs, prompts=kwargs.get("prompt", None), messages=kwargs.get("messages", None) ) - integration.metric(span, "dist", "tokens.total", total_tokens, tags=["openai.estimated:true"]) + if operation_id == _CompletionHook.OPERATION_ID: + completions = [_construct_completion_from_streamed_chunks(choice) for choice in streamed_chunks] + else: + messages = [_construct_message_from_streamed_chunks(choice) for choice in streamed_chunks] + _set_metrics_on_streamed_response(integration, span, completions=completions, messages=messages) finally: - if integration.is_pc_sampled_span(span): - _tag_streamed_chat_completion_response(integration, span, streamed_chunks) - if integration.is_pc_sampled_llmobs(span): - if span.resource == _ChatCompletionHook.OPERATION_ID: - integration.llmobs_set_tags("chat", resp, None, span, kwargs, streamed_resp=streamed_chunks) - elif span.resource == _CompletionHook.OPERATION_ID: - integration.llmobs_set_tags( - "completion", resp, None, span, kwargs, streamed_resp=streamed_chunks - ) + if operation_id == _CompletionHook.OPERATION_ID: + if integration.is_pc_sampled_span(span): + _tag_streamed_completion_response(integration, span, completions) + if integration.is_pc_sampled_llmobs(span): + integration.llmobs_set_tags("completion", resp, span, kwargs, streamed_resp=streamed_chunks) + else: + if integration.is_pc_sampled_span(span): + _tag_streamed_chat_completion_response(integration, span, messages) + if integration.is_pc_sampled_llmobs(span): + integration.llmobs_set_tags("chat", resp, span, kwargs, streamed_resp=streamed_chunks) span.finish() integration.metric(span, "dist", "request.duration", span.duration_ns) - num_prompt_tokens = 0 - estimated = False - prompt = kwargs.get("prompt", None) - messages = kwargs.get("messages", None) - if prompt is not None: - if isinstance(prompt, str) or isinstance(prompt, list) and isinstance(prompt[0], int): - prompt = [prompt] - for p in prompt: - estimated, prompt_tokens = _compute_prompt_token_count(p, kwargs.get("model")) - num_prompt_tokens += prompt_tokens - if messages is not None: - for m in messages: - estimated, prompt_tokens = _compute_prompt_token_count(m.get("content", ""), kwargs.get("model")) - num_prompt_tokens += prompt_tokens - span.set_metric("openai.request.prompt_tokens_estimated", int(estimated)) - span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens) - - # A chunk corresponds to a token: - # https://community.openai.com/t/how-to-get-total-tokens-from-a-stream-of-completioncreaterequests/110700 - # https://community.openai.com/t/openai-api-get-usage-tokens-in-response-when-set-stream-true/141866 if _is_async_generator(resp): async def traced_streamed_response(): g = shared_gen() g.send(None) - streamed_chunks = [[] for _ in range(kwargs.get("n", 1))] + n = kwargs.get("n", 1) + if operation_id == _CompletionHook.OPERATION_ID: + prompts = kwargs.get("prompt", "") + if isinstance(prompts, list) and not isinstance(prompts[0], int): + n *= len(prompts) + streamed_chunks = [[] for _ in range(n)] try: async for chunk in resp: - if span.get_tag("openai.response.model") is None: - span.set_tag_str("openai.response.model", chunk.model) - for choice in chunk.choices: - streamed_chunks[choice.index].append(choice) + _loop_handler(span, chunk, streamed_chunks) yield chunk finally: try: @@ -177,13 +158,15 @@ async def traced_streamed_response(): def traced_streamed_response(): g = shared_gen() g.send(None) - streamed_chunks = [[] for _ in range(kwargs.get("n", 1))] + n = kwargs.get("n", 1) + if operation_id == _CompletionHook.OPERATION_ID: + prompts = kwargs.get("prompt", "") + if isinstance(prompts, list) and not isinstance(prompts[0], int): + n *= len(prompts) + streamed_chunks = [[] for _ in range(n)] try: for chunk in resp: - if span.get_tag("openai.response.model") is None: - span.set_tag_str("openai.response.model", chunk.model) - for choice in chunk.choices: - streamed_chunks[choice.index].append(choice) + _loop_handler(span, chunk, streamed_chunks) yield chunk finally: try: @@ -233,7 +216,7 @@ def _record_request(self, pin, integration, span, args, kwargs): def _record_response(self, pin, integration, span, args, kwargs, resp, error): resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) if kwargs.get("stream") and error is None: - return self._handle_streamed_response(integration, span, args, kwargs, resp) + return self._handle_streamed_response(integration, span, kwargs, resp, operation_id=self.OPERATION_ID) if integration.is_pc_sampled_log(span): attrs_dict = {"prompt": kwargs.get("prompt", "")} if error is None: @@ -245,7 +228,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict ) if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("completion", resp, error, span, kwargs) + integration.llmobs_set_tags("completion", resp, span, kwargs, err=error) if not resp: return for choice in resp.choices: @@ -257,6 +240,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): class _ChatCompletionHook(_BaseCompletionHook): + _request_arg_params = ("api_key", "api_base", "api_type", "request_id", "api_version", "organization") _request_kwarg_params = ( "model", "engine", @@ -290,7 +274,7 @@ def _record_request(self, pin, integration, span, args, kwargs): def _record_response(self, pin, integration, span, args, kwargs, resp, error): resp = super()._record_response(pin, integration, span, args, kwargs, resp, error) if kwargs.get("stream") and error is None: - return self._handle_streamed_response(integration, span, args, kwargs, resp) + return self._handle_streamed_response(integration, span, kwargs, resp, operation_id=self.OPERATION_ID) if integration.is_pc_sampled_log(span): log_choices = resp.choices if hasattr(resp.choices[0], "model_dump"): @@ -300,7 +284,7 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error): span, "info" if error is None else "error", "sampled %s" % self.OPERATION_ID, attrs=attrs_dict ) if integration.is_pc_sampled_llmobs(span): - integration.llmobs_set_tags("chat", resp, error, span, kwargs) + integration.llmobs_set_tags("chat", resp, span, kwargs, err=error) if not resp: return for choice in resp.choices: diff --git a/ddtrace/contrib/openai/utils.py b/ddtrace/contrib/openai/utils.py index 249dc0bb711..09cfbf0036c 100644 --- a/ddtrace/contrib/openai/utils.py +++ b/ddtrace/contrib/openai/utils.py @@ -1,10 +1,9 @@ import re -from typing import AsyncGenerator # noqa:F401 -from typing import Generator # noqa:F401 -from typing import List # noqa:F401 -from typing import Optional # noqa:F401 -from typing import Tuple # noqa:F401 -from typing import Union # noqa:F401 +from typing import Any +from typing import AsyncGenerator +from typing import Dict +from typing import Generator +from typing import List from ddtrace.internal.logger import get_logger @@ -22,10 +21,10 @@ _punc_regex = re.compile(r"[\w']+|[.,!?;~@#$%^&*()+/-]") -def _compute_prompt_token_count(prompt, model): +def _compute_token_count(content, model): # type: (Union[str, List[int]], Optional[str]) -> Tuple[bool, int] """ - Takes in a prompt(s) and model pair, and returns a tuple of whether or not the number of prompt + Takes in prompt/response(s) and model pair, and returns a tuple of whether or not the number of prompt tokens was estimated, and the estimated/calculated prompt token count. """ num_prompt_tokens = 0 @@ -33,10 +32,10 @@ def _compute_prompt_token_count(prompt, model): if model is not None and tiktoken_available is True: try: enc = encoding_for_model(model) - if isinstance(prompt, str): - num_prompt_tokens += len(enc.encode(prompt)) - elif isinstance(prompt, list) and isinstance(prompt[0], int): - num_prompt_tokens += len(prompt) + if isinstance(content, str): + num_prompt_tokens += len(enc.encode(content)) + elif isinstance(content, list) and isinstance(content[0], int): + num_prompt_tokens += len(content) return estimated, num_prompt_tokens except KeyError: # tiktoken.encoding_for_model() will raise a KeyError if it doesn't have a tokenizer for the model @@ -45,7 +44,7 @@ def _compute_prompt_token_count(prompt, model): estimated = True # If model is unavailable or tiktoken is not imported, then provide a very rough estimate of the number of tokens - return estimated, _est_tokens(prompt) + return estimated, _est_tokens(content) def _est_tokens(prompt): @@ -105,6 +104,114 @@ def _is_async_generator(resp): return False +def _construct_completion_from_streamed_chunks(streamed_chunks: List[Any]) -> Dict[str, str]: + """Constructs a completion dictionary of form {"text": "...", "finish_reason": "..."} from streamed chunks.""" + completion = {"text": "".join(c.text for c in streamed_chunks if getattr(c, "text", None))} + if streamed_chunks[-1].finish_reason is not None: + completion["finish_reason"] = streamed_chunks[-1].finish_reason + return completion + + +def _construct_message_from_streamed_chunks(streamed_chunks: List[Any]) -> Dict[str, str]: + """Constructs a chat completion message dictionary from streamed chunks. + The resulting message dictionary is of form {"content": "...", "role": "...", "finish_reason": "..."} + """ + message = {} + content = "".join(c.delta.content for c in streamed_chunks if getattr(c.delta, "content", None)) + if getattr(streamed_chunks[0].delta, "tool_calls", None): + content = "".join( + c.delta.tool_calls.function.arguments for c in streamed_chunks if getattr(c.delta, "tool_calls", None) + ) + elif getattr(streamed_chunks[0].delta, "function_call", None): + content = "".join( + c.delta.function_call.arguments for c in streamed_chunks if getattr(c.delta, "function_call", None) + ) + message["role"] = streamed_chunks[0].delta.role + if streamed_chunks[-1].finish_reason is not None: + message["finish_reason"] = streamed_chunks[-1].finish_reason + message["content"] = content + return message + + +def _tag_streamed_completion_response(integration, span, completions): + """Tagging logic for streamed completions.""" + for idx, choice in enumerate(completions): + span.set_tag_str("openai.response.choices.%d.text" % idx, integration.trunc(choice["text"])) + if choice.get("finish_reason") is not None: + span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, choice["finish_reason"]) + + +def _tag_streamed_chat_completion_response(integration, span, messages): + """Tagging logic for streamed chat completions.""" + for idx, message in enumerate(messages): + span.set_tag_str("openai.response.choices.%d.message.content" % idx, integration.trunc(message["content"])) + span.set_tag_str("openai.response.choices.%d.message.role" % idx, message["role"]) + if message.get("finish_reason") is not None: + span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, message["finish_reason"]) + + +def _set_metrics_on_request(integration, span, kwargs, prompts=None, messages=None): + """Set token span metrics on streamed chat/completion requests.""" + num_prompt_tokens = 0 + estimated = False + if messages is not None: + for m in messages: + estimated, prompt_tokens = _compute_token_count(m.get("content", ""), kwargs.get("model")) + num_prompt_tokens += prompt_tokens + elif prompts is not None: + if isinstance(prompts, str) or isinstance(prompts, list) and isinstance(prompts[0], int): + prompts = [prompts] + for prompt in prompts: + estimated, prompt_tokens = _compute_token_count(prompt, kwargs.get("model")) + num_prompt_tokens += prompt_tokens + span.set_metric("openai.request.prompt_tokens_estimated", int(estimated)) + span.set_metric("openai.response.usage.prompt_tokens", num_prompt_tokens) + if not estimated: + integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens) + else: + integration.metric(span, "dist", "tokens.prompt", num_prompt_tokens, tags=["openai.estimated:true"]) + + +def _set_metrics_on_streamed_response(integration, span, completions=None, messages=None): + """Set token span metrics on streamed chat/completion responses.""" + num_completion_tokens = 0 + estimated = False + if messages is not None: + for m in messages: + estimated, completion_tokens = _compute_token_count( + m.get("content", ""), span.get_tag("openai.response.model") + ) + num_completion_tokens += completion_tokens + elif completions is not None: + for c in completions: + estimated, completion_tokens = _compute_token_count( + c.get("text", ""), span.get_tag("openai.response.model") + ) + num_completion_tokens += completion_tokens + span.set_metric("openai.response.completion_tokens_estimated", int(estimated)) + span.set_metric("openai.response.usage.completion_tokens", num_completion_tokens) + num_prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0 + total_tokens = num_prompt_tokens + num_completion_tokens + span.set_metric("openai.response.usage.total_tokens", total_tokens) + if not estimated: + integration.metric(span, "dist", "tokens.completion", num_completion_tokens) + integration.metric(span, "dist", "tokens.total", total_tokens) + else: + integration.metric(span, "dist", "tokens.completion", num_completion_tokens, tags=["openai.estimated:true"]) + integration.metric(span, "dist", "tokens.total", total_tokens, tags=["openai.estimated:true"]) + + +def _loop_handler(span, chunk, streamed_chunks): + """Sets the openai model tag and appends the chunk to the correct index in the streamed_chunks list. + + When handling a streamed chat/completion response, this function is called for each chunk in the streamed response. + """ + if span.get_tag("openai.response.model") is None: + span.set_tag("openai.response.model", chunk.model) + for choice in chunk.choices: + streamed_chunks[choice.index].append(choice) + + def _tag_tool_calls(integration, span, tool_calls, choice_idx): # type: (...) -> None """ @@ -120,25 +227,3 @@ def _tag_tool_calls(integration, span, tool_calls, choice_idx): integration.trunc(str(tool_call.arguments)), ) span.set_tag("openai.response.choices.%d.message.tool_calls.%d.name" % (choice_idx, idy), str(tool_call.name)) - - -def _tag_streamed_chat_completion_response(integration, span, streamed_chunks): - """Tagging logic for streamed chat/completions.""" - for idx, choice in enumerate(streamed_chunks): - if span.resource == "createChatCompletion": - content = "".join(c.delta.content for c in choice if getattr(c.delta, "content", None)) - if getattr(choice[0].delta, "tool_calls", None): - content = "".join( - c.delta.tool_calls.function.arguments for c in choice if getattr(c.delta, "tool_calls", None) - ) - elif getattr(choice[0].delta, "function_call", None): - content = "".join( - c.delta.function_call.arguments for c in choice if getattr(c.delta, "function_call", None) - ) - span.set_tag_str("openai.response.choices.%d.message.content" % idx, integration.trunc(content)) - span.set_tag_str("openai.response.choices.%d.message.role" % idx, choice[0].delta.role) - else: - content = "".join(c.text for c in choice if c.text is not None) - span.set_tag_str("openai.response.choices.%d.text" % idx, integration.trunc(content)) - if choice[-1].finish_reason is not None: - span.set_tag_str("openai.response.choices.%d.finish_reason" % idx, choice[-1].finish_reason) diff --git a/ddtrace/llmobs/_integrations/openai.py b/ddtrace/llmobs/_integrations/openai.py index f7041050965..4512bd28826 100644 --- a/ddtrace/llmobs/_integrations/openai.py +++ b/ddtrace/llmobs/_integrations/openai.py @@ -117,10 +117,10 @@ def llmobs_set_tags( self, record_type: str, resp: Any, - err: Any, span: Span, kwargs: Dict[str, Any], streamed_resp: Optional[Any] = None, + err: Optional[Any] = None, ) -> None: """Sets meta tags and metrics for span events to be sent to LLMObs.""" if not self.llmobs_enabled: @@ -133,7 +133,7 @@ def llmobs_set_tags( self._llmobs_set_meta_tags_from_completion(resp, err, kwargs, streamed_resp, span) elif record_type == "chat": self._llmobs_set_meta_tags_from_chat(resp, err, kwargs, streamed_resp, span) - span.set_tag_str(METRICS, json.dumps(self._set_llmobs_metrics_tags(span, resp, streamed_resp))) + span.set_tag_str(METRICS, json.dumps(self._set_llmobs_metrics_tags(span, resp, streamed_resp is not None))) @staticmethod def _llmobs_set_meta_tags_from_completion( @@ -201,12 +201,12 @@ def _llmobs_set_meta_tags_from_chat( span.set_tag_str(OUTPUT_MESSAGES, json.dumps(output_messages)) @staticmethod - def _set_llmobs_metrics_tags(span: Span, resp: Any, streamed_resp: Optional[Any]) -> Dict[str, Any]: + def _set_llmobs_metrics_tags(span: Span, resp: Any, streamed: bool = False) -> Dict[str, Any]: """Extract metrics from a chat/completion and set them as a temporary "_ml_obs.metrics" tag.""" metrics = {} - if streamed_resp: + if streamed: prompt_tokens = span.get_metric("openai.response.usage.prompt_tokens") or 0 - completion_tokens = sum(len(choice) for choice in streamed_resp) + completion_tokens = span.get_metric("openai.response.usage.completion_tokens") or 0 metrics.update( { "prompt_tokens": prompt_tokens, diff --git a/tests/contrib/openai/test_openai_v0.py b/tests/contrib/openai/test_openai_v0.py index 0e6b034a86b..ce8fb79b528 100644 --- a/tests/contrib/openai/test_openai_v0.py +++ b/tests/contrib/openai/test_openai_v0.py @@ -1612,12 +1612,10 @@ async def test_completion_async_stream(openai, openai_vcr, mock_metrics, mock_tr "openai.estimated:true", ] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", 2, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.total", len(chunks) + 2, tags=expected_tags) in mock_metrics.mock_calls + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", 2, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", 15, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", 17, tags=expected_tags) in mock_metrics.mock_calls def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer): @@ -1664,15 +1662,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace assert mock.call.gauge("ratelimit.remaining.requests", 2, tags=expected_tags) in mock_metrics.mock_calls expected_tags += ["openai.estimated:true"] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert ( - mock.call.distribution("tokens.total", len(chunks) + prompt_tokens, tags=expected_tags) - in mock_metrics.mock_calls - ) + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", 12, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", 12 + prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls @pytest.mark.asyncio @@ -1721,15 +1714,10 @@ async def test_chat_completion_async_stream(openai, openai_vcr, mock_metrics, sn assert mock.call.gauge("ratelimit.remaining.tokens", 89971, tags=expected_tags) in mock_metrics.mock_calls expected_tags += ["openai.estimated:true"] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert ( - mock.call.distribution("tokens.total", len(chunks) + prompt_tokens, tags=expected_tags) - in mock_metrics.mock_calls - ) + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", 35, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", 35 + prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls @pytest.mark.snapshot( @@ -2333,7 +2321,7 @@ async def test_llmobs_chat_completion_stream( input_messages=input_messages, output_messages=[{"content": expected_completion, "role": "assistant"}], parameters={"temperature": 0}, - token_metrics={"prompt_tokens": 8, "completion_tokens": 15, "total_tokens": 23}, + token_metrics={"prompt_tokens": 8, "completion_tokens": 12, "total_tokens": 20}, ) ), ] @@ -2410,7 +2398,7 @@ def test_llmobs_chat_completion_function_call_stream( input_messages=[{"content": chat_completion_input_description, "role": "user"}], output_messages=[{"content": expected_output, "role": "assistant"}], parameters={"temperature": 0}, - token_metrics={"prompt_tokens": 63, "completion_tokens": 35, "total_tokens": 98}, + token_metrics={"prompt_tokens": 63, "completion_tokens": 33, "total_tokens": 96}, ) ), ] diff --git a/tests/contrib/openai/test_openai_v1.py b/tests/contrib/openai/test_openai_v1.py index f7efebfc3c0..bf7935317ac 100644 --- a/tests/contrib/openai/test_openai_v1.py +++ b/tests/contrib/openai/test_openai_v1.py @@ -1219,12 +1219,10 @@ def test_completion_stream(openai, openai_vcr, mock_metrics, mock_tracer): "openai.estimated:true", ] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", 2, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.total", len(chunks) + 2, tags=expected_tags) in mock_metrics.mock_calls + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", 2, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls @pytest.mark.asyncio @@ -1261,12 +1259,10 @@ async def test_completion_async_stream(openai, openai_vcr, mock_metrics, mock_tr "openai.estimated:true", ] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", 2, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.total", len(chunks) + 2, tags=expected_tags) in mock_metrics.mock_calls + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", 2, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_tracer): @@ -1312,15 +1308,10 @@ def test_chat_completion_stream(openai, openai_vcr, mock_metrics, snapshot_trace assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls expected_tags += ["openai.estimated:true"] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert ( - mock.call.distribution("tokens.total", len(chunks) + prompt_tokens, tags=expected_tags) - in mock_metrics.mock_calls - ) + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls @pytest.mark.asyncio @@ -1367,15 +1358,10 @@ async def test_chat_completion_async_stream(openai, openai_vcr, mock_metrics, sn assert mock.call.gauge("ratelimit.remaining.requests", 2999, tags=expected_tags) in mock_metrics.mock_calls expected_tags += ["openai.estimated:true"] if TIKTOKEN_AVAILABLE: - prompt_expected_tags = expected_tags[:-1] - else: - prompt_expected_tags = expected_tags - assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=prompt_expected_tags) in mock_metrics.mock_calls - assert mock.call.distribution("tokens.completion", len(chunks), tags=expected_tags) in mock_metrics.mock_calls - assert ( - mock.call.distribution("tokens.total", len(chunks) + prompt_tokens, tags=expected_tags) - in mock_metrics.mock_calls - ) + expected_tags = expected_tags[:-1] + assert mock.call.distribution("tokens.prompt", prompt_tokens, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.completion", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls + assert mock.call.distribution("tokens.total", mock.ANY, tags=expected_tags) in mock_metrics.mock_calls @pytest.mark.snapshot( @@ -1914,13 +1900,15 @@ def test_llmobs_completion(openai_vcr, openai, ddtrace_global_config, mock_llmob def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, mock_llmobs_writer, mock_tracer): with openai_vcr.use_cassette("completion_streamed.yaml"): with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: - mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] - model = "ada" - expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' - client = openai.OpenAI() - resp = client.completions.create(model=model, prompt="Hello world", stream=True) - for _ in resp: - pass + with mock.patch("ddtrace.contrib.openai.utils._est_tokens") as mock_est: + mock_encoding.return_value.encode.side_effect = lambda x: [1, 2] + mock_est.return_value = 2 + model = "ada" + expected_completion = '! ... A page layouts page drawer? ... Interesting. The "Tools" is' + client = openai.OpenAI() + resp = client.completions.create(model=model, prompt="Hello world", stream=True) + for _ in resp: + pass span = mock_tracer.pop_traces()[0][0] assert mock_llmobs_writer.enqueue.call_count == 1 mock_llmobs_writer.assert_has_calls( @@ -1933,7 +1921,7 @@ def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, moc input_messages=[{"content": "Hello world"}], output_messages=[{"content": expected_completion}], parameters={"temperature": 0}, - token_metrics={"prompt_tokens": 2, "completion_tokens": 16, "total_tokens": 18}, + token_metrics={"prompt_tokens": 2, "completion_tokens": 2, "total_tokens": 4}, ), ), ] @@ -1991,20 +1979,22 @@ def test_llmobs_chat_completion_stream(openai_vcr, openai, ddtrace_global_config """ with openai_vcr.use_cassette("chat_completion_streamed.yaml"): with mock.patch("ddtrace.contrib.openai.utils.encoding_for_model", create=True) as mock_encoding: - model = "gpt-3.5-turbo" - resp_model = model - input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}] - mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] - expected_completion = "The Los Angeles Dodgers won the World Series in 2020." - client = openai.OpenAI() - resp = client.chat.completions.create( - model=model, - messages=input_messages, - stream=True, - user="ddtrace-test", - ) - for chunk in resp: - resp_model = chunk.model + with mock.patch("ddtrace.contrib.openai.utils._est_tokens") as mock_est: + mock_encoding.return_value.encode.side_effect = lambda x: [1, 2, 3, 4, 5, 6, 7, 8] + mock_est.return_value = 8 + model = "gpt-3.5-turbo" + resp_model = model + input_messages = [{"role": "user", "content": "Who won the world series in 2020?"}] + expected_completion = "The Los Angeles Dodgers won the World Series in 2020." + client = openai.OpenAI() + resp = client.chat.completions.create( + model=model, + messages=input_messages, + stream=True, + user="ddtrace-test", + ) + for chunk in resp: + resp_model = chunk.model span = mock_tracer.pop_traces()[0][0] assert mock_llmobs_writer.enqueue.call_count == 1 mock_llmobs_writer.assert_has_calls( @@ -2017,7 +2007,7 @@ def test_llmobs_chat_completion_stream(openai_vcr, openai, ddtrace_global_config input_messages=input_messages, output_messages=[{"content": expected_completion, "role": "assistant"}], parameters={"temperature": 0}, - token_metrics={"prompt_tokens": 8, "completion_tokens": 15, "total_tokens": 23}, + token_metrics={"prompt_tokens": 8, "completion_tokens": 8, "total_tokens": 16}, ) ), ] From 78888e1bf177e96fda6e28adb15f18bc06d5f988 Mon Sep 17 00:00:00 2001 From: Zarir Hamza Date: Thu, 29 Feb 2024 14:44:28 -0500 Subject: [PATCH 4/8] ci(sqlite3): upgrade tests to support latest versions of sqlite3 (#8532) Adds testing for internal package `sqlite3` without upgrading python versions. Note the comment further down in the PR about potential issues on running CI locally. Seeing as how this is a widely used integration, change is necessary to test across a range of values rather than just one ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- .riot/requirements/{188ca09.txt => 1544815.txt} | 13 +++++++------ .riot/requirements/{1fccf5b.txt => 16eb426.txt} | 12 ++++++------ .riot/requirements/{f4c3852.txt => 1c64cfc.txt} | 11 ++++++----- .riot/requirements/1d73048.txt | 4 ++-- .riot/requirements/{11921fa.txt => 1e311f5.txt} | 11 ++++++----- .riot/requirements/{da724fd.txt => 1fc50b1.txt} | 11 ++++++----- riotfile.py | 7 +++++-- tests/contrib/sqlite3/test_sqlite3.py | 8 ++++++++ tests/contrib/sqlite3/test_sqlite3_patch.py | 7 +++++++ 9 files changed, 53 insertions(+), 31 deletions(-) rename .riot/requirements/{188ca09.txt => 1544815.txt} (63%) rename .riot/requirements/{1fccf5b.txt => 16eb426.txt} (67%) rename .riot/requirements/{f4c3852.txt => 1c64cfc.txt} (63%) rename .riot/requirements/{11921fa.txt => 1e311f5.txt} (63%) rename .riot/requirements/{da724fd.txt => 1fc50b1.txt} (65%) diff --git a/.riot/requirements/188ca09.txt b/.riot/requirements/1544815.txt similarity index 63% rename from .riot/requirements/188ca09.txt rename to .riot/requirements/1544815.txt index 2b537cb86cf..18868464432 100644 --- a/.riot/requirements/188ca09.txt +++ b/.riot/requirements/1544815.txt @@ -2,19 +2,20 @@ # This file is autogenerated by pip-compile with Python 3.9 # by the following command: # -# pip-compile --no-annotate .riot/requirements/188ca09.in +# pip-compile --no-annotate .riot/requirements/1544815.in # -attrs==23.1.0 -coverage[toml]==7.3.4 +attrs==23.2.0 +coverage[toml]==7.4.3 exceptiongroup==1.2.0 hypothesis==6.45.0 -importlib-metadata==7.0.0 +importlib-metadata==7.0.1 iniconfig==2.0.0 mock==5.1.0 opentracing==2.4.0 packaging==23.2 -pluggy==1.3.0 -pytest==7.4.3 +pluggy==1.4.0 +pysqlite3-binary==0.5.2.post3 +pytest==8.0.2 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-randomly==3.15.0 diff --git a/.riot/requirements/1fccf5b.txt b/.riot/requirements/16eb426.txt similarity index 67% rename from .riot/requirements/1fccf5b.txt rename to .riot/requirements/16eb426.txt index f308ba1f91b..e1072294f88 100644 --- a/.riot/requirements/1fccf5b.txt +++ b/.riot/requirements/16eb426.txt @@ -2,19 +2,19 @@ # This file is autogenerated by pip-compile with Python 3.8 # by the following command: # -# pip-compile --no-annotate .riot/requirements/1fccf5b.in +# pip-compile --no-annotate .riot/requirements/16eb426.in # -attrs==23.1.0 -coverage[toml]==7.3.4 +attrs==23.2.0 +coverage[toml]==7.4.3 exceptiongroup==1.2.0 hypothesis==6.45.0 -importlib-metadata==7.0.0 +importlib-metadata==7.0.1 iniconfig==2.0.0 mock==5.1.0 opentracing==2.4.0 packaging==23.2 -pluggy==1.3.0 -pytest==7.4.3 +pluggy==1.4.0 +pytest==8.0.2 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-randomly==3.15.0 diff --git a/.riot/requirements/f4c3852.txt b/.riot/requirements/1c64cfc.txt similarity index 63% rename from .riot/requirements/f4c3852.txt rename to .riot/requirements/1c64cfc.txt index 6ecde539fd4..f99df92dedb 100644 --- a/.riot/requirements/f4c3852.txt +++ b/.riot/requirements/1c64cfc.txt @@ -2,17 +2,18 @@ # This file is autogenerated by pip-compile with Python 3.12 # by the following command: # -# pip-compile --no-annotate .riot/requirements/f4c3852.in +# pip-compile --no-annotate .riot/requirements/1c64cfc.in # -attrs==23.1.0 -coverage[toml]==7.3.4 +attrs==23.2.0 +coverage[toml]==7.4.3 hypothesis==6.45.0 iniconfig==2.0.0 mock==5.1.0 opentracing==2.4.0 packaging==23.2 -pluggy==1.3.0 -pytest==7.4.3 +pluggy==1.4.0 +pysqlite3-binary==0.5.2.post3 +pytest==8.0.2 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-randomly==3.15.0 diff --git a/.riot/requirements/1d73048.txt b/.riot/requirements/1d73048.txt index cee30ee902f..0d8a063b5d1 100644 --- a/.riot/requirements/1d73048.txt +++ b/.riot/requirements/1d73048.txt @@ -4,7 +4,7 @@ # # pip-compile --config=pyproject.toml --no-annotate --resolver=backtracking .riot/requirements/1d73048.in # -attrs==23.1.0 +attrs==23.2.0 coverage[toml]==7.2.7 exceptiongroup==1.2.0 hypothesis==6.45.0 @@ -14,7 +14,7 @@ mock==5.1.0 opentracing==2.4.0 packaging==23.2 pluggy==1.2.0 -pytest==7.4.3 +pytest==7.4.4 pytest-cov==4.1.0 pytest-mock==3.11.1 pytest-randomly==3.12.0 diff --git a/.riot/requirements/11921fa.txt b/.riot/requirements/1e311f5.txt similarity index 63% rename from .riot/requirements/11921fa.txt rename to .riot/requirements/1e311f5.txt index 7f7a4b8d721..8ce95230664 100644 --- a/.riot/requirements/11921fa.txt +++ b/.riot/requirements/1e311f5.txt @@ -2,17 +2,18 @@ # This file is autogenerated by pip-compile with Python 3.11 # by the following command: # -# pip-compile --no-annotate .riot/requirements/11921fa.in +# pip-compile --no-annotate .riot/requirements/1e311f5.in # -attrs==23.1.0 -coverage[toml]==7.3.4 +attrs==23.2.0 +coverage[toml]==7.4.3 hypothesis==6.45.0 iniconfig==2.0.0 mock==5.1.0 opentracing==2.4.0 packaging==23.2 -pluggy==1.3.0 -pytest==7.4.3 +pluggy==1.4.0 +pysqlite3-binary==0.5.2.post3 +pytest==8.0.2 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-randomly==3.15.0 diff --git a/.riot/requirements/da724fd.txt b/.riot/requirements/1fc50b1.txt similarity index 65% rename from .riot/requirements/da724fd.txt rename to .riot/requirements/1fc50b1.txt index 2957d4c7b1c..bfe0a355da3 100644 --- a/.riot/requirements/da724fd.txt +++ b/.riot/requirements/1fc50b1.txt @@ -2,18 +2,19 @@ # This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile --no-annotate .riot/requirements/da724fd.in +# pip-compile --no-annotate .riot/requirements/1fc50b1.in # -attrs==23.1.0 -coverage[toml]==7.3.4 +attrs==23.2.0 +coverage[toml]==7.4.3 exceptiongroup==1.2.0 hypothesis==6.45.0 iniconfig==2.0.0 mock==5.1.0 opentracing==2.4.0 packaging==23.2 -pluggy==1.3.0 -pytest==7.4.3 +pluggy==1.4.0 +pysqlite3-binary==0.5.2.post3 +pytest==8.0.2 pytest-cov==4.1.0 pytest-mock==3.12.0 pytest-randomly==3.15.0 diff --git a/riotfile.py b/riotfile.py index ba565179ce3..f30b6f14063 100644 --- a/riotfile.py +++ b/riotfile.py @@ -2204,8 +2204,11 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION): "pytest-randomly": latest, }, venvs=[ - Venv(pys=select_pys(min_version="3.8")), - Venv(pys=["3.7"], pkgs={"importlib-metadata": latest}), + # sqlite3 is tied to the Python version and is not installable via pip + # To test a range of versions without updating Python, we use Linux only pysqlite3-binary package + # Remove pysqlite3-binary on Python 3.9+ locally on non-linux machines + Venv(pys=select_pys(min_version="3.9"), pkgs={"pysqlite3-binary": [latest]}), + Venv(pys=select_pys(max_version="3.8"), pkgs={"importlib-metadata": latest}), ], ), Venv( diff --git a/tests/contrib/sqlite3/test_sqlite3.py b/tests/contrib/sqlite3/test_sqlite3.py index d3a8fede21e..bdbdd1436af 100644 --- a/tests/contrib/sqlite3/test_sqlite3.py +++ b/tests/contrib/sqlite3/test_sqlite3.py @@ -1,3 +1,11 @@ +import sys + + +try: + sys.modules["sqlite3"] = __import__("pysqlite3") +except ImportError: + pass + import sqlite3 import time from typing import TYPE_CHECKING # noqa:F401 diff --git a/tests/contrib/sqlite3/test_sqlite3_patch.py b/tests/contrib/sqlite3/test_sqlite3_patch.py index ab19d63c141..9515011f308 100644 --- a/tests/contrib/sqlite3/test_sqlite3_patch.py +++ b/tests/contrib/sqlite3/test_sqlite3_patch.py @@ -2,6 +2,13 @@ # script. If you want to make changes to it, you should make sure that you have # removed the ``_generated`` suffix from the file name, to prevent the content # from being overwritten by future re-generations. +import sys + + +try: + sys.modules["sqlite3"] = __import__("pysqlite3") +except ImportError: + pass from ddtrace.contrib.sqlite3 import get_version from ddtrace.contrib.sqlite3.patch import patch From 67e86b2dfd0599ebf20bc786611b621c2d691df9 Mon Sep 17 00:00:00 2001 From: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Date: Thu, 29 Feb 2024 12:10:34 -0800 Subject: [PATCH 5/8] ci: fix more instances of telemetry test writer flakiness (#8560) This pull request resolves [this CI failure from the main branch](https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56923/workflows/d0751cf1-1d7b-42f3-a0df-8cbf4b6dcce2/jobs/3609971) as well as all of the similar-looking cases in the same file. The approach is similar to other recent telemetry test changes, avoiding the assumption that the test agent session is empty of events at the start of a test. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- tests/telemetry/test_writer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/telemetry/test_writer.py b/tests/telemetry/test_writer.py index 4d4255136b5..e8697e4c164 100644 --- a/tests/telemetry/test_writer.py +++ b/tests/telemetry/test_writer.py @@ -60,6 +60,7 @@ def test_add_event_disabled_writer(telemetry_writer, test_agent_session): def test_app_started_event(telemetry_writer, test_agent_session, mock_time): """asserts that _app_started_event() queues a valid telemetry request which is then sent by periodic()""" with override_global_config(dict(_telemetry_dependency_collection=False)): + initial_event_count = len(test_agent_session.get_events()) # queue an app started event telemetry_writer._app_started_event() # force a flush @@ -70,7 +71,7 @@ def test_app_started_event(telemetry_writer, test_agent_session, mock_time): assert requests[0]["headers"]["DD-Telemetry-Request-Type"] == "app-started" events = test_agent_session.get_events() - assert len(events) == 1 + assert len(events) == initial_event_count + 1 events[0]["payload"]["configuration"].sort(key=lambda c: c["name"]) @@ -327,6 +328,7 @@ def test_update_dependencies_event(telemetry_writer, test_agent_session, mock_ti def test_update_dependencies_event_when_disabled(telemetry_writer, test_agent_session, mock_time): with override_global_config(dict(_telemetry_dependency_collection=False)): + initial_event_count = len(test_agent_session.get_events()) TelemetryWriterModuleWatchdog._initial = False TelemetryWriterModuleWatchdog._new_imported.clear() @@ -337,7 +339,7 @@ def test_update_dependencies_event_when_disabled(telemetry_writer, test_agent_se # force a flush telemetry_writer.periodic() events = test_agent_session.get_events() - assert len(events) <= 1 # could have a heartbeat + assert initial_event_count <= len(events) <= initial_event_count + 1 # could have a heartbeat if events: assert events[0]["request_type"] != "app-dependencies-loaded" @@ -437,6 +439,7 @@ def test_add_integration(telemetry_writer, test_agent_session, mock_time): def test_app_client_configuration_changed_event(telemetry_writer, test_agent_session, mock_time): """asserts that queuing a configuration sends a valid telemetry request""" with override_global_config(dict(_telemetry_dependency_collection=False)): + initial_event_count = len(test_agent_session.get_events()) telemetry_writer.add_configuration("appsec_enabled", True) telemetry_writer.add_configuration("DD_TRACE_PROPAGATION_STYLE_EXTRACT", "datadog") telemetry_writer.add_configuration("appsec_enabled", False, "env_var") @@ -444,7 +447,7 @@ def test_app_client_configuration_changed_event(telemetry_writer, test_agent_ses telemetry_writer.periodic() events = test_agent_session.get_events() - assert len(events) == 1 + assert len(events) == initial_event_count + 1 assert events[0]["request_type"] == "app-client-configuration-change" received_configurations = events[0]["payload"]["configuration"] # Sort the configuration list by name @@ -498,6 +501,7 @@ def test_send_failing_request(mock_status, telemetry_writer): def test_telemetry_graceful_shutdown(telemetry_writer, test_agent_session, mock_time): with override_global_config(dict(_telemetry_dependency_collection=False)): + initial_event_count = len(test_agent_session.get_events()) try: telemetry_writer.start() except ServiceStatusError: @@ -508,7 +512,7 @@ def test_telemetry_graceful_shutdown(telemetry_writer, test_agent_session, mock_ telemetry_writer.app_shutdown() events = test_agent_session.get_events() - assert len(events) == 1 + assert len(events) == initial_event_count + 1 # Reverse chronological order assert events[0]["request_type"] == "app-closing" From 1d75efa7ad354d8d67eeb8f8aee8634cc96ddd18 Mon Sep 17 00:00:00 2001 From: Emmett Butler <723615+emmettbutler@users.noreply.github.com> Date: Thu, 29 Feb 2024 13:55:21 -0800 Subject: [PATCH 6/8] ci: honor xfails that happen in subprocess snapshot tests (#8561) This pull request resolves failures like [this one](https://app.circleci.com/pipelines/github/DataDog/dd-trace-py/56935/workflows/ae60ef0e-2215-43e7-be75-004e35ad97f8/jobs/3610594) recently observed on main by propagating xfails that happen in subprocess tests back to the main pytest session. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- tests/conftest.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 38f422e6493..7c54d51c349 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -235,6 +235,11 @@ def run_function_from_file(item, params=None): def _subprocess_wrapper(): out, err, status, _ = call_program(*args, env=env, cwd=cwd, timeout=timeout) + xfailed = b"_pytest.outcomes.XFailed" in err and status == 1 + if xfailed: + pytest.xfail("subprocess test resulted in XFail") + return + if status != expected_status: raise AssertionError( "Expected status %s, got %s." From f3ac6494c86cd844bdaf53d98815ec8c0f6a3a30 Mon Sep 17 00:00:00 2001 From: Munir Abdinur Date: Thu, 29 Feb 2024 17:36:41 -0500 Subject: [PATCH 7/8] feat: kafka trace consume (#8511) Adds tracing and DSM support for https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#confluent_kafka.Consumer.consume ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [x] Title is accurate - [x] All changes are related to the pull request's stated goal - [x] Description motivates each change - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [x] Testing strategy adequately addresses listed risks - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] Release note makes sense to a user of the library - [x] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [x] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --------- Co-authored-by: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> --- ddtrace/contrib/kafka/patch.py | 134 ++++++++++-------- .../trace-kafka-consume-10a797a1305b1cd9.yaml | 4 + tests/contrib/kafka/test_kafka.py | 60 ++++++-- ...st_commit_with_consume_single_message.json | 74 ++++++++++ ...commit_with_consume_with_error[False].json | 71 ++++++++++ ...t_with_consume_with_multiple_messages.json | 110 ++++++++++++++ 6 files changed, 383 insertions(+), 70 deletions(-) create mode 100644 releasenotes/notes/trace-kafka-consume-10a797a1305b1cd9.yaml create mode 100644 tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_single_message.json create mode 100644 tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_error[False].json create mode 100644 tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_multiple_messages.json diff --git a/ddtrace/contrib/kafka/patch.py b/ddtrace/contrib/kafka/patch.py index 0d9c6860031..5c31e33f8d0 100644 --- a/ddtrace/contrib/kafka/patch.py +++ b/ddtrace/contrib/kafka/patch.py @@ -1,4 +1,5 @@ import os +import sys import confluent_kafka @@ -115,8 +116,11 @@ def patch(): for producer in (TracedProducer, TracedSerializingProducer): trace_utils.wrap(producer, "produce", traced_produce) for consumer in (TracedConsumer, TracedDeserializingConsumer): - trace_utils.wrap(consumer, "poll", traced_poll) + trace_utils.wrap(consumer, "poll", traced_poll_or_consume) trace_utils.wrap(consumer, "commit", traced_commit) + + # Consume is not implemented in deserializing consumers + trace_utils.wrap(TracedConsumer, "consume", traced_poll_or_consume) Pin().onto(confluent_kafka.Producer) Pin().onto(confluent_kafka.Consumer) Pin().onto(confluent_kafka.SerializingProducer) @@ -136,6 +140,10 @@ def unpatch(): if trace_utils.iswrapped(consumer.commit): trace_utils.unwrap(consumer, "commit") + # Consume is not implemented in deserializing consumers + if trace_utils.iswrapped(TracedConsumer.consume): + trace_utils.unwrap(TracedConsumer, "consume") + confluent_kafka.Producer = _Producer confluent_kafka.Consumer = _Consumer if _SerializingProducer is not None: @@ -194,7 +202,7 @@ def traced_produce(func, instance, args, kwargs): return func(*args, **kwargs) -def traced_poll(func, instance, args, kwargs): +def traced_poll_or_consume(func, instance, args, kwargs): pin = Pin.get_from(instance) if not pin or not pin.enabled(): return func(*args, **kwargs) @@ -204,67 +212,79 @@ def traced_poll(func, instance, args, kwargs): start_ns = time_ns() # wrap in a try catch and raise exception after span is started err = None + result = None try: - message = func(*args, **kwargs) + result = func(*args, **kwargs) + return result except Exception as e: err = e + raise err + finally: + if isinstance(result, confluent_kafka.Message): + # poll returns a single message + _instrument_message([result], pin, start_ns, instance, err) + elif isinstance(result, list): + # consume returns a list of messages, + _instrument_message(result, pin, start_ns, instance, err) + elif config.kafka.trace_empty_poll_enabled: + _instrument_message([None], pin, start_ns, instance, err) + + +def _instrument_message(messages, pin, start_ns, instance, err): ctx = None - if message and config.kafka.distributed_tracing_enabled and message.headers(): - ctx = Propagator.extract(dict(message.headers())) - if message or config.kafka.trace_empty_poll_enabled: - with pin.tracer.start_span( - name=schematize_messaging_operation(kafkax.CONSUME, provider="kafka", direction=SpanDirection.PROCESSING), - service=trace_utils.ext_service(pin, config.kafka), - span_type=SpanTypes.WORKER, - child_of=ctx if ctx is not None else pin.tracer.context_provider.active(), - activate=True, - ) as span: - # reset span start time to before function call - span.start_ns = start_ns - - span.set_tag_str(MESSAGING_SYSTEM, kafkax.SERVICE) - span.set_tag_str(COMPONENT, config.kafka.integration_name) - span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER) - span.set_tag_str(kafkax.RECEIVED_MESSAGE, str(message is not None)) - span.set_tag_str(kafkax.GROUP_ID, instance._group_id) + # First message is used to extract context and enrich datadog spans + # This approach aligns with the opentelemetry confluent kafka semantics + first_message = messages[0] + if first_message and config.kafka.distributed_tracing_enabled and first_message.headers(): + ctx = Propagator.extract(dict(first_message.headers())) + with pin.tracer.start_span( + name=schematize_messaging_operation(kafkax.CONSUME, provider="kafka", direction=SpanDirection.PROCESSING), + service=trace_utils.ext_service(pin, config.kafka), + span_type=SpanTypes.WORKER, + child_of=ctx if ctx is not None else pin.tracer.context_provider.active(), + activate=True, + ) as span: + # reset span start time to before function call + span.start_ns = start_ns + + for message in messages: if message is not None: - core.set_item("kafka_topic", message.topic()) - core.dispatch("kafka.consume.start", (instance, message, span)) - - message_key = message.key() or "" - message_offset = message.offset() or -1 - span.set_tag_str(kafkax.TOPIC, message.topic()) - - # If this is a deserializing consumer, do not set the key as a tag since we - # do not have the serialization function - if ( - (_DeserializingConsumer is not None and not isinstance(instance, _DeserializingConsumer)) - or isinstance(message_key, str) - or isinstance(message_key, bytes) - ): - span.set_tag_str(kafkax.MESSAGE_KEY, message_key) - span.set_tag(kafkax.PARTITION, message.partition()) - is_tombstone = False - try: - is_tombstone = len(message) == 0 - except TypeError: # https://github.com/confluentinc/confluent-kafka-python/issues/1192 - pass - span.set_tag_str(kafkax.TOMBSTONE, str(is_tombstone)) - span.set_tag(kafkax.MESSAGE_OFFSET, message_offset) - span.set_tag(SPAN_MEASURED_KEY) - rate = config.kafka.get_analytics_sample_rate() - if rate is not None: - span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, rate) - - # raise exception if one was encountered - if err is not None: - raise err - return message - else: + core.set_item("kafka_topic", first_message.topic()) + core.dispatch("kafka.consume.start", (instance, first_message, span)) + + span.set_tag_str(MESSAGING_SYSTEM, kafkax.SERVICE) + span.set_tag_str(COMPONENT, config.kafka.integration_name) + span.set_tag_str(SPAN_KIND, SpanKind.CONSUMER) + span.set_tag_str(kafkax.RECEIVED_MESSAGE, str(first_message is not None)) + span.set_tag_str(kafkax.GROUP_ID, instance._group_id) + if messages[0] is not None: + message_key = messages[0].key() or "" + message_offset = messages[0].offset() or -1 + span.set_tag_str(kafkax.TOPIC, messages[0].topic()) + + # If this is a deserializing consumer, do not set the key as a tag since we + # do not have the serialization function + if ( + (_DeserializingConsumer is not None and not isinstance(instance, _DeserializingConsumer)) + or isinstance(message_key, str) + or isinstance(message_key, bytes) + ): + span.set_tag_str(kafkax.MESSAGE_KEY, message_key) + span.set_tag(kafkax.PARTITION, messages[0].partition()) + is_tombstone = False + try: + is_tombstone = len(messages[0]) == 0 + except TypeError: # https://github.com/confluentinc/confluent-kafka-python/issues/1192 + pass + span.set_tag_str(kafkax.TOMBSTONE, str(is_tombstone)) + span.set_tag(kafkax.MESSAGE_OFFSET, message_offset) + span.set_tag(SPAN_MEASURED_KEY) + rate = config.kafka.get_analytics_sample_rate() + if rate is not None: + span.set_tag(ANALYTICS_SAMPLE_RATE_KEY, rate) + if err is not None: - raise err - else: - return message + span.set_exc_info(*sys.exc_info()) def traced_commit(func, instance, args, kwargs): diff --git a/releasenotes/notes/trace-kafka-consume-10a797a1305b1cd9.yaml b/releasenotes/notes/trace-kafka-consume-10a797a1305b1cd9.yaml new file mode 100644 index 00000000000..7b020ad52a7 --- /dev/null +++ b/releasenotes/notes/trace-kafka-consume-10a797a1305b1cd9.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + kafka: Adds tracing and DSM support for ``confluent_kafka.Consumer.consume()``. Previously only `confluent_kafka.Consumer.poll` was instrumented. diff --git a/tests/contrib/kafka/test_kafka.py b/tests/contrib/kafka/test_kafka.py index a3c86fa4a61..3e7acb3b5e4 100644 --- a/tests/contrib/kafka/test_kafka.py +++ b/tests/contrib/kafka/test_kafka.py @@ -90,10 +90,16 @@ def dummy_tracer(): @pytest.fixture -def tracer(): +def should_filter_empty_polls(): + yield True + + +@pytest.fixture +def tracer(should_filter_empty_polls): patch() t = Tracer() - t.configure(settings={"FILTERS": [KafkaConsumerPollFilter()]}) + if should_filter_empty_polls: + t.configure(settings={"FILTERS": [KafkaConsumerPollFilter()]}) # disable backoff because it makes these tests less reliable t._writer._send_payload_with_backoff = t._writer._send_payload try: @@ -266,6 +272,42 @@ def test_commit(producer, consumer, kafka_topic): consumer.commit(message) +@pytest.mark.snapshot(ignores=["metrics.kafka.message_offset"]) +def test_commit_with_consume_single_message(producer, consumer, kafka_topic): + with override_config("kafka", dict(trace_empty_poll_enabled=False)): + producer.produce(kafka_topic, PAYLOAD, key=KEY) + producer.flush() + # One message is consumed and one span is generated. + messages = consumer.consume(num_messages=1) + assert len(messages) == 1 + consumer.commit(messages[0]) + + +@pytest.mark.snapshot(ignores=["metrics.kafka.message_offset"]) +def test_commit_with_consume_with_multiple_messages(producer, consumer, kafka_topic): + with override_config("kafka", dict(trace_empty_poll_enabled=False)): + producer.produce(kafka_topic, PAYLOAD, key=KEY) + producer.produce(kafka_topic, PAYLOAD, key=KEY) + producer.flush() + # Two messages are consumed but only ONE span is generated + messages = consumer.consume(num_messages=2) + assert len(messages) == 2 + + +@pytest.mark.snapshot(ignores=["metrics.kafka.message_offset", "meta.error.stack"]) +@pytest.mark.parametrize("should_filter_empty_polls", [False]) +def test_commit_with_consume_with_error(producer, consumer, kafka_topic): + producer.produce(kafka_topic, PAYLOAD, key=KEY) + producer.flush() + # Raises an exception by consuming messages after the consumer has been closed + with pytest.raises(TypeError): + # Empty poll spans are filtered out by the KafkaConsumerPollFilter. We need to disable + # it to test error spans. + # Allowing empty poll spans could introduce flakiness in the test. + with override_config("kafka", dict(trace_empty_poll_enabled=True)): + consumer.consume(num_messages=1, invalid_args="invalid_args") + + @pytest.mark.snapshot(ignores=["metrics.kafka.message_offset"]) def test_commit_with_offset(producer, consumer, kafka_topic): with override_config("kafka", dict(trace_empty_poll_enabled=False)): @@ -415,20 +457,10 @@ def _generate_in_subprocess(random_topic): import ddtrace from ddtrace.contrib.kafka.patch import patch from ddtrace.contrib.kafka.patch import unpatch - from ddtrace.filters import TraceFilter + from tests.contrib.kafka.test_kafka import KafkaConsumerPollFilter PAYLOAD = bytes("hueh hueh hueh", encoding="utf-8") - class KafkaConsumerPollFilter(TraceFilter): - def process_trace(self, trace): - # Filter out all poll spans that have no received message - return ( - None - if trace[0].name in {"kafka.consume", "kafka.process"} - and trace[0].get_tag("kafka.received_message") == "False" - else trace - ) - ddtrace.tracer.configure(settings={"FILTERS": [KafkaConsumerPollFilter()]}) # disable backoff because it makes these tests less reliable ddtrace.tracer._writer._send_payload_with_backoff = ddtrace.tracer._writer._send_payload @@ -733,6 +765,7 @@ def test_tracing_context_is_propagated_when_enabled(ddtrace_run_python_code_in_s from tests.contrib.kafka.test_kafka import kafka_topic from tests.contrib.kafka.test_kafka import producer from tests.contrib.kafka.test_kafka import tracer +from tests.contrib.kafka.test_kafka import should_filter_empty_polls from tests.utils import DummyTracer def test(consumer, producer, kafka_topic): @@ -923,6 +956,7 @@ def test_does_not_trace_empty_poll_when_disabled(ddtrace_run_python_code_in_subp from tests.contrib.kafka.test_kafka import kafka_topic from tests.contrib.kafka.test_kafka import producer from tests.contrib.kafka.test_kafka import tracer +from tests.contrib.kafka.test_kafka import should_filter_empty_polls from tests.utils import DummyTracer def test(consumer, producer, kafka_topic): diff --git a/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_single_message.json b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_single_message.json new file mode 100644 index 00000000000..ec0e59e04c2 --- /dev/null +++ b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_single_message.json @@ -0,0 +1,74 @@ +[[ + { + "name": "kafka.consume", + "service": "kafka", + "resource": "kafka.consume", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65dcd1fd00000000", + "component": "kafka", + "kafka.group_id": "test_group", + "kafka.message_key": "test_key", + "kafka.received_message": "True", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_single_message", + "language": "python", + "messaging.system": "kafka", + "pathway.hash": "7964333589438960939", + "runtime-id": "ff074b2cc3b34b63bbdabbfb5bafd0a4", + "span.kind": "consumer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.message_offset": -1, + "kafka.partition": 0, + "process_id": 96733 + }, + "duration": 3198787000, + "start": 1708970490483150000 + }], +[ + { + "name": "kafka.produce", + "service": "kafka", + "resource": "kafka.produce", + "trace_id": 1, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65dcd1f900000000", + "component": "kafka", + "kafka.message_key": "test_key", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_single_message", + "language": "python", + "messaging.kafka.bootstrap.servers": "localhost:29092", + "messaging.system": "kafka", + "pathway.hash": "8904226842384519559", + "runtime-id": "ff074b2cc3b34b63bbdabbfb5bafd0a4", + "span.kind": "producer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.partition": -1, + "process_id": 96733 + }, + "duration": 356000, + "start": 1708970489477615000 + }]] diff --git a/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_error[False].json b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_error[False].json new file mode 100644 index 00000000000..21bdb597019 --- /dev/null +++ b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_error[False].json @@ -0,0 +1,71 @@ +[[ + { + "name": "kafka.consume", + "service": "kafka", + "resource": "kafka.consume", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 1, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65df90df00000000", + "component": "kafka", + "error.message": "'invalid_args' is an invalid keyword argument for this function", + "error.stack": "Traceback (most recent call last):\n File \"/Users/munirabdinur/go/src/github.com/DataDog/dd-trace-py/ddtrace/contrib/kafka/patch.py\", line 222, in traced_poll_or_consume\n raise err\n File \"/Users/munirabdinur/go/src/github.com/DataDog/dd-trace-py/ddtrace/contrib/kafka/patch.py\", line 218, in traced_poll_or_consume\n result = func(*args, **kwargs)\nTypeError: 'invalid_args' is an invalid keyword argument for this function\n", + "error.type": "builtins.TypeError", + "kafka.group_id": "test_group", + "kafka.received_message": "False", + "language": "python", + "messaging.system": "kafka", + "runtime-id": "60490ef14dac4fccae1050b4a5837a51", + "span.kind": "consumer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "process_id": 58905 + }, + "duration": 1366041000, + "start": 1709150430267359000 + }], +[ + { + "name": "kafka.produce", + "service": "kafka", + "resource": "kafka.produce", + "trace_id": 1, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65df90dd00000000", + "component": "kafka", + "kafka.message_key": "test_key", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_with_error_False", + "language": "python", + "messaging.kafka.bootstrap.servers": "localhost:29092", + "messaging.system": "kafka", + "pathway.hash": "8223615727003867653", + "runtime-id": "60490ef14dac4fccae1050b4a5837a51", + "span.kind": "producer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.partition": -1, + "process_id": 58905 + }, + "duration": 542000, + "start": 1709150429264547000 + }]] diff --git a/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_multiple_messages.json b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_multiple_messages.json new file mode 100644 index 00000000000..4c55c127ba9 --- /dev/null +++ b/tests/snapshots/tests.contrib.kafka.test_kafka.test_commit_with_consume_with_multiple_messages.json @@ -0,0 +1,110 @@ +[[ + { + "name": "kafka.consume", + "service": "kafka", + "resource": "kafka.consume", + "trace_id": 0, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65e0d12500000000", + "component": "kafka", + "kafka.group_id": "test_group", + "kafka.message_key": "test_key", + "kafka.received_message": "True", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_with_multiple_messages", + "language": "python", + "messaging.system": "kafka", + "pathway.hash": "4441628080899120654", + "runtime-id": "92caf7109ecc49ee9f9658b2b8a0e917", + "span.kind": "consumer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.message_offset": -1, + "kafka.partition": 0, + "process_id": 12413 + }, + "duration": 4197487000, + "start": 1709232417758567000 + }], +[ + { + "name": "kafka.produce", + "service": "kafka", + "resource": "kafka.produce", + "trace_id": 1, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65e0d12000000000", + "component": "kafka", + "kafka.message_key": "test_key", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_with_multiple_messages", + "language": "python", + "messaging.kafka.bootstrap.servers": "localhost:29092", + "messaging.system": "kafka", + "pathway.hash": "6303792236934717500", + "runtime-id": "92caf7109ecc49ee9f9658b2b8a0e917", + "span.kind": "producer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.partition": -1, + "process_id": 12413 + }, + "duration": 281000, + "start": 1709232416778413000 + }], +[ + { + "name": "kafka.produce", + "service": "kafka", + "resource": "kafka.produce", + "trace_id": 2, + "span_id": 1, + "parent_id": 0, + "type": "worker", + "error": 0, + "meta": { + "_dd.base_service": "", + "_dd.p.dm": "-0", + "_dd.p.tid": "65e0d12000000000", + "component": "kafka", + "kafka.message_key": "test_key", + "kafka.tombstone": "False", + "kafka.topic": "test_commit_with_consume_with_multiple_messages", + "language": "python", + "messaging.kafka.bootstrap.servers": "localhost:29092", + "messaging.system": "kafka", + "pathway.hash": "6303792236934717500", + "runtime-id": "92caf7109ecc49ee9f9658b2b8a0e917", + "span.kind": "producer" + }, + "metrics": { + "_dd.measured": 1, + "_dd.top_level": 1, + "_dd.tracer_kr": 1.0, + "_sampling_priority_v1": 1, + "kafka.partition": -1, + "process_id": 12413 + }, + "duration": 210000, + "start": 1709232416779010000 + }]] From 6bc3fd9dbb1e3d53d4437b953c02513c25c9e585 Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Thu, 29 Feb 2024 19:40:37 -0500 Subject: [PATCH 8/8] feat(llmobs): add function decorators to LLMObs SDK (#8534) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `ddtrace.llmobs.decorators.{agent, workflow, tool, task, llm}` function decorators to the LLMObs SDK for users to instrument their LLM applications. This PR also refactors the existing LLMObs test cases (including in `tests/openai, tests/botocore/test_bedrock, tests/llmobs`) to reuse a helper method that creates expected llm span events). These decorators provides the same functionality as the inline methods introduced in #8476, but now users can instrument their applications using function decorators which is less intrusive than the inline methods. These decorators accept the following arguments: - `model_name` (Required; only for llm()): The name of the invoked LLM. - `model_provider` (Optional; only for llm()): The name of the invoked LLM provider (OpenAI, Bedrock, etc…). If not provided, a default value of “custom” will be set. - `name`: (Optional) The name of the operation. If not provided, the traced function name will be set as the name of the span by default. - `session_id` (Optional): The ID of the underlying user session. Required for tracking sessions, must be set on root spans and spans created in new processes/threads. An example use case: ```python from ddtrace.llmobs.decorators import llm @llm(model_name="my_model_name", name="my_llm_function", model_provider="custom", session_id="my_sesion_id") def make_llm_call(prompt): # user application logic to call llm return response ``` ### Notes: - For all decorators except `llm` (i.e. `agent, workflow, tool. task`) they can be used without any arguments, i.e. `@agent` vs `@agent()`. - These methods do not return a Span object. To annotate spans traced by these function decorators, use `LLMObs.annotate()` without a specified span in the function before/after starting another traced operation. - These function decorators currently do not support wrapping around generators. The inline methods must be used to trace generator functions. ## Checklist - [x] Change(s) are motivated and described in the PR description - [x] Testing strategy is described if automated tests are not included in the PR - [x] Risks are described (performance impact, potential for breakage, maintainability) - [x] Change is maintainable (easy to change, telemetry, documentation) - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) are followed or label `changelog/no-changelog` is set - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)) - [x] Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) - [x] If this PR changes the public interface, I've notified `@DataDog/apm-tees`. - [x] If change touches code that signs or publishes builds or packages, or handles credentials of any kind, I've requested a review from `@DataDog/security-design-and-guidance`. ## Reviewer Checklist - [ ] Title is accurate - [ ] All changes are related to the pull request's stated goal - [ ] Description motivates each change - [ ] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - [ ] Testing strategy adequately addresses listed risks - [ ] Change is maintainable (easy to change, telemetry, documentation) - [ ] Release note makes sense to a user of the library - [ ] Author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - [ ] Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/llmobs/decorators.py | 60 ++++++ tests/contrib/botocore/test_bedrock.py | 93 +++----- tests/contrib/openai/test_openai_v0.py | 42 ++-- tests/contrib/openai/test_openai_v1.py | 37 ++-- tests/contrib/openai/utils.py | 51 ----- tests/llmobs/_utils.py | 153 +++++++++++++ tests/llmobs/conftest.py | 2 +- tests/llmobs/test_llmobs_decorators.py | 286 +++++++++++++++++++++++++ tests/llmobs/test_llmobs_service.py | 145 ++----------- tests/llmobs/test_llmobs_writer.py | 2 +- tests/llmobs/test_logger.py | 2 +- tests/llmobs/utils.py | 13 -- 12 files changed, 595 insertions(+), 291 deletions(-) create mode 100644 ddtrace/llmobs/decorators.py create mode 100644 tests/llmobs/_utils.py create mode 100644 tests/llmobs/test_llmobs_decorators.py delete mode 100644 tests/llmobs/utils.py diff --git a/ddtrace/llmobs/decorators.py b/ddtrace/llmobs/decorators.py new file mode 100644 index 00000000000..af2d7462a7c --- /dev/null +++ b/ddtrace/llmobs/decorators.py @@ -0,0 +1,60 @@ +from functools import wraps + +from ddtrace.internal.logger import get_logger +from ddtrace.llmobs import LLMObs + + +log = get_logger(__name__) + + +def llm(model_name, model_provider=None, name=None, session_id=None): + def inner(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not LLMObs.enabled or LLMObs._instance is None: + log.warning("LLMObs.llm() cannot be used while LLMObs is disabled.") + return func(*args, **kwargs) + span_name = name + if span_name is None: + span_name = func.__name__ + with LLMObs.llm( + model_name=model_name, + model_provider=model_provider, + name=span_name, + session_id=session_id, + ): + return func(*args, **kwargs) + + return wrapper + + return inner + + +def llmobs_decorator(operation_kind): + def decorator(original_func=None, name=None, session_id=None): + def inner(func): + @wraps(func) + def wrapper(*args, **kwargs): + if not LLMObs.enabled or LLMObs._instance is None: + log.warning("LLMObs.{}() cannot be used while LLMObs is disabled.", operation_kind) + return func(*args, **kwargs) + span_name = name + if span_name is None: + span_name = func.__name__ + traced_operation = getattr(LLMObs, operation_kind, "workflow") + with traced_operation(name=span_name, session_id=session_id): + return func(*args, **kwargs) + + return wrapper + + if original_func and callable(original_func): + return inner(original_func) + return inner + + return decorator + + +workflow = llmobs_decorator("workflow") +task = llmobs_decorator("task") +tool = llmobs_decorator("tool") +agent = llmobs_decorator("agent") diff --git a/tests/contrib/botocore/test_bedrock.py b/tests/contrib/botocore/test_bedrock.py index b77a246db2d..77ac109a8e0 100644 --- a/tests/contrib/botocore/test_bedrock.py +++ b/tests/contrib/botocore/test_bedrock.py @@ -8,6 +8,7 @@ from ddtrace.contrib.botocore.patch import patch from ddtrace.contrib.botocore.patch import unpatch from ddtrace.llmobs import LLMObs +from tests.llmobs._utils import _expected_llmobs_llm_span_event from tests.subprocesstest import SubprocessTestCase from tests.subprocesstest import run_in_subprocess from tests.utils import DummyTracer @@ -403,47 +404,26 @@ class TestLLMObsBedrock: def _expected_llmobs_calls(span, n_output): prompt_tokens = int(span.get_tag("bedrock.usage.prompt_tokens")) completion_tokens = int(span.get_tag("bedrock.usage.completion_tokens")) - - expected_tags = [ - "version:", - "env:", - "service:aws.bedrock-runtime", - "source:integration", - "ml_app:unnamed-ml-app", - "error:0", - ] expected_llmobs_writer_calls = [mock.call.start()] expected_llmobs_writer_calls += [ mock.call.enqueue( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": expected_tags, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": { - "span.kind": "llm", - "model_name": span.get_tag("bedrock.request.model"), - "model_provider": span.get_tag("bedrock.request.model_provider"), - "input": { - "messages": [{"content": mock.ANY}], - "parameters": { - "temperature": float(span.get_tag("bedrock.request.temperature")), - "max_tokens": int(span.get_tag("bedrock.request.max_tokens")), - }, - }, - "output": {"messages": [{"content": mock.ANY} for _ in range(n_output)]}, + _expected_llmobs_llm_span_event( + span, + model_name=span.get_tag("bedrock.request.model"), + model_provider=span.get_tag("bedrock.request.model_provider"), + input_messages=[{"content": mock.ANY}], + output_messages=[{"content": mock.ANY} for _ in range(n_output)], + parameters={ + "temperature": float(span.get_tag("bedrock.request.temperature")), + "max_tokens": int(span.get_tag("bedrock.request.max_tokens")), }, - "metrics": { + token_metrics={ "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": prompt_tokens + completion_tokens, }, - }, + tags={"service": "aws.bedrock-runtime"}, + ) ) ] return expected_llmobs_writer_calls @@ -583,44 +563,23 @@ def test_llmobs_error(self, ddtrace_global_config, bedrock_client, mock_llmobs_w json.loads(response.get("body").read()) span = mock_tracer.pop_traces()[0][0] - expected_tags = [ - "version:", - "env:", - "service:aws.bedrock-runtime", - "source:integration", - "ml_app:unnamed-ml-app", - "error:1", - "error_type:%s" % span.get_tag("error.type"), - ] expected_llmobs_writer_calls = [ mock.call.start(), mock.call.enqueue( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": expected_tags, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 1, - "meta": { - "span.kind": "llm", - "error.message": span.get_tag("error.message"), - "model_name": span.get_tag("bedrock.request.model"), - "model_provider": span.get_tag("bedrock.request.model_provider"), - "input": { - "messages": [{"content": mock.ANY}], - "parameters": { - "temperature": float(span.get_tag("bedrock.request.temperature")), - "max_tokens": int(span.get_tag("bedrock.request.max_tokens")), - }, - }, - "output": {"messages": [{"content": ""}]}, + _expected_llmobs_llm_span_event( + span, + model_name=span.get_tag("bedrock.request.model"), + model_provider=span.get_tag("bedrock.request.model_provider"), + input_messages=[{"content": mock.ANY}], + parameters={ + "temperature": float(span.get_tag("bedrock.request.temperature")), + "max_tokens": int(span.get_tag("bedrock.request.max_tokens")), }, - "metrics": {}, - }, + output_messages=[{"content": ""}], + error=span.get_tag("error.type"), + error_message=span.get_tag("error.message"), + tags={"service": "aws.bedrock-runtime"}, + ) ), ] diff --git a/tests/contrib/openai/test_openai_v0.py b/tests/contrib/openai/test_openai_v0.py index ce8fb79b528..cb468f6f15a 100644 --- a/tests/contrib/openai/test_openai_v0.py +++ b/tests/contrib/openai/test_openai_v0.py @@ -12,9 +12,9 @@ from ddtrace import patch from ddtrace.contrib.openai.utils import _est_tokens from ddtrace.internal.utils.version import parse_version -from tests.contrib.openai.utils import _expected_llmobs_span_event from tests.contrib.openai.utils import get_openai_vcr from tests.contrib.openai.utils import iswrapped +from tests.llmobs._utils import _expected_llmobs_llm_span_event from tests.utils import override_global_config from tests.utils import snapshot_context @@ -2200,9 +2200,10 @@ def test_llmobs_completion(openai_vcr, openai, ddtrace_global_config, mock_llmob [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": ", relax!” I said to my laptop"}, {"content": " (1"}], parameters={"temperature": 0.8, "max_tokens": 10}, @@ -2227,9 +2228,10 @@ def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, moc [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": expected_completion}], parameters={"temperature": 0}, @@ -2269,9 +2271,10 @@ def test_llmobs_chat_completion(openai_vcr, openai, ddtrace_global_config, mock_ [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp.model, + model_name=resp.model, + model_provider="openai", input_messages=input_messages, output_messages=[ {"role": "assistant", "content": choice.message.content} for choice in resp.choices @@ -2315,9 +2318,10 @@ async def test_llmobs_chat_completion_stream( [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp_model, + model_name=resp_model, + model_provider="openai", input_messages=input_messages, output_messages=[{"content": expected_completion, "role": "assistant"}], parameters={"temperature": 0}, @@ -2350,9 +2354,10 @@ def test_llmobs_chat_completion_function_call( [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp.model, + model_name=resp.model, + model_provider="openai", input_messages=[{"content": chat_completion_input_description, "role": "user"}], output_messages=[{"content": resp.choices[0].message.function_call.arguments, "role": "assistant"}], parameters={"temperature": 0}, @@ -2392,9 +2397,10 @@ def test_llmobs_chat_completion_function_call_stream( [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp_model, + model_name=resp_model, + model_provider="openai", input_messages=[{"content": chat_completion_input_description, "role": "user"}], output_messages=[{"content": expected_output, "role": "assistant"}], parameters={"temperature": 0}, @@ -2420,9 +2426,10 @@ def test_llmobs_completion_error(openai_vcr, openai, ddtrace_global_config, mock [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": ""}], parameters={"temperature": 0.8, "max_tokens": 10}, @@ -2462,9 +2469,10 @@ def test_llmobs_chat_completion_error(openai_vcr, openai, ddtrace_global_config, [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=input_messages, output_messages=[{"content": ""}], parameters={"temperature": 0}, diff --git a/tests/contrib/openai/test_openai_v1.py b/tests/contrib/openai/test_openai_v1.py index bf7935317ac..b015314160b 100644 --- a/tests/contrib/openai/test_openai_v1.py +++ b/tests/contrib/openai/test_openai_v1.py @@ -10,9 +10,9 @@ from ddtrace import patch from ddtrace.contrib.openai.utils import _est_tokens from ddtrace.internal.utils.version import parse_version -from tests.contrib.openai.utils import _expected_llmobs_span_event from tests.contrib.openai.utils import get_openai_vcr from tests.contrib.openai.utils import iswrapped +from tests.llmobs._utils import _expected_llmobs_llm_span_event from tests.utils import override_global_config from tests.utils import snapshot_context @@ -1883,9 +1883,10 @@ def test_llmobs_completion(openai_vcr, openai, ddtrace_global_config, mock_llmob [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": ", relax!” I said to my laptop"}, {"content": " (1"}], parameters={"temperature": 0.8, "max_tokens": 10}, @@ -1915,9 +1916,10 @@ def test_llmobs_completion_stream(openai_vcr, openai, ddtrace_global_config, moc [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": expected_completion}], parameters={"temperature": 0}, @@ -1956,9 +1958,10 @@ def test_llmobs_chat_completion(openai_vcr, openai, ddtrace_global_config, mock_ [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp.model, + model_name=resp.model, + model_provider="openai", input_messages=input_messages, output_messages=[ {"role": "assistant", "content": choice.message.content} for choice in resp.choices @@ -2001,9 +2004,10 @@ def test_llmobs_chat_completion_stream(openai_vcr, openai, ddtrace_global_config [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp_model, + model_name=resp_model, + model_provider="openai", input_messages=input_messages, output_messages=[{"content": expected_completion, "role": "assistant"}], parameters={"temperature": 0}, @@ -2035,9 +2039,10 @@ def test_llmobs_chat_completion_function_call( [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=resp.model, + model_name=resp.model, + model_provider="openai", input_messages=[{"content": chat_completion_input_description, "role": "user"}], output_messages=[{"content": resp.choices[0].message.function_call.arguments, "role": "assistant"}], parameters={"temperature": 0}, @@ -2070,9 +2075,10 @@ def test_llmobs_completion_error(openai_vcr, openai, ddtrace_global_config, mock [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=[{"content": "Hello world"}], output_messages=[{"content": ""}], parameters={"temperature": 0.8, "max_tokens": 10}, @@ -2111,9 +2117,10 @@ def test_llmobs_chat_completion_error(openai_vcr, openai, ddtrace_global_config, [ mock.call.start(), mock.call.enqueue( - _expected_llmobs_span_event( + _expected_llmobs_llm_span_event( span, - model=model, + model_name=model, + model_provider="openai", input_messages=input_messages, output_messages=[{"content": ""}], parameters={"temperature": 0}, diff --git a/tests/contrib/openai/utils.py b/tests/contrib/openai/utils.py index cd29bc41b75..40dc1d8ec17 100644 --- a/tests/contrib/openai/utils.py +++ b/tests/contrib/openai/utils.py @@ -28,54 +28,3 @@ def get_openai_vcr(subdirectory_name=""): # Ignore requests to the agent ignore_localhost=True, ) - - -def _expected_llmobs_tags(error=None): - expected_tags = [ - "version:", - "env:", - "service:", - "source:integration", - "ml_app:unnamed-ml-app", - ] - if error: - expected_tags.append("error:1") - expected_tags.append("error_type:{}".format(error)) - else: - expected_tags.append("error:0") - return expected_tags - - -def _expected_llmobs_span_event( - span, - model, - input_messages, - output_messages, - parameters, - token_metrics, - session_id=None, - error=None, - error_message=None, -): - span_event = { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(session_id or span.trace_id), - "name": span.name, - "tags": _expected_llmobs_tags(error=error), - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 1 if error else 0, - "meta": { - "span.kind": "llm", - "model_name": model, - "model_provider": "openai", - "input": {"messages": input_messages, "parameters": parameters}, - "output": {"messages": output_messages}, - }, - "metrics": token_metrics, - } - if error: - span_event["meta"]["error.message"] = error_message - return span_event diff --git a/tests/llmobs/_utils.py b/tests/llmobs/_utils.py new file mode 100644 index 00000000000..b020eb03d8f --- /dev/null +++ b/tests/llmobs/_utils.py @@ -0,0 +1,153 @@ +import os + +import vcr + + +logs_vcr = vcr.VCR( + cassette_library_dir=os.path.join(os.path.dirname(__file__), "llmobs_cassettes/"), + record_mode="once", + match_on=["path"], + filter_headers=[("DD-API-KEY", "XXXXXX")], + # Ignore requests to the agent + ignore_localhost=True, +) + + +def _expected_llmobs_tags(error=None, tags=None): + if tags is None: + tags = {} + expected_tags = [ + "version:{}".format(tags.get("version", "")), + "env:{}".format(tags.get("env", "")), + "service:{}".format(tags.get("service", "")), + "source:integration", + "ml_app:{}".format(tags.get("ml_app", "unnamed-ml-app")), + ] + if error: + expected_tags.append("error:1") + expected_tags.append("error_type:{}".format(error)) + else: + expected_tags.append("error:0") + if tags: + expected_tags.extend( + "{}:{}".format(k, v) for k, v in tags.items() if k not in ("version", "env", "service", "ml_app") + ) + return expected_tags + + +def _expected_llmobs_llm_span_event( + span, + span_kind="llm", + input_messages=None, + output_messages=None, + parameters=None, + token_metrics=None, + model_name=None, + model_provider=None, + tags=None, + session_id=None, + error=None, + error_message=None, +): + """ + Helper function to create an expected LLM span event. + span_kind: either "llm" or "agent" + input_messages: list of input messages in format {"content": "...", "optional_role", "..."} + output_messages: list of output messages in format {"content": "...", "optional_role", "..."} + parameters: dict of input parameters + token_metrics: dict of token metrics (e.g. prompt_tokens, completion_tokens, total_tokens) + model_name: name of the model + model_provider: name of the model provider + tags: dict of tags to add/override on span + session_id: session ID + error: error type + error_message: error message + """ + span_event = _llmobs_base_span_event(span, span_kind, tags, session_id, error, error_message) + meta_dict = {"input": {}, "output": {}} + if input_messages is not None: + meta_dict["input"].update({"messages": input_messages}) + if output_messages is not None: + meta_dict["output"].update({"messages": output_messages}) + if parameters is not None: + meta_dict["input"].update({"parameters": parameters}) + if model_name is not None: + meta_dict.update({"model_name": model_name}) + if model_provider is not None: + meta_dict.update({"model_provider": model_provider}) + if not meta_dict["input"]: + meta_dict.pop("input") + if not meta_dict["output"]: + meta_dict.pop("output") + span_event["meta"].update(meta_dict) + if token_metrics is not None: + span_event["metrics"].update(token_metrics) + return span_event + + +def _expected_llmobs_non_llm_span_event( + span, + span_kind, + input_value=None, + output_value=None, + parameters=None, + token_metrics=None, + tags=None, + session_id=None, + error=None, + error_message=None, +): + """ + Helper function to create an expected span event of type (workflow, task, tool). + span_kind: one of "workflow", "task", "tool" + input_value: input value string + output_value: output value string + parameters: dict of input parameters + token_metrics: dict of token metrics (e.g. prompt_tokens, completion_tokens, total_tokens) + tags: dict of tags to add/override on span + session_id: session ID + error: error type + error_message: error message + """ + span_event = _llmobs_base_span_event(span, span_kind, tags, session_id, error, error_message) + meta_dict = {"input": {}, "output": {}} + if input_value is not None: + meta_dict["input"].update({"value": input_value}) + if parameters is not None: + meta_dict["input"].update({"parameters": parameters}) + if output_value is not None: + meta_dict["output"].update({"messages": output_value}) + if not meta_dict["input"]: + meta_dict.pop("input") + if not meta_dict["output"]: + meta_dict.pop("output") + span_event["meta"].update(meta_dict) + if token_metrics is not None: + span_event["metrics"].update(token_metrics) + return span_event + + +def _llmobs_base_span_event( + span, + span_kind, + tags=None, + session_id=None, + error=None, + error_message=None, +): + span_event = { + "span_id": str(span.span_id), + "trace_id": "{:x}".format(span.trace_id), + "parent_id": "", + "session_id": session_id or "{:x}".format(span.trace_id), + "name": span.name, + "tags": _expected_llmobs_tags(tags=tags, error=error), + "start_ns": span.start_ns, + "duration": span.duration_ns, + "error": 1 if error else 0, + "meta": {"span.kind": span_kind}, + "metrics": {}, + } + if error: + span_event["meta"]["error.message"] = error_message + return span_event diff --git a/tests/llmobs/conftest.py b/tests/llmobs/conftest.py index 5339409bff9..b0ad5094d02 100644 --- a/tests/llmobs/conftest.py +++ b/tests/llmobs/conftest.py @@ -4,7 +4,7 @@ import pytest from ddtrace.llmobs import LLMObs as llmobs_service -from tests.llmobs.utils import logs_vcr +from tests.llmobs._utils import logs_vcr from tests.utils import DummyTracer from tests.utils import override_global_config from tests.utils import request_token diff --git a/tests/llmobs/test_llmobs_decorators.py b/tests/llmobs/test_llmobs_decorators.py new file mode 100644 index 00000000000..11e615b5471 --- /dev/null +++ b/tests/llmobs/test_llmobs_decorators.py @@ -0,0 +1,286 @@ +import mock +import pytest + +from ddtrace.llmobs.decorators import agent +from ddtrace.llmobs.decorators import llm +from ddtrace.llmobs.decorators import task +from ddtrace.llmobs.decorators import tool +from ddtrace.llmobs.decorators import workflow +from tests.llmobs._utils import _expected_llmobs_llm_span_event +from tests.llmobs._utils import _expected_llmobs_non_llm_span_event + + +@pytest.fixture +def mock_logs(): + with mock.patch("ddtrace.llmobs.decorators.log") as mock_logs: + yield mock_logs + + +def test_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): + @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + pass + + LLMObs.disable() + f() + mock_logs.warning.assert_called_with("LLMObs.llm() cannot be used while LLMObs is disabled.") + + +def test_non_llm_decorator_with_llmobs_disabled_logs_warning(LLMObs, mock_logs): + for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool), ("agent", agent)]: + + @decorator(name="test_function", session_id="test_session_id") + def f(): + pass + + LLMObs.disable() + f() + mock_logs.warning.assert_called_with("LLMObs.{}() cannot be used while LLMObs is disabled.", decorator_name) + mock_logs.reset_mock() + + +def test_llm_decorator(LLMObs, mock_llmobs_writer): + @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, "llm", model_name="test_model", model_provider="test_provider", session_id="test_session_id" + ) + ) + + +def test_llm_decorator_no_model_name_raises_error(LLMObs, mock_llmobs_writer): + with pytest.raises(TypeError): + + @llm(model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + pass + + +def test_llm_decorator_default_kwargs(LLMObs, mock_llmobs_writer): + @llm(model_name="test_model") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "llm", model_name="test_model", model_provider="custom") + ) + + +def test_task_decorator(LLMObs, mock_llmobs_writer): + @task(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "task", session_id="test_session_id") + ) + + +def test_task_decorator_default_kwargs(LLMObs, mock_llmobs_writer): + @task() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) + + +def test_tool_decorator(LLMObs, mock_llmobs_writer): + @tool(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "tool", session_id="test_session_id") + ) + + +def test_tool_decorator_default_kwargs(LLMObs, mock_llmobs_writer): + @tool() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) + + +def test_workflow_decorator(LLMObs, mock_llmobs_writer): + @workflow(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event(span, "workflow", session_id="test_session_id") + ) + + +def test_workflow_decorator_default_kwargs(LLMObs, mock_llmobs_writer): + @workflow() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) + + +def test_agent_decorator(LLMObs, mock_llmobs_writer): + @agent(name="test_function", session_id="test_session_id") + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event(span, "agent", session_id="test_session_id") + ) + + +def test_agent_decorator_default_kwargs(LLMObs, mock_llmobs_writer): + @agent() + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) + + +def test_llm_decorator_with_error(LLMObs, mock_llmobs_writer): + @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + raise ValueError("test_error") + + with pytest.raises(ValueError): + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, + "llm", + model_name="test_model", + model_provider="test_provider", + session_id="test_session_id", + error="builtins.ValueError", + error_message="test_error", + ) + ) + + +def test_non_llm_decorators_with_error(LLMObs, mock_llmobs_writer): + for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool), ("agent", agent)]: + + @decorator(name="test_function", session_id="test_session_id") + def f(): + raise ValueError("test_error") + + with pytest.raises(ValueError): + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_non_llm_span_event( + span, + decorator_name, + session_id="test_session_id", + error="builtins.ValueError", + error_message="test_error", + ) + ) + + +def test_llm_annotate(LLMObs, mock_llmobs_writer): + @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + LLMObs.annotate( + parameters={"temperature": 0.9, "max_tokens": 50}, + input_data=[{"content": "test_prompt"}], + output_data=[{"content": "test_response"}], + tags={"custom_tag": "tag_value"}, + metrics={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, + "llm", + model_name="test_model", + model_provider="test_provider", + input_messages=[{"content": "test_prompt"}], + output_messages=[{"content": "test_response"}], + parameters={"temperature": 0.9, "max_tokens": 50}, + token_metrics={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + tags={"custom_tag": "tag_value"}, + session_id="test_session_id", + ) + ) + + +def test_llm_annotate_raw_string_io(LLMObs, mock_llmobs_writer): + @llm(model_name="test_model", model_provider="test_provider", name="test_function", session_id="test_session_id") + def f(): + LLMObs.annotate( + parameters={"temperature": 0.9, "max_tokens": 50}, + input_data="test_prompt", + output_data="test_response", + tags={"custom_tag": "tag_value"}, + metrics={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with( + _expected_llmobs_llm_span_event( + span, + "llm", + model_name="test_model", + model_provider="test_provider", + input_messages=[{"content": "test_prompt"}], + output_messages=[{"content": "test_response"}], + parameters={"temperature": 0.9, "max_tokens": 50}, + token_metrics={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + tags={"custom_tag": "tag_value"}, + session_id="test_session_id", + ) + ) + + +def test_non_llm_decorators_no_args(LLMObs, mock_llmobs_writer): + """Test that using the decorators without any arguments, i.e. @tool, works the same as @tool(...).""" + for decorator_name, decorator in [("task", task), ("workflow", workflow), ("tool", tool)]: + + @decorator + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, decorator_name)) + + +def test_agent_decorator_no_args(LLMObs, mock_llmobs_writer): + """Test that using agent decorator without any arguments, i.e. @agent, works the same as @agent(...).""" + + @agent + def f(): + pass + + f() + span = LLMObs._instance.tracer.pop()[0] + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) diff --git a/tests/llmobs/test_llmobs_service.py b/tests/llmobs/test_llmobs_service.py index ee4b533caf6..4d74a4ea79a 100644 --- a/tests/llmobs/test_llmobs_service.py +++ b/tests/llmobs/test_llmobs_service.py @@ -16,6 +16,8 @@ from ddtrace.llmobs._constants import SPAN_KIND from ddtrace.llmobs._constants import TAGS from ddtrace.llmobs._llmobs import LLMObsTraceProcessor +from tests.llmobs._utils import _expected_llmobs_llm_span_event +from tests.llmobs._utils import _expected_llmobs_non_llm_span_event from tests.utils import DummyTracer from tests.utils import override_global_config @@ -120,19 +122,7 @@ def test_llmobs_session_id_becomes_top_level_field(LLMObs, mock_llmobs_writer): with LLMObs.task(session_id=session_id) as span: pass mock_llmobs_writer.enqueue.assert_called_with( - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(span.span_id), - "parent_id": "", - "session_id": session_id, - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "task"}, - "metrics": mock.ANY, - }, + _expected_llmobs_non_llm_span_event(span, "task", session_id=session_id) ) @@ -147,19 +137,7 @@ def test_llmobs_llm_span(LLMObs, mock_llmobs_writer): assert span.get_tag(SESSION_ID) is None mock_llmobs_writer.enqueue.assert_called_with( - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(span.span_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "llm", "model_name": "test_model", "model_provider": "test_provider"}, - "metrics": mock.ANY, - }, + _expected_llmobs_llm_span_event(span, "llm", model_name="test_model", model_provider="test_provider") ) @@ -192,21 +170,7 @@ def test_llmobs_tool_span(LLMObs, mock_llmobs_writer): assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "tool" assert span.get_tag(SESSION_ID) is None - mock_llmobs_writer.enqueue.assert_called_with( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "tool"}, - "metrics": mock.ANY, - }, - ) + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "tool")) def test_llmobs_task_span(LLMObs, mock_llmobs_writer): @@ -216,21 +180,7 @@ def test_llmobs_task_span(LLMObs, mock_llmobs_writer): assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "task" assert span.get_tag(SESSION_ID) is None - mock_llmobs_writer.enqueue.assert_called_with( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "task"}, - "metrics": mock.ANY, - }, - ) + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "task")) def test_llmobs_workflow_span(LLMObs, mock_llmobs_writer): @@ -240,21 +190,7 @@ def test_llmobs_workflow_span(LLMObs, mock_llmobs_writer): assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "workflow" assert span.get_tag(SESSION_ID) is None - mock_llmobs_writer.enqueue.assert_called_with( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "workflow"}, - "metrics": mock.ANY, - }, - ) + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_non_llm_span_event(span, "workflow")) def test_llmobs_agent_span(LLMObs, mock_llmobs_writer): @@ -264,21 +200,7 @@ def test_llmobs_agent_span(LLMObs, mock_llmobs_writer): assert span.span_type == "llm" assert span.get_tag(SPAN_KIND) == "agent" assert span.get_tag(SESSION_ID) is None - mock_llmobs_writer.enqueue.assert_called_with( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "agent"}, - "metrics": mock.ANY, - }, - ) + mock_llmobs_writer.enqueue.assert_called_with(_expected_llmobs_llm_span_event(span, "agent")) def test_llmobs_annotate_while_disabled_logs_warning(LLMObs, mock_logs): @@ -389,24 +311,13 @@ def test_llmobs_span_error_sets_error(LLMObs, mock_llmobs_writer): with LLMObs.llm(model_name="test_model", model_provider="test_model_provider") as span: raise ValueError("test error message") mock_llmobs_writer.enqueue.assert_called_with( - { - "trace_id": "{:x}".format(span.trace_id), - "span_id": str(span.span_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": mock.ANY, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 1, - "meta": { - "span.kind": "llm", - "model_name": "test_model", - "model_provider": "test_model_provider", - "error.message": "test error message", - }, - "metrics": mock.ANY, - }, + _expected_llmobs_llm_span_event( + span, + model_name="test_model", + model_provider="test_model_provider", + error="builtins.ValueError", + error_message="test error message", + ) ) @@ -415,26 +326,10 @@ def test_llmobs_tags(ddtrace_global_config, LLMObs, mock_llmobs_writer, monkeypa monkeypatch.setenv("DD_LLMOBS_APP_NAME", "test_app_name") with LLMObs.task(name="test_task") as span: pass - expected_tags = [ - "version:1.2.3", - "env:test_env", - "service:test_service", - "source:integration", - "ml_app:test_app_name", - "error:0", - ] mock_llmobs_writer.enqueue.assert_called_with( - { - "span_id": str(span.span_id), - "trace_id": "{:x}".format(span.trace_id), - "parent_id": "", - "session_id": "{:x}".format(span.trace_id), - "name": span.name, - "tags": expected_tags, - "start_ns": span.start_ns, - "duration": span.duration_ns, - "error": 0, - "meta": {"span.kind": "task"}, - "metrics": mock.ANY, - }, + _expected_llmobs_non_llm_span_event( + span, + "task", + tags={"version": "1.2.3", "env": "test_env", "service": "test_service", "ml_app": "test_app_name"}, + ) ) diff --git a/tests/llmobs/test_llmobs_writer.py b/tests/llmobs/test_llmobs_writer.py index f42f182eaa0..32248cc1c4b 100644 --- a/tests/llmobs/test_llmobs_writer.py +++ b/tests/llmobs/test_llmobs_writer.py @@ -168,7 +168,7 @@ def test_send_on_exit(mock_logs, run_python_code_in_subprocess): from ddtrace.llmobs._writer import LLMObsWriter from tests.llmobs.test_llmobs_writer import _completion_event -from tests.llmobs.utils import logs_vcr +from tests.llmobs._utils import logs_vcr ctx = logs_vcr.use_cassette("tests.llmobs.test_llmobs_writer.test_send_on_exit.yaml") ctx.__enter__() diff --git a/tests/llmobs/test_logger.py b/tests/llmobs/test_logger.py index d3fe0f11967..d33791c4fd4 100644 --- a/tests/llmobs/test_logger.py +++ b/tests/llmobs/test_logger.py @@ -94,7 +94,7 @@ def test_send_on_exit(): import time from ddtrace.llmobs._log_writer import V2LogWriter - from tests.llmobs.utils import logs_vcr + from tests.llmobs._utils import logs_vcr ctx = logs_vcr.use_cassette("tests.llmobs.test_logger.test_send_on_exit.yaml") ctx.__enter__() diff --git a/tests/llmobs/utils.py b/tests/llmobs/utils.py deleted file mode 100644 index 261740843b7..00000000000 --- a/tests/llmobs/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import os - -import vcr - - -logs_vcr = vcr.VCR( - cassette_library_dir=os.path.join(os.path.dirname(__file__), "llmobs_cassettes/"), - record_mode="once", - match_on=["path"], - filter_headers=[("DD-API-KEY", "XXXXXX")], - # Ignore requests to the agent - ignore_localhost=True, -)