From e4742671776a9e09a29846a976c5804b08d252b1 Mon Sep 17 00:00:00 2001 From: Yun Kim <35776586+Yun-Kim@users.noreply.github.com> Date: Thu, 12 Dec 2024 18:34:35 -0500 Subject: [PATCH] chore(llmobs): use span store instead of temporary tags (#11543) This PR performs some cleanup refactors on the LLM Obs SDK and associated integrations. Specifically regarding the data stored, which includes LLMObs span metadata/metrics/tags/IO: - Stop storing these as temporary span tags and instead use the span store field, which allows arbitrary key value pairs but is not submitted to Datadog. This removes the potential for temporary tags to be not extracted and still submitted as a APM span tag. - Stop attempting `safe_json()` (i.e. `json.dumps()`) to store the above data, which is an expensive operation that adds up with the number of separate calls, and instead just store the raw values of the stored objects in the store field, and only call `safe_json()` "once" at payload encoding time. Things to look out for: - Previously we were calling `safe_json()` every time to store data as string span tags. One danger includes errors during span processing due to wrong types (expect string, likely receive a dictionary/object from the span store field) - By avoiding any jsonify processing before encode time, a small edge case appeared from the LLMObs SDK decorator function which auto-annotates non-LLM spans with input function argument maps. In Python 3.8, the `bind_partial().arguments` call used to extract the function arguments returns an OrderedDict (otherwise returns a regular Dict() in Python >= 3.9, which broke some tests as we were simply casting to a string when storing the input/output value). I added a fix to cast the `bind_partial().arguments` object to a dict to avoid this issue coming up. ## Next Steps This is a great first step, but there are still tons of performance improvements we can make to our encoding/writing. The most notable is that we call `json.dumps()` on span events more than once (to calculate the payload size before adding to the buffer). ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/llmobs/_integrations/anthropic.py | 27 +- ddtrace/llmobs/_integrations/bedrock.py | 31 +- ddtrace/llmobs/_integrations/gemini.py | 27 +- ddtrace/llmobs/_integrations/langchain.py | 132 +++++---- ddtrace/llmobs/_integrations/openai.py | 57 ++-- ddtrace/llmobs/_integrations/vertexai.py | 30 +- ddtrace/llmobs/_llmobs.py | 49 ++-- ddtrace/llmobs/_trace_processor.py | 74 ++--- ddtrace/llmobs/_utils.py | 12 +- ddtrace/llmobs/_writer.py | 14 +- ddtrace/llmobs/decorators.py | 12 +- .../anthropic/test_anthropic_llmobs.py | 32 --- tests/contrib/openai/test_openai_llmobs.py | 32 --- tests/llmobs/_utils.py | 52 +++- tests/llmobs/test_llmobs_decorators.py | 8 +- tests/llmobs/test_llmobs_service.py | 271 ++++++------------ .../test_llmobs_span_agentless_writer.py | 28 +- tests/llmobs/test_llmobs_span_encoder.py | 72 +++++ tests/llmobs/test_llmobs_trace_processor.py | 98 ++++--- 19 files changed, 495 insertions(+), 563 deletions(-) create mode 100644 tests/llmobs/test_llmobs_span_encoder.py diff --git a/ddtrace/llmobs/_integrations/anthropic.py b/ddtrace/llmobs/_integrations/anthropic.py index 0747d68e77b..dfb39c0f7e9 100644 --- a/ddtrace/llmobs/_integrations/anthropic.py +++ b/ddtrace/llmobs/_integrations/anthropic.py @@ -19,7 +19,6 @@ from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._integrations.base import BaseLLMIntegration from ddtrace.llmobs._utils import _get_attr -from ddtrace.llmobs._utils import safe_json log = get_logger(__name__) @@ -66,21 +65,21 @@ def _llmobs_set_tags( system_prompt = kwargs.get("system") input_messages = self._extract_input_message(messages, system_prompt) - span.set_tag_str(SPAN_KIND, "llm") - span.set_tag_str(MODEL_NAME, span.get_tag("anthropic.request.model") or "") - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - span.set_tag_str(METADATA, safe_json(parameters)) - span.set_tag_str(MODEL_PROVIDER, "anthropic") - - if span.error or response is None: - span.set_tag_str(OUTPUT_MESSAGES, json.dumps([{"content": ""}])) - else: + output_messages = [{"content": ""}] + if not span.error and response is not None: output_messages = self._extract_output_message(response) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) - usage = self._get_llmobs_metrics_tags(span) - if usage: - span.set_tag_str(METRICS, safe_json(usage)) + span._set_ctx_items( + { + SPAN_KIND: "llm", + MODEL_NAME: span.get_tag("anthropic.request.model") or "", + MODEL_PROVIDER: "anthropic", + INPUT_MESSAGES: input_messages, + METADATA: parameters, + OUTPUT_MESSAGES: output_messages, + METRICS: self._get_llmobs_metrics_tags(span), + } + ) def _extract_input_message(self, messages, system_prompt=None): """Extract input messages from the stored prompt. diff --git a/ddtrace/llmobs/_integrations/bedrock.py b/ddtrace/llmobs/_integrations/bedrock.py index 78798ae4f98..bf8b020ebea 100644 --- a/ddtrace/llmobs/_integrations/bedrock.py +++ b/ddtrace/llmobs/_integrations/bedrock.py @@ -19,7 +19,6 @@ from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._integrations import BaseLLMIntegration from ddtrace.llmobs._utils import _get_llmobs_parent_id -from ddtrace.llmobs._utils import safe_json log = get_logger(__name__) @@ -37,9 +36,9 @@ def _llmobs_set_tags( operation: str = "", ) -> None: """Extract prompt/response tags from a completion and set them as temporary "_ml_obs.*" tags.""" - if span.get_tag(PROPAGATED_PARENT_ID_KEY) is None: + if span._get_ctx_item(PROPAGATED_PARENT_ID_KEY) is None: parent_id = _get_llmobs_parent_id(span) or "undefined" - span.set_tag(PARENT_ID_KEY, parent_id) + span._set_ctx_item(PARENT_ID_KEY, parent_id) parameters = {} if span.get_tag("bedrock.request.temperature"): parameters["temperature"] = float(span.get_tag("bedrock.request.temperature") or 0.0) @@ -48,20 +47,20 @@ def _llmobs_set_tags( prompt = kwargs.get("prompt", "") input_messages = self._extract_input_message(prompt) - - span.set_tag_str(SPAN_KIND, "llm") - span.set_tag_str(MODEL_NAME, span.get_tag("bedrock.request.model") or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag("bedrock.request.model_provider") or "") - - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - span.set_tag_str(METADATA, safe_json(parameters)) - if span.error or response is None: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) - else: + output_messages = [{"content": ""}] + if not span.error and response is not None: output_messages = self._extract_output_message(response) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) - metrics = self._llmobs_metrics(span, response) - span.set_tag_str(METRICS, safe_json(metrics)) + span._set_ctx_items( + { + SPAN_KIND: "llm", + MODEL_NAME: span.get_tag("bedrock.request.model") or "", + MODEL_PROVIDER: span.get_tag("bedrock.request.model_provider") or "", + INPUT_MESSAGES: input_messages, + METADATA: parameters, + METRICS: self._llmobs_metrics(span, response), + OUTPUT_MESSAGES: output_messages, + } + ) @staticmethod def _llmobs_metrics(span: Span, response: Optional[Dict[str, Any]]) -> Dict[str, Any]: diff --git a/ddtrace/llmobs/_integrations/gemini.py b/ddtrace/llmobs/_integrations/gemini.py index f1a4730812f..491187475f0 100644 --- a/ddtrace/llmobs/_integrations/gemini.py +++ b/ddtrace/llmobs/_integrations/gemini.py @@ -19,7 +19,6 @@ from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import llmobs_get_metadata_google from ddtrace.llmobs._utils import _get_attr -from ddtrace.llmobs._utils import safe_json class GeminiIntegration(BaseLLMIntegration): @@ -41,28 +40,28 @@ def _llmobs_set_tags( response: Optional[Any] = None, operation: str = "", ) -> None: - span.set_tag_str(SPAN_KIND, "llm") - span.set_tag_str(MODEL_NAME, span.get_tag("google_generativeai.request.model") or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag("google_generativeai.request.provider") or "") - instance = kwargs.get("instance", None) metadata = llmobs_get_metadata_google(kwargs, instance) - span.set_tag_str(METADATA, safe_json(metadata)) system_instruction = get_system_instructions_from_google_model(instance) input_contents = get_argument_value(args, kwargs, 0, "contents") input_messages = self._extract_input_message(input_contents, system_instruction) - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - if span.error or response is None: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) - else: + output_messages = [{"content": ""}] + if not span.error and response is not None: output_messages = self._extract_output_message(response) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) - usage = get_llmobs_metrics_tags_google("google_generativeai", span) - if usage: - span.set_tag_str(METRICS, safe_json(usage)) + span._set_ctx_items( + { + SPAN_KIND: "llm", + MODEL_NAME: span.get_tag("google_generativeai.request.model") or "", + MODEL_PROVIDER: span.get_tag("google_generativeai.request.provider") or "", + METADATA: metadata, + INPUT_MESSAGES: input_messages, + OUTPUT_MESSAGES: output_messages, + METRICS: get_llmobs_metrics_tags_google("google_generativeai", span), + } + ) def _extract_input_message(self, contents, system_instruction=None): messages = [] diff --git a/ddtrace/llmobs/_integrations/langchain.py b/ddtrace/llmobs/_integrations/langchain.py index 2128458253d..1fce3d11804 100644 --- a/ddtrace/llmobs/_integrations/langchain.py +++ b/ddtrace/llmobs/_integrations/langchain.py @@ -28,7 +28,6 @@ from ddtrace.llmobs._constants import SPAN_KIND from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._integrations.base import BaseLLMIntegration -from ddtrace.llmobs._utils import safe_json from ddtrace.llmobs.utils import Document @@ -130,15 +129,11 @@ def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) if max_tokens is not None and max_tokens != "None": metadata["max_tokens"] = int(max_tokens) if metadata: - span.set_tag_str(METADATA, safe_json(metadata)) + span._set_ctx_item(METADATA, metadata) def _llmobs_set_tags_from_llm( self, span: Span, args: List[Any], kwargs: Dict[str, Any], completions: Any, is_workflow: bool = False ) -> None: - span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "llm") - span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag(PROVIDER) or "") - input_tag_key = INPUT_VALUE if is_workflow else INPUT_MESSAGES output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES stream = span.get_tag("langchain.request.stream") @@ -146,21 +141,28 @@ def _llmobs_set_tags_from_llm( prompts = get_argument_value(args, kwargs, 0, "input" if stream else "prompts") if isinstance(prompts, str) or not isinstance(prompts, list): prompts = [prompts] - if stream: # chat and llm take the same input types for streamed calls - span.set_tag_str(input_tag_key, safe_json(self._handle_stream_input_messages(prompts))) + input_messages = self._handle_stream_input_messages(prompts) else: - span.set_tag_str(input_tag_key, safe_json([{"content": str(prompt)} for prompt in prompts])) + input_messages = [{"content": str(prompt)} for prompt in prompts] + + span._set_ctx_items( + { + SPAN_KIND: "workflow" if is_workflow else "llm", + MODEL_NAME: span.get_tag(MODEL) or "", + MODEL_PROVIDER: span.get_tag(PROVIDER) or "", + input_tag_key: input_messages, + } + ) if span.error: - span.set_tag_str(output_tag_key, safe_json([{"content": ""}])) + span._set_ctx_item(output_tag_key, [{"content": ""}]) return if stream: message_content = [{"content": completions}] # single completion for streams else: message_content = [{"content": completion[0].text} for completion in completions.generations] - if not is_workflow: input_tokens, output_tokens, total_tokens = self.check_token_usage_chat_or_llm_result(completions) if total_tokens > 0: @@ -169,8 +171,8 @@ def _llmobs_set_tags_from_llm( OUTPUT_TOKENS_METRIC_KEY: output_tokens, TOTAL_TOKENS_METRIC_KEY: total_tokens, } - span.set_tag_str(METRICS, safe_json(metrics)) - span.set_tag_str(output_tag_key, safe_json(message_content)) + span._set_ctx_item(METRICS, metrics) + span._set_ctx_item(output_tag_key, message_content) def _llmobs_set_tags_from_chat_model( self, @@ -180,10 +182,13 @@ def _llmobs_set_tags_from_chat_model( chat_completions: Any, is_workflow: bool = False, ) -> None: - span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "llm") - span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag(PROVIDER) or "") - + span._set_ctx_items( + { + SPAN_KIND: "workflow" if is_workflow else "llm", + MODEL_NAME: span.get_tag(MODEL) or "", + MODEL_PROVIDER: span.get_tag(PROVIDER) or "", + } + ) input_tag_key = INPUT_VALUE if is_workflow else INPUT_MESSAGES output_tag_key = OUTPUT_VALUE if is_workflow else OUTPUT_MESSAGES stream = span.get_tag("langchain.request.stream") @@ -203,17 +208,17 @@ def _llmobs_set_tags_from_chat_model( ) role = getattr(message, "role", ROLE_MAPPING.get(message.type, "")) input_messages.append({"content": str(content), "role": str(role)}) - span.set_tag_str(input_tag_key, safe_json(input_messages)) + span._set_ctx_item(input_tag_key, input_messages) if span.error: - span.set_tag_str(output_tag_key, json.dumps([{"content": ""}])) + span._set_ctx_item(output_tag_key, [{"content": ""}]) return output_messages = [] if stream: content = chat_completions.content role = chat_completions.__class__.__name__.replace("MessageChunk", "").lower() # AIMessageChunk --> ai - span.set_tag_str(output_tag_key, safe_json([{"content": content, "role": ROLE_MAPPING.get(role, "")}])) + span._set_ctx_item(output_tag_key, [{"content": content, "role": ROLE_MAPPING.get(role, "")}]) return input_tokens, output_tokens, total_tokens = 0, 0, 0 @@ -249,7 +254,7 @@ def _llmobs_set_tags_from_chat_model( output_tokens = sum(v["output_tokens"] for v in tokens_per_choice_run_id.values()) total_tokens = sum(v["total_tokens"] for v in tokens_per_choice_run_id.values()) - span.set_tag_str(output_tag_key, safe_json(output_messages)) + span._set_ctx_item(output_tag_key, output_messages) if not is_workflow and total_tokens > 0: metrics = { @@ -257,7 +262,7 @@ def _llmobs_set_tags_from_chat_model( OUTPUT_TOKENS_METRIC_KEY: output_tokens, TOTAL_TOKENS_METRIC_KEY: total_tokens, } - span.set_tag_str(METRICS, safe_json(metrics)) + span._set_ctx_item(METRICS, metrics) def _extract_tool_calls(self, chat_completion_msg: Any) -> List[Dict[str, Any]]: """Extracts tool calls from a langchain chat completion.""" @@ -301,20 +306,17 @@ def _handle_stream_input_messages(self, inputs): return input_messages def _llmobs_set_meta_tags_from_chain(self, span: Span, args, kwargs, outputs: Any) -> None: - span.set_tag_str(SPAN_KIND, "workflow") - stream = span.get_tag("langchain.request.stream") - if stream: + if span.get_tag("langchain.request.stream"): inputs = get_argument_value(args, kwargs, 0, "input") else: inputs = kwargs + formatted_inputs = "" if inputs is not None: formatted_inputs = self.format_io(inputs) - span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs)) - if span.error or outputs is None: - span.set_tag_str(OUTPUT_VALUE, "") - return - formatted_outputs = self.format_io(outputs) - span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) + formatted_outputs = "" + if not span.error and outputs is not None: + formatted_outputs = self.format_io(outputs) + span._set_ctx_items({SPAN_KIND: "workflow", INPUT_VALUE: formatted_inputs, OUTPUT_VALUE: formatted_outputs}) def _llmobs_set_meta_tags_from_embedding( self, @@ -324,13 +326,15 @@ def _llmobs_set_meta_tags_from_embedding( output_embedding: Union[List[float], List[List[float]], None], is_workflow: bool = False, ) -> None: - span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "embedding") - span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag(PROVIDER) or "") - + span._set_ctx_items( + { + SPAN_KIND: "workflow" if is_workflow else "embedding", + MODEL_NAME: span.get_tag(MODEL) or "", + MODEL_PROVIDER: span.get_tag(PROVIDER) or "", + } + ) input_tag_key = INPUT_VALUE if is_workflow else INPUT_DOCUMENTS output_tag_key = OUTPUT_VALUE - output_values: Any try: @@ -343,16 +347,16 @@ def _llmobs_set_meta_tags_from_embedding( ): if is_workflow: formatted_inputs = self.format_io(input_texts) - span.set_tag_str(input_tag_key, safe_json(formatted_inputs)) + span._set_ctx_item(input_tag_key, formatted_inputs) else: if isinstance(input_texts, str): input_texts = [input_texts] input_documents = [Document(text=str(doc)) for doc in input_texts] - span.set_tag_str(input_tag_key, safe_json(input_documents)) + span._set_ctx_item(input_tag_key, input_documents) except TypeError: log.warning("Failed to serialize embedding input data to JSON") if span.error or output_embedding is None: - span.set_tag_str(output_tag_key, "") + span._set_ctx_item(output_tag_key, "") return try: if isinstance(output_embedding[0], float): @@ -364,7 +368,7 @@ def _llmobs_set_meta_tags_from_embedding( output_values = output_embedding embeddings_count = len(output_embedding) embedding_dim = len(output_values[0]) - span.set_tag_str( + span._set_ctx_item( output_tag_key, "[{} embedding(s) returned with size {}]".format(embeddings_count, embedding_dim), ) @@ -379,19 +383,22 @@ def _llmobs_set_meta_tags_from_similarity_search( output_documents: Union[List[Any], None], is_workflow: bool = False, ) -> None: - span.set_tag_str(SPAN_KIND, "workflow" if is_workflow else "retrieval") - span.set_tag_str(MODEL_NAME, span.get_tag(MODEL) or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag(PROVIDER) or "") - + span._set_ctx_items( + { + SPAN_KIND: "workflow" if is_workflow else "retrieval", + MODEL_NAME: span.get_tag(MODEL) or "", + MODEL_PROVIDER: span.get_tag(PROVIDER) or "", + } + ) input_query = get_argument_value(args, kwargs, 0, "query") if input_query is not None: formatted_inputs = self.format_io(input_query) - span.set_tag_str(INPUT_VALUE, safe_json(formatted_inputs)) + span._set_ctx_item(INPUT_VALUE, formatted_inputs) if span.error or not output_documents or not isinstance(output_documents, list): - span.set_tag_str(OUTPUT_VALUE, "") + span._set_ctx_item(OUTPUT_VALUE, "") return if is_workflow: - span.set_tag_str(OUTPUT_VALUE, "[{} document(s) retrieved]".format(len(output_documents))) + span._set_ctx_item(OUTPUT_VALUE, "[{} document(s) retrieved]".format(len(output_documents))) return documents = [] for d in output_documents: @@ -400,32 +407,31 @@ def _llmobs_set_meta_tags_from_similarity_search( metadata = getattr(d, "metadata", {}) doc["name"] = metadata.get("name", doc["id"]) documents.append(doc) - span.set_tag_str(OUTPUT_DOCUMENTS, safe_json(self.format_io(documents))) + span._set_ctx_item(OUTPUT_DOCUMENTS, self.format_io(documents)) # we set the value as well to ensure that the UI would display it in case the span was the root - span.set_tag_str(OUTPUT_VALUE, "[{} document(s) retrieved]".format(len(documents))) + span._set_ctx_item(OUTPUT_VALUE, "[{} document(s) retrieved]".format(len(documents))) def _llmobs_set_meta_tags_from_tool(self, span: Span, tool_inputs: Dict[str, Any], tool_output: object) -> None: - if span.get_tag(METADATA): - metadata = json.loads(str(span.get_tag(METADATA))) - else: - metadata = {} - - span.set_tag_str(SPAN_KIND, "tool") + metadata = json.loads(str(span.get_tag(METADATA))) if span.get_tag(METADATA) else {} + formatted_input = "" if tool_inputs is not None: tool_input = tool_inputs.get("input") if tool_inputs.get("config"): metadata["tool_config"] = tool_inputs.get("config") if tool_inputs.get("info"): metadata["tool_info"] = tool_inputs.get("info") - if metadata: - span.set_tag_str(METADATA, safe_json(metadata)) formatted_input = self.format_io(tool_input) - span.set_tag_str(INPUT_VALUE, safe_json(formatted_input)) - if span.error or tool_output is None: - span.set_tag_str(OUTPUT_VALUE, "") - return - formatted_outputs = self.format_io(tool_output) - span.set_tag_str(OUTPUT_VALUE, safe_json(formatted_outputs)) + formatted_outputs = "" + if not span.error and tool_output is not None: + formatted_outputs = self.format_io(tool_output) + span._set_ctx_items( + { + SPAN_KIND: "tool", + METADATA: metadata, + INPUT_VALUE: formatted_input, + OUTPUT_VALUE: formatted_outputs, + } + ) def _set_base_span_tags( # type: ignore[override] self, diff --git a/ddtrace/llmobs/_integrations/openai.py b/ddtrace/llmobs/_integrations/openai.py index 5c9e73eaca7..bd727b1a5a2 100644 --- a/ddtrace/llmobs/_integrations/openai.py +++ b/ddtrace/llmobs/_integrations/openai.py @@ -23,7 +23,6 @@ from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY from ddtrace.llmobs._integrations.base import BaseLLMIntegration from ddtrace.llmobs._utils import _get_attr -from ddtrace.llmobs._utils import safe_json from ddtrace.llmobs.utils import Document from ddtrace.pin import Pin @@ -148,19 +147,18 @@ def _llmobs_set_tags( ) -> None: """Sets meta tags and metrics for span events to be sent to LLMObs.""" span_kind = "embedding" if operation == "embedding" else "llm" - span.set_tag_str(SPAN_KIND, span_kind) model_name = span.get_tag("openai.response.model") or span.get_tag("openai.request.model") - span.set_tag_str(MODEL_NAME, model_name or "") model_provider = "azure_openai" if self._is_azure_openai(span) else "openai" - span.set_tag_str(MODEL_PROVIDER, model_provider) if operation == "completion": self._llmobs_set_meta_tags_from_completion(span, kwargs, response) elif operation == "chat": self._llmobs_set_meta_tags_from_chat(span, kwargs, response) elif operation == "embedding": self._llmobs_set_meta_tags_from_embedding(span, kwargs, response) - metrics = self._set_llmobs_metrics_tags(span, response) - span.set_tag_str(METRICS, safe_json(metrics)) + metrics = self._extract_llmobs_metrics_tags(span, response) + span._set_ctx_items( + {SPAN_KIND: span_kind, MODEL_NAME: model_name or "", MODEL_PROVIDER: model_provider, METRICS: metrics} + ) @staticmethod def _llmobs_set_meta_tags_from_completion(span: Span, kwargs: Dict[str, Any], completions: Any) -> None: @@ -168,20 +166,18 @@ def _llmobs_set_meta_tags_from_completion(span: Span, kwargs: Dict[str, Any], co prompt = kwargs.get("prompt", "") if isinstance(prompt, str): prompt = [prompt] - span.set_tag_str(INPUT_MESSAGES, safe_json([{"content": str(p)} for p in prompt])) - parameters = {k: v for k, v in kwargs.items() if k not in ("model", "prompt")} - span.set_tag_str(METADATA, safe_json(parameters)) - - if span.error or not completions: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) - return - if hasattr(completions, "choices"): # non-streaming response - choices = completions.choices - else: # streamed response - choices = completions - messages = [{"content": _get_attr(choice, "text", "")} for choice in choices] - span.set_tag_str(OUTPUT_MESSAGES, safe_json(messages)) + output_messages = [{"content": ""}] + if not span.error and completions: + choices = getattr(completions, "choices", completions) + output_messages = [{"content": _get_attr(choice, "text", "")} for choice in choices] + span._set_ctx_items( + { + INPUT_MESSAGES: [{"content": str(p)} for p in prompt], + METADATA: parameters, + OUTPUT_MESSAGES: output_messages, + } + ) @staticmethod def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages: Optional[Any]) -> None: @@ -189,16 +185,14 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages input_messages = [] for m in kwargs.get("messages", []): input_messages.append({"content": str(_get_attr(m, "content", "")), "role": str(_get_attr(m, "role", ""))}) - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - parameters = {k: v for k, v in kwargs.items() if k not in ("model", "messages", "tools", "functions")} - span.set_tag_str(METADATA, safe_json(parameters)) + span._set_ctx_items({INPUT_MESSAGES: input_messages, METADATA: parameters}) if span.error or not messages: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) + span._set_ctx_item(OUTPUT_MESSAGES, [{"content": ""}]) return - output_messages = [] if isinstance(messages, list): # streamed response + output_messages = [] for streamed_message in messages: message = {"content": streamed_message["content"], "role": streamed_message["role"]} tool_calls = streamed_message.get("tool_calls", []) @@ -213,9 +207,10 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages for tool_call in tool_calls ] output_messages.append(message) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) + span._set_ctx_item(OUTPUT_MESSAGES, output_messages) return choices = _get_attr(messages, "choices", []) + output_messages = [] for idx, choice in enumerate(choices): tool_calls_info = [] choice_message = _get_attr(choice, "message", {}) @@ -241,7 +236,7 @@ def _llmobs_set_meta_tags_from_chat(span: Span, kwargs: Dict[str, Any], messages output_messages.append({"content": content, "role": role, "tool_calls": tool_calls_info}) continue output_messages.append({"content": content, "role": role}) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) + span._set_ctx_item(OUTPUT_MESSAGES, output_messages) @staticmethod def _llmobs_set_meta_tags_from_embedding(span: Span, kwargs: Dict[str, Any], resp: Any) -> None: @@ -250,7 +245,6 @@ def _llmobs_set_meta_tags_from_embedding(span: Span, kwargs: Dict[str, Any], res metadata = {"encoding_format": encoding_format} if kwargs.get("dimensions"): metadata["dimensions"] = kwargs.get("dimensions") - span.set_tag_str(METADATA, safe_json(metadata)) embedding_inputs = kwargs.get("input", "") if isinstance(embedding_inputs, str) or isinstance(embedding_inputs[0], int): @@ -258,20 +252,19 @@ def _llmobs_set_meta_tags_from_embedding(span: Span, kwargs: Dict[str, Any], res input_documents = [] for doc in embedding_inputs: input_documents.append(Document(text=str(doc))) - span.set_tag_str(INPUT_DOCUMENTS, safe_json(input_documents)) - + span._set_ctx_items({METADATA: metadata, INPUT_DOCUMENTS: input_documents}) if span.error: return if encoding_format == "float": embedding_dim = len(resp.data[0].embedding) - span.set_tag_str( + span._set_ctx_item( OUTPUT_VALUE, "[{} embedding(s) returned with size {}]".format(len(resp.data), embedding_dim) ) return - span.set_tag_str(OUTPUT_VALUE, "[{} embedding(s) returned]".format(len(resp.data))) + span._set_ctx_item(OUTPUT_VALUE, "[{} embedding(s) returned]".format(len(resp.data))) @staticmethod - def _set_llmobs_metrics_tags(span: Span, resp: Any) -> Dict[str, Any]: + def _extract_llmobs_metrics_tags(span: Span, resp: Any) -> Dict[str, Any]: """Extract metrics from a chat/completion and set them as a temporary "_ml_obs.metrics" tag.""" token_usage = _get_attr(resp, "usage", None) if token_usage is not None: diff --git a/ddtrace/llmobs/_integrations/vertexai.py b/ddtrace/llmobs/_integrations/vertexai.py index 69fdc7eb665..4019268e0c4 100644 --- a/ddtrace/llmobs/_integrations/vertexai.py +++ b/ddtrace/llmobs/_integrations/vertexai.py @@ -19,7 +19,6 @@ from ddtrace.llmobs._integrations.utils import get_system_instructions_from_google_model from ddtrace.llmobs._integrations.utils import llmobs_get_metadata_google from ddtrace.llmobs._utils import _get_attr -from ddtrace.llmobs._utils import safe_json class VertexAIIntegration(BaseLLMIntegration): @@ -41,30 +40,29 @@ def _llmobs_set_tags( response: Optional[Any] = None, operation: str = "", ) -> None: - span.set_tag_str(SPAN_KIND, "llm") - span.set_tag_str(MODEL_NAME, span.get_tag("vertexai.request.model") or "") - span.set_tag_str(MODEL_PROVIDER, span.get_tag("vertexai.request.provider") or "") - instance = kwargs.get("instance", None) history = kwargs.get("history", []) metadata = llmobs_get_metadata_google(kwargs, instance) - span.set_tag_str(METADATA, safe_json(metadata)) system_instruction = get_system_instructions_from_google_model(instance) input_contents = get_argument_value(args, kwargs, 0, "contents") input_messages = self._extract_input_message(input_contents, history, system_instruction) - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages)) - - if span.error or response is None: - span.set_tag_str(OUTPUT_MESSAGES, safe_json([{"content": ""}])) - return - output_messages = self._extract_output_message(response) - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages)) + output_messages = [{"content": ""}] + if not span.error and response is not None: + output_messages = self._extract_output_message(response) - usage = get_llmobs_metrics_tags_google("vertexai", span) - if usage: - span.set_tag_str(METRICS, safe_json(usage)) + span._set_ctx_items( + { + SPAN_KIND: "llm", + MODEL_NAME: span.get_tag("vertexai.request.model") or "", + MODEL_PROVIDER: span.get_tag("vertexai.request.provider") or "", + METADATA: metadata, + INPUT_MESSAGES: input_messages, + OUTPUT_MESSAGES: output_messages, + METRICS: get_llmobs_metrics_tags_google("vertexai", span), + } + ) def _extract_input_message(self, contents, history, system_instruction=None): from vertexai.generative_models._generative_models import Part diff --git a/ddtrace/llmobs/_llmobs.py b/ddtrace/llmobs/_llmobs.py index 808cee89e0f..867edbdca4f 100644 --- a/ddtrace/llmobs/_llmobs.py +++ b/ddtrace/llmobs/_llmobs.py @@ -399,23 +399,23 @@ def _start_span( if name is None: name = operation_kind span = self.tracer.trace(name, resource=operation_kind, span_type=SpanTypes.LLM) - span.set_tag_str(SPAN_KIND, operation_kind) + span._set_ctx_item(SPAN_KIND, operation_kind) if model_name is not None: - span.set_tag_str(MODEL_NAME, model_name) + span._set_ctx_item(MODEL_NAME, model_name) if model_provider is not None: - span.set_tag_str(MODEL_PROVIDER, model_provider) + span._set_ctx_item(MODEL_PROVIDER, model_provider) session_id = session_id if session_id is not None else _get_session_id(span) if session_id is not None: - span.set_tag_str(SESSION_ID, session_id) + span._set_ctx_item(SESSION_ID, session_id) if ml_app is None: ml_app = _get_ml_app(span) - span.set_tag_str(ML_APP, ml_app) - if span.get_tag(PROPAGATED_PARENT_ID_KEY) is None: + span._set_ctx_item(ML_APP, ml_app) + if span._get_ctx_item(PROPAGATED_PARENT_ID_KEY) is None: # For non-distributed traces or spans in the first service of a distributed trace, # The LLMObs parent ID tag is not set at span start time. We need to manually set the parent ID tag now # in these cases to avoid conflicting with the later propagated tags. parent_id = _get_llmobs_parent_id(span) or "undefined" - span.set_tag_str(PARENT_ID_KEY, str(parent_id)) + span._set_ctx_item(PARENT_ID_KEY, str(parent_id)) return span @classmethod @@ -638,7 +638,7 @@ def annotate( cls._tag_metrics(span, metrics) if tags is not None: cls._tag_span_tags(span, tags) - span_kind = span.get_tag(SPAN_KIND) + span_kind = span._get_ctx_item(SPAN_KIND) if parameters is not None: log.warning("Setting parameters is deprecated, please set parameters and other metadata as tags instead.") cls._tag_params(span, parameters) @@ -664,7 +664,7 @@ def _tag_prompt(span, prompt: dict) -> None: """Tags a given LLMObs span with a prompt""" try: validated_prompt = validate_prompt(prompt) - span.set_tag_str(INPUT_PROMPT, safe_json(validated_prompt)) + span._set_ctx_item(INPUT_PROMPT, validated_prompt) except TypeError: log.warning("Failed to validate prompt with error: ", exc_info=True) return @@ -677,7 +677,7 @@ def _tag_params(span: Span, params: Dict[str, Any]) -> None: if not isinstance(params, dict): log.warning("parameters must be a dictionary of key-value pairs.") return - span.set_tag_str(INPUT_PARAMETERS, safe_json(params)) + span._set_ctx_item(INPUT_PARAMETERS, params) @classmethod def _tag_llm_io(cls, span, input_messages=None, output_messages=None): @@ -689,7 +689,7 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None): if not isinstance(input_messages, Messages): input_messages = Messages(input_messages) if input_messages.messages: - span.set_tag_str(INPUT_MESSAGES, safe_json(input_messages.messages)) + span._set_ctx_item(INPUT_MESSAGES, input_messages.messages) except TypeError: log.warning("Failed to parse input messages.", exc_info=True) if output_messages is None: @@ -699,7 +699,7 @@ def _tag_llm_io(cls, span, input_messages=None, output_messages=None): output_messages = Messages(output_messages) if not output_messages.messages: return - span.set_tag_str(OUTPUT_MESSAGES, safe_json(output_messages.messages)) + span._set_ctx_item(OUTPUT_MESSAGES, output_messages.messages) except TypeError: log.warning("Failed to parse output messages.", exc_info=True) @@ -713,12 +713,12 @@ def _tag_embedding_io(cls, span, input_documents=None, output_text=None): if not isinstance(input_documents, Documents): input_documents = Documents(input_documents) if input_documents.documents: - span.set_tag_str(INPUT_DOCUMENTS, safe_json(input_documents.documents)) + span._set_ctx_item(INPUT_DOCUMENTS, input_documents.documents) except TypeError: log.warning("Failed to parse input documents.", exc_info=True) if output_text is None: return - span.set_tag_str(OUTPUT_VALUE, safe_json(output_text)) + span._set_ctx_item(OUTPUT_VALUE, str(output_text)) @classmethod def _tag_retrieval_io(cls, span, input_text=None, output_documents=None): @@ -726,7 +726,7 @@ def _tag_retrieval_io(cls, span, input_text=None, output_documents=None): Will be mapped to span's `meta.{input,output}.text` fields. """ if input_text is not None: - span.set_tag_str(INPUT_VALUE, safe_json(input_text)) + span._set_ctx_item(INPUT_VALUE, str(input_text)) if output_documents is None: return try: @@ -734,7 +734,7 @@ def _tag_retrieval_io(cls, span, input_text=None, output_documents=None): output_documents = Documents(output_documents) if not output_documents.documents: return - span.set_tag_str(OUTPUT_DOCUMENTS, safe_json(output_documents.documents)) + span._set_ctx_item(OUTPUT_DOCUMENTS, output_documents.documents) except TypeError: log.warning("Failed to parse output documents.", exc_info=True) @@ -744,9 +744,9 @@ def _tag_text_io(cls, span, input_value=None, output_value=None): Will be mapped to span's `meta.{input,output}.values` fields. """ if input_value is not None: - span.set_tag_str(INPUT_VALUE, safe_json(input_value)) + span._set_ctx_item(INPUT_VALUE, str(input_value)) if output_value is not None: - span.set_tag_str(OUTPUT_VALUE, safe_json(output_value)) + span._set_ctx_item(OUTPUT_VALUE, str(output_value)) @staticmethod def _tag_span_tags(span: Span, span_tags: Dict[str, Any]) -> None: @@ -759,12 +759,9 @@ def _tag_span_tags(span: Span, span_tags: Dict[str, Any]) -> None: log.warning("span_tags must be a dictionary of string key - primitive value pairs.") return try: - current_tags_str = span.get_tag(TAGS) - if current_tags_str: - current_tags = json.loads(current_tags_str) - current_tags.update(span_tags) - span_tags = current_tags - span.set_tag_str(TAGS, safe_json(span_tags)) + existing_tags = span._get_ctx_item(TAGS) or {} + existing_tags.update(span_tags) + span._set_ctx_item(TAGS, existing_tags) except Exception: log.warning("Failed to parse tags.", exc_info=True) @@ -776,7 +773,7 @@ def _tag_metadata(span: Span, metadata: Dict[str, Any]) -> None: if not isinstance(metadata, dict): log.warning("metadata must be a dictionary of string key-value pairs.") return - span.set_tag_str(METADATA, safe_json(metadata)) + span._set_ctx_item(METADATA, metadata) @staticmethod def _tag_metrics(span: Span, metrics: Dict[str, Any]) -> None: @@ -786,7 +783,7 @@ def _tag_metrics(span: Span, metrics: Dict[str, Any]) -> None: if not isinstance(metrics, dict): log.warning("metrics must be a dictionary of string key - numeric value pairs.") return - span.set_tag_str(METRICS, safe_json(metrics)) + span._set_ctx_item(METRICS, metrics) @classmethod def submit_evaluation( diff --git a/ddtrace/llmobs/_trace_processor.py b/ddtrace/llmobs/_trace_processor.py index b4af0c5ffd1..231d53d7626 100644 --- a/ddtrace/llmobs/_trace_processor.py +++ b/ddtrace/llmobs/_trace_processor.py @@ -1,4 +1,3 @@ -import json from typing import Any from typing import Dict from typing import List @@ -27,7 +26,6 @@ from ddtrace.llmobs._constants import OUTPUT_DOCUMENTS from ddtrace.llmobs._constants import OUTPUT_MESSAGES from ddtrace.llmobs._constants import OUTPUT_VALUE -from ddtrace.llmobs._constants import PARENT_ID_KEY from ddtrace.llmobs._constants import RAGAS_ML_APP_PREFIX from ddtrace.llmobs._constants import RUNNER_IS_INTEGRATION_SPAN_TAG from ddtrace.llmobs._constants import SESSION_ID @@ -37,6 +35,7 @@ from ddtrace.llmobs._utils import _get_ml_app from ddtrace.llmobs._utils import _get_session_id from ddtrace.llmobs._utils import _get_span_name +from ddtrace.llmobs._utils import safe_json log = get_logger(__name__) @@ -62,7 +61,7 @@ def process_trace(self, trace: List[Span]) -> Optional[List[Span]]: def submit_llmobs_span(self, span: Span) -> None: """Generate and submit an LLMObs span event to be sent to LLMObs.""" span_event = None - is_llm_span = span.get_tag(SPAN_KIND) == "llm" + is_llm_span = span._get_ctx_item(SPAN_KIND) == "llm" is_ragas_integration_span = False try: span_event, is_ragas_integration_span = self._llmobs_span_event(span) @@ -77,44 +76,49 @@ def submit_llmobs_span(self, span: Span) -> None: def _llmobs_span_event(self, span: Span) -> Tuple[Dict[str, Any], bool]: """Span event object structure.""" - span_kind = span._meta.pop(SPAN_KIND) + span_kind = span._get_ctx_item(SPAN_KIND) + if not span_kind: + raise KeyError("Span kind not found in span context") meta: Dict[str, Any] = {"span.kind": span_kind, "input": {}, "output": {}} - if span_kind in ("llm", "embedding") and span.get_tag(MODEL_NAME) is not None: - meta["model_name"] = span._meta.pop(MODEL_NAME) - meta["model_provider"] = span._meta.pop(MODEL_PROVIDER, "custom").lower() - if span.get_tag(METADATA) is not None: - meta["metadata"] = json.loads(span._meta.pop(METADATA)) - if span.get_tag(INPUT_PARAMETERS): - meta["input"]["parameters"] = json.loads(span._meta.pop(INPUT_PARAMETERS)) - if span_kind == "llm" and span.get_tag(INPUT_MESSAGES) is not None: - meta["input"]["messages"] = json.loads(span._meta.pop(INPUT_MESSAGES)) - if span.get_tag(INPUT_VALUE) is not None: - meta["input"]["value"] = span._meta.pop(INPUT_VALUE) - if span_kind == "llm" and span.get_tag(OUTPUT_MESSAGES) is not None: - meta["output"]["messages"] = json.loads(span._meta.pop(OUTPUT_MESSAGES)) - if span_kind == "embedding" and span.get_tag(INPUT_DOCUMENTS) is not None: - meta["input"]["documents"] = json.loads(span._meta.pop(INPUT_DOCUMENTS)) - if span.get_tag(OUTPUT_VALUE) is not None: - meta["output"]["value"] = span._meta.pop(OUTPUT_VALUE) - if span_kind == "retrieval" and span.get_tag(OUTPUT_DOCUMENTS) is not None: - meta["output"]["documents"] = json.loads(span._meta.pop(OUTPUT_DOCUMENTS)) - if span.get_tag(INPUT_PROMPT) is not None: - prompt_json_str = span._meta.pop(INPUT_PROMPT) + if span_kind in ("llm", "embedding") and span._get_ctx_item(MODEL_NAME) is not None: + meta["model_name"] = span._get_ctx_item(MODEL_NAME) + meta["model_provider"] = (span._get_ctx_item(MODEL_PROVIDER) or "custom").lower() + meta["metadata"] = span._get_ctx_item(METADATA) or {} + if span._get_ctx_item(INPUT_PARAMETERS): + meta["input"]["parameters"] = span._get_ctx_item(INPUT_PARAMETERS) + if span_kind == "llm" and span._get_ctx_item(INPUT_MESSAGES) is not None: + meta["input"]["messages"] = span._get_ctx_item(INPUT_MESSAGES) + if span._get_ctx_item(INPUT_VALUE) is not None: + meta["input"]["value"] = safe_json(span._get_ctx_item(INPUT_VALUE)) + if span_kind == "llm" and span._get_ctx_item(OUTPUT_MESSAGES) is not None: + meta["output"]["messages"] = span._get_ctx_item(OUTPUT_MESSAGES) + if span_kind == "embedding" and span._get_ctx_item(INPUT_DOCUMENTS) is not None: + meta["input"]["documents"] = span._get_ctx_item(INPUT_DOCUMENTS) + if span._get_ctx_item(OUTPUT_VALUE) is not None: + meta["output"]["value"] = safe_json(span._get_ctx_item(OUTPUT_VALUE)) + if span_kind == "retrieval" and span._get_ctx_item(OUTPUT_DOCUMENTS) is not None: + meta["output"]["documents"] = span._get_ctx_item(OUTPUT_DOCUMENTS) + if span._get_ctx_item(INPUT_PROMPT) is not None: + prompt_json_str = span._get_ctx_item(INPUT_PROMPT) if span_kind != "llm": log.warning( "Dropping prompt on non-LLM span kind, annotating prompts is only supported for LLM span kinds." ) else: - meta["input"]["prompt"] = json.loads(prompt_json_str) + meta["input"]["prompt"] = prompt_json_str if span.error: - meta[ERROR_MSG] = span.get_tag(ERROR_MSG) - meta[ERROR_STACK] = span.get_tag(ERROR_STACK) - meta[ERROR_TYPE] = span.get_tag(ERROR_TYPE) + meta.update( + { + ERROR_MSG: span.get_tag(ERROR_MSG), + ERROR_STACK: span.get_tag(ERROR_STACK), + ERROR_TYPE: span.get_tag(ERROR_TYPE), + } + ) if not meta["input"]: meta.pop("input") if not meta["output"]: meta.pop("output") - metrics = json.loads(span._meta.pop(METRICS, "{}")) + metrics = span._get_ctx_item(METRICS) or {} ml_app = _get_ml_app(span) is_ragas_integration_span = False @@ -122,10 +126,8 @@ def _llmobs_span_event(self, span: Span) -> Tuple[Dict[str, Any], bool]: if ml_app.startswith(RAGAS_ML_APP_PREFIX): is_ragas_integration_span = True - span.set_tag_str(ML_APP, ml_app) - + span._set_ctx_item(ML_APP, ml_app) parent_id = str(_get_llmobs_parent_id(span) or "undefined") - span._meta.pop(PARENT_ID_KEY, None) llmobs_span_event = { "trace_id": "{:x}".format(span.trace_id), @@ -140,7 +142,7 @@ def _llmobs_span_event(self, span: Span) -> Tuple[Dict[str, Any], bool]: } session_id = _get_session_id(span) if session_id is not None: - span.set_tag_str(SESSION_ID, session_id) + span._set_ctx_item(SESSION_ID, session_id) llmobs_span_event["session_id"] = session_id llmobs_span_event["tags"] = self._llmobs_tags( @@ -169,7 +171,7 @@ def _llmobs_tags( tags["session_id"] = session_id if is_ragas_integration_span: tags[RUNNER_IS_INTEGRATION_SPAN_TAG] = "ragas" - existing_tags = span._meta.pop(TAGS, None) + existing_tags = span._get_ctx_item(TAGS) if existing_tags is not None: - tags.update(json.loads(existing_tags)) + tags.update(existing_tags) return ["{}:{}".format(k, v) for k, v in tags.items()] diff --git a/ddtrace/llmobs/_utils.py b/ddtrace/llmobs/_utils.py index 8813788f0a3..c1b1c4a776c 100644 --- a/ddtrace/llmobs/_utils.py +++ b/ddtrace/llmobs/_utils.py @@ -110,8 +110,8 @@ def _get_llmobs_parent_id(span: Span) -> Optional[str]: """Return the span ID of the nearest LLMObs-type span in the span's ancestor tree. In priority order: manually set parent ID tag, nearest LLMObs ancestor, local root's propagated parent ID tag. """ - if span.get_tag(PARENT_ID_KEY): - return span.get_tag(PARENT_ID_KEY) + if span._get_ctx_item(PARENT_ID_KEY): + return span._get_ctx_item(PARENT_ID_KEY) nearest_llmobs_ancestor = _get_nearest_llmobs_ancestor(span) if nearest_llmobs_ancestor: return str(nearest_llmobs_ancestor.span_id) @@ -132,12 +132,12 @@ def _get_ml_app(span: Span) -> str: Return the ML app name for a given span, by checking the span's nearest LLMObs span ancestor. Default to the global config LLMObs ML app name otherwise. """ - ml_app = span.get_tag(ML_APP) + ml_app = span._get_ctx_item(ML_APP) if ml_app: return ml_app nearest_llmobs_ancestor = _get_nearest_llmobs_ancestor(span) if nearest_llmobs_ancestor: - ml_app = nearest_llmobs_ancestor.get_tag(ML_APP) + ml_app = nearest_llmobs_ancestor._get_ctx_item(ML_APP) return ml_app or config._llmobs_ml_app or "unknown-ml-app" @@ -146,12 +146,12 @@ def _get_session_id(span: Span) -> Optional[str]: Return the session ID for a given span, by checking the span's nearest LLMObs span ancestor. Default to the span's trace ID. """ - session_id = span.get_tag(SESSION_ID) + session_id = span._get_ctx_item(SESSION_ID) if session_id: return session_id nearest_llmobs_ancestor = _get_nearest_llmobs_ancestor(span) if nearest_llmobs_ancestor: - session_id = nearest_llmobs_ancestor.get_tag(SESSION_ID) + session_id = nearest_llmobs_ancestor._get_ctx_item(SESSION_ID) return session_id diff --git a/ddtrace/llmobs/_writer.py b/ddtrace/llmobs/_writer.py index 6496de96cfe..5a293f05c4e 100644 --- a/ddtrace/llmobs/_writer.py +++ b/ddtrace/llmobs/_writer.py @@ -1,5 +1,4 @@ import atexit -import json from typing import Any from typing import Dict from typing import List @@ -32,6 +31,7 @@ from ddtrace.llmobs._constants import EVP_PROXY_AGENT_ENDPOINT from ddtrace.llmobs._constants import EVP_SUBDOMAIN_HEADER_NAME from ddtrace.llmobs._constants import EVP_SUBDOMAIN_HEADER_VALUE +from ddtrace.llmobs._utils import safe_json logger = get_logger(__name__) @@ -108,11 +108,7 @@ def periodic(self) -> None: self._buffer = [] data = self._data(events) - try: - enc_llm_events = json.dumps(data) - except TypeError: - logger.error("failed to encode %d LLMObs %s events", len(events), self._event_type, exc_info=True) - return + enc_llm_events = safe_json(data) conn = httplib.HTTPSConnection(self._intake, 443, timeout=self._timeout) try: conn.request("POST", self._endpoint, enc_llm_events, self._headers) @@ -197,7 +193,7 @@ def put(self, events: List[LLMObsSpanEvent]): ) return self._buffer.extend(events) - self.buffer_size += len(json.dumps(events)) + self.buffer_size += len(safe_json(events)) def encode(self): with self._lock: @@ -207,7 +203,7 @@ def encode(self): self._init_buffer() data = {"_dd.stage": "raw", "_dd.tracer_version": ddtrace.__version__, "event_type": "span", "spans": events} try: - enc_llm_events = json.dumps(data) + enc_llm_events = safe_json(data) logger.debug("encode %d LLMObs span events to be sent", len(events)) except TypeError: logger.error("failed to encode %d LLMObs span events", len(events), exc_info=True) @@ -277,7 +273,7 @@ def stop(self, timeout=None): super(LLMObsSpanWriter, self).stop(timeout=timeout) def enqueue(self, event: LLMObsSpanEvent) -> None: - event_size = len(json.dumps(event)) + event_size = len(safe_json(event)) if event_size >= EVP_EVENT_SIZE_LIMIT: logger.warning( diff --git a/ddtrace/llmobs/decorators.py b/ddtrace/llmobs/decorators.py index 93f329f2889..7e61f9b4e18 100644 --- a/ddtrace/llmobs/decorators.py +++ b/ddtrace/llmobs/decorators.py @@ -172,7 +172,7 @@ def generator_wrapper(*args, **kwargs): func_signature = signature(func) bound_args = func_signature.bind_partial(*args, **kwargs) if _automatic_io_annotation and bound_args.arguments: - LLMObs.annotate(span=span, input_data=bound_args.arguments) + LLMObs.annotate(span=span, input_data=dict(bound_args.arguments)) return yield_from_async_gen(func, span, args, kwargs) @wraps(func) @@ -186,13 +186,13 @@ async def wrapper(*args, **kwargs): func_signature = signature(func) bound_args = func_signature.bind_partial(*args, **kwargs) if _automatic_io_annotation and bound_args.arguments: - LLMObs.annotate(span=span, input_data=bound_args.arguments) + LLMObs.annotate(span=span, input_data=dict(bound_args.arguments)) resp = await func(*args, **kwargs) if ( _automatic_io_annotation and resp and operation_kind != "retrieval" - and span.get_tag(OUTPUT_VALUE) is None + and span._get_ctx_item(OUTPUT_VALUE) is None ): LLMObs.annotate(span=span, output_data=resp) return resp @@ -211,7 +211,7 @@ def generator_wrapper(*args, **kwargs): func_signature = signature(func) bound_args = func_signature.bind_partial(*args, **kwargs) if _automatic_io_annotation and bound_args.arguments: - LLMObs.annotate(span=span, input_data=bound_args.arguments) + LLMObs.annotate(span=span, input_data=dict(bound_args.arguments)) try: yield from func(*args, **kwargs) except (StopIteration, GeneratorExit): @@ -234,13 +234,13 @@ def wrapper(*args, **kwargs): func_signature = signature(func) bound_args = func_signature.bind_partial(*args, **kwargs) if _automatic_io_annotation and bound_args.arguments: - LLMObs.annotate(span=span, input_data=bound_args.arguments) + LLMObs.annotate(span=span, input_data=dict(bound_args.arguments)) resp = func(*args, **kwargs) if ( _automatic_io_annotation and resp and operation_kind != "retrieval" - and span.get_tag(OUTPUT_VALUE) is None + and span._get_ctx_item(OUTPUT_VALUE) is None ): LLMObs.annotate(span=span, output_data=resp) return resp diff --git a/tests/contrib/anthropic/test_anthropic_llmobs.py b/tests/contrib/anthropic/test_anthropic_llmobs.py index f286a890209..e2850a4157f 100644 --- a/tests/contrib/anthropic/test_anthropic_llmobs.py +++ b/tests/contrib/anthropic/test_anthropic_llmobs.py @@ -1,6 +1,5 @@ from pathlib import Path -import mock import pytest from tests.llmobs._utils import _expected_llmobs_llm_span_event @@ -117,37 +116,6 @@ def test_error(self, anthropic, ddtrace_global_config, mock_llmobs_writer, mock_ ) ) - def test_error_unserializable_arg( - self, anthropic, ddtrace_global_config, mock_llmobs_writer, mock_tracer, request_vcr - ): - """Ensure we handle unserializable arguments correctly and still emit llmobs records.""" - llm = anthropic.Anthropic() - with pytest.raises(Exception): - llm.messages.create( - model="claude-3-opus-20240229", - max_tokens=object(), - temperature=0.8, - messages=[{"role": "user", "content": "Hello World!"}], - ) - - span = mock_tracer.pop_traces()[0][0] - assert mock_llmobs_writer.enqueue.call_count == 1 - expected_span = _expected_llmobs_llm_span_event( - span, - model_name="claude-3-opus-20240229", - model_provider="anthropic", - input_messages=[{"content": "Hello World!", "role": "user"}], - output_messages=[{"content": ""}], - error=span.get_tag("error.type"), - error_message=span.get_tag("error.message"), - error_stack=span.get_tag("error.stack"), - metadata={"temperature": 0.8, "max_tokens": mock.ANY}, - tags={"ml_app": "", "service": "tests.contrib.anthropic"}, - ) - mock_llmobs_writer.enqueue.assert_called_with(expected_span) - actual_span = mock_llmobs_writer.enqueue.call_args[0][0] - assert "[Unserializable object: ", "service": "tests.contrib.openai"}, - ) - mock_llmobs_writer.enqueue.assert_called_with(expected_span) - actual_span = mock_llmobs_writer.enqueue.call_args[0][0] - assert "[Unserializable object: