Skip to content

Commit

Permalink
langchain + langgraph
Browse files Browse the repository at this point in the history
  • Loading branch information
sabrenner committed Feb 3, 2025
1 parent af9098c commit 39df84d
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 2 deletions.
37 changes: 37 additions & 0 deletions ddtrace/contrib/internal/langchain/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def traced_llm_generate(langchain, pin, func, instance, args, kwargs):
api_key=_extract_api_key(instance),
)
completions = None

integration.record_instance(instance, span)

try:
if integration.is_pc_sampled_span(span):
for idx, prompt in enumerate(prompts):
Expand Down Expand Up @@ -283,6 +286,9 @@ async def traced_llm_agenerate(langchain, pin, func, instance, args, kwargs):
model=model,
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

completions = None
try:
if integration.is_pc_sampled_span(span):
Expand Down Expand Up @@ -353,6 +359,9 @@ def traced_chat_model_generate(langchain, pin, func, instance, args, kwargs):
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

chat_completions = None
try:
for message_set_idx, message_set in enumerate(chat_messages):
Expand Down Expand Up @@ -479,6 +488,9 @@ async def traced_chat_model_agenerate(langchain, pin, func, instance, args, kwar
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

chat_completions = None
try:
for message_set_idx, message_set in enumerate(chat_messages):
Expand Down Expand Up @@ -612,6 +624,9 @@ def traced_embedding(langchain, pin, func, instance, args, kwargs):
model=_extract_model_name(instance),
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

embeddings = None
try:
if isinstance(input_texts, str):
Expand Down Expand Up @@ -763,6 +778,12 @@ async def traced_chain_acall(langchain, pin, func, instance, args, kwargs):
return final_outputs


def _extract_bound(instance):
if hasattr(instance, "bound"):
return instance.bound
return instance


@with_traced_module
def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
"""
Expand All @@ -786,6 +807,9 @@ def traced_lcel_runnable_sequence(langchain, pin, func, instance, args, kwargs):
)
inputs = None
final_output = None

integration.record_steps(instance, span)

try:
try:
inputs = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -832,6 +856,9 @@ async def traced_lcel_runnable_sequence_async(langchain, pin, func, instance, ar
)
inputs = None
final_output = None

integration.record_steps(instance, span)

try:
try:
inputs = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -878,6 +905,9 @@ def traced_similarity_search(langchain, pin, func, instance, args, kwargs):
provider=provider,
api_key=_extract_api_key(instance),
)

integration.record_instance(instance, span)

documents = []
try:
if integration.is_pc_sampled_span(span):
Expand Down Expand Up @@ -940,6 +970,7 @@ def traced_chain_stream(langchain, pin, func, instance, args, kwargs):
integration: LangChainIntegration = langchain._datadog_integration

def _on_span_started(span: Span):
integration.record_instance(instance, span)
inputs = get_argument_value(args, kwargs, 0, "input")
if not integration.is_pc_sampled_span(span):
return
Expand Down Expand Up @@ -997,6 +1028,7 @@ def traced_chat_stream(langchain, pin, func, instance, args, kwargs):
model = _extract_model_name(instance)

def _on_span_started(span: Span):
integration.record_instance(instance, span)
if not integration.is_pc_sampled_span(span):
return
chat_messages = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -1056,6 +1088,7 @@ def traced_llm_stream(langchain, pin, func, instance, args, kwargs):
model = _extract_model_name(instance)

def _on_span_start(span: Span):
integration.record_instance(instance, span)
if not integration.is_pc_sampled_span(span):
return
inp = get_argument_value(args, kwargs, 0, "input")
Expand Down Expand Up @@ -1103,6 +1136,8 @@ def traced_base_tool_invoke(langchain, pin, func, instance, args, kwargs):
submit_to_llmobs=True,
)

integration.record_instance(instance, span)

tool_output = None
tool_info = {}
try:
Expand Down Expand Up @@ -1154,6 +1189,8 @@ async def traced_base_tool_ainvoke(langchain, pin, func, instance, args, kwargs)
submit_to_llmobs=True,
)

integration.record_instance(instance, span)

tool_output = None
tool_info = {}
try:
Expand Down
108 changes: 106 additions & 2 deletions ddtrace/llmobs/_integrations/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import List
from typing import Optional
from typing import Union
from weakref import WeakKeyDictionary

from ddtrace import config
from ddtrace._trace.span import Span
Expand All @@ -26,9 +27,11 @@
from ddtrace.llmobs._constants import OUTPUT_TOKENS_METRIC_KEY
from ddtrace.llmobs._constants import OUTPUT_VALUE
from ddtrace.llmobs._constants import SPAN_KIND
from ddtrace.llmobs._constants import SPAN_LINKS
from ddtrace.llmobs._constants import TOTAL_TOKENS_METRIC_KEY
from ddtrace.llmobs._integrations.base import BaseLLMIntegration
from ddtrace.llmobs._integrations.utils import format_langchain_io
from ddtrace.llmobs._utils import _get_nearest_llmobs_ancestor
from ddtrace.llmobs.utils import Document


Expand Down Expand Up @@ -56,9 +59,39 @@
SUPPORTED_OPERATIONS = ["llm", "chat", "chain", "embedding", "retrieval", "tool"]


def _extract_bound(instance):
if hasattr(instance, "bound"):
return instance.bound
return instance


class LangChainIntegration(BaseLLMIntegration):
_integration_name = "langchain"

_chain_steps = set() # instance_id
_spans = {} # instance_id --> span
_instances = WeakKeyDictionary() # spans --> instances

def record_steps(self, instance, span):
if not self.llmobs_enabled:
return

steps = getattr(instance, "steps", [])
for step in steps:
step = _extract_bound(step)
self._chain_steps.add(id(step))

self.record_instance(instance, span)

def record_instance(self, instance, span):
if not self.llmobs_enabled:
return

instance = _extract_bound(instance)

self._instances[span] = instance
self._spans[id(instance)] = span

def _llmobs_set_tags(
self,
span: Span,
Expand All @@ -74,6 +107,8 @@ def _llmobs_set_tags(
log.warning("Unsupported operation : %s", operation)
return

self._set_links(span)

model_provider = span.get_tag(PROVIDER)
self._llmobs_set_metadata(span, model_provider)

Expand Down Expand Up @@ -111,6 +146,75 @@ def _llmobs_set_tags(
elif operation == "tool":
self._llmobs_set_meta_tags_from_tool(span, tool_inputs=kwargs, tool_output=response)

def _set_links(self, span: Span):
instance = self._instances.get(span) # TODO can maybe just pass instance as part of `kwargs`
if not instance:
return

instance = _extract_bound(instance)
is_step = id(instance) in self._chain_steps

invoker_span = _get_nearest_llmobs_ancestor(span)
invoker_link_attributes = {"from": "input", "to": "input"}

if invoker_span is None:
return

links = []

if is_step:
chain_instance = _extract_bound(self._instances.get(invoker_span))
steps = getattr(chain_instance, "steps", [])

idx = -1
for i, step in enumerate(steps):
step = _extract_bound(step)
if id(step) == id(instance):
idx = i
break

for i in range(idx - 1, -1, -1):
step = _extract_bound(steps[i])
if id(step) in self._spans:
invoker_span = self._spans[id(step)]
invoker_link_attributes = {"from": "output", "to": "input"}
break

links.append(
{
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(invoker_span.span_id),
"attributes": invoker_link_attributes,
}
)

existing_span_links = span._get_ctx_item(SPAN_LINKS) or []
span._set_ctx_item(SPAN_LINKS, existing_span_links + links)

invoker_links = invoker_span._get_ctx_item(SPAN_LINKS) or []
index = next(
(
i
for i, link in enumerate(invoker_links)
if link["attributes"]["from"] == "output" and link["attributes"]["to"] == "output"
),
None,
)
if is_step and index is not None:
invoker_links.pop(index)

invoker_span._set_ctx_item(
SPAN_LINKS,
invoker_links
+ [
{
"trace_id": "{:x}".format(span.trace_id),
"span_id": str(span.span_id),
"attributes": {"from": "output", "to": "output"},
}
],
)

def _llmobs_set_metadata(self, span: Span, model_provider: Optional[str] = None) -> None:
if not model_provider:
return
Expand Down Expand Up @@ -417,8 +521,8 @@ def _llmobs_set_meta_tags_from_tool(self, span: Span, tool_inputs: Dict[str, Any
formatted_input = ""
if tool_inputs is not None:
tool_input = tool_inputs.get("input")
if tool_inputs.get("config"):
metadata["tool_config"] = tool_inputs.get("config")
# if tool_inputs.get("config"):
# metadata["tool_config"] = tool_inputs.get("config")
if tool_inputs.get("info"):
metadata["tool_info"] = tool_inputs.get("info")
formatted_input = format_langchain_io(tool_input)
Expand Down

0 comments on commit 39df84d

Please sign in to comment.