Skip to content

(do not merge) link insights for langgraph integration #14317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions ddtrace/contrib/internal/langgraph/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ddtrace.contrib.trace_utils import wrap
from ddtrace.internal.utils import get_argument_value
from ddtrace.llmobs._integrations.langgraph import LangGraphIntegration
from ddtrace.llmobs._integrations.langgraph import LangGraphRoutingContext
from ddtrace.trace import Pin


Expand Down Expand Up @@ -206,6 +207,23 @@ def patched_pregel_loop_tick(langgraph, pin, func, instance, args, kwargs):
return result


@with_traced_module
def traced_runnable_callable_invoke(langgraph, pin, func, instance, args, kwargs):
integration: LangGraphIntegration = langgraph._datadog_integration

if not integration.llmobs_enabled or not integration.llm_influenced_control_enabled:
return func(*args, **kwargs)

node_name = getattr(instance, "name", None) or getattr(getattr(instance, "func", None), "__name__", None)
with integration.routing_context(node_name, args, kwargs) as ctx:
fn_args = ctx.get_args() if isinstance(ctx, LangGraphRoutingContext) else args
result = func(*fn_args, **kwargs)

return integration.get_llmobs_state(
node_name, input_state=get_argument_value(args, kwargs, 0, "input"), output_state=result
)


def patch():
if getattr(langgraph, "_datadog_patch", False):
return
Expand All @@ -218,6 +236,7 @@ def patch():

from langgraph.pregel import Pregel
from langgraph.pregel.loop import PregelLoop
from langgraph.utils.runnable import RunnableCallable
from langgraph.utils.runnable import RunnableSeq

wrap(RunnableSeq, "invoke", traced_runnable_seq_invoke(langgraph))
Expand All @@ -226,6 +245,8 @@ def patch():
wrap(Pregel, "astream", traced_pregel_astream(langgraph))
wrap(PregelLoop, "tick", patched_pregel_loop_tick(langgraph))

wrap(RunnableCallable, "invoke", traced_runnable_callable_invoke(langgraph))


def unpatch():
if not getattr(langgraph, "_datadog_patch", False):
Expand All @@ -235,6 +256,7 @@ def unpatch():

from langgraph.pregel import Pregel
from langgraph.pregel.loop import PregelLoop
from langgraph.utils.runnable import RunnableCallable
from langgraph.utils.runnable import RunnableSeq

unwrap(RunnableSeq, "invoke")
Expand All @@ -243,4 +265,6 @@ def unpatch():
unwrap(Pregel, "astream")
unwrap(PregelLoop, "tick")

unwrap(RunnableCallable, "invoke")

delattr(langgraph, "_datadog_integration")
5 changes: 5 additions & 0 deletions ddtrace/llmobs/_integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ def llmobs_enabled(self) -> bool:
"""Return whether submitting llmobs payloads is enabled."""
return LLMObs.enabled

@property
def llm_influenced_control_enabled(self) -> bool:
"""Return whether LLM influenced control is enabled."""
return config._llmobs_llm_influenced_control_enabled

def is_pc_sampled_span(self, span: Span) -> bool:
if span.context.sampling_priority is not None and span.context.sampling_priority <= 0:
return False
Expand Down
110 changes: 101 additions & 9 deletions ddtrace/llmobs/_integrations/langgraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from contextlib import suppress
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from ddtrace.internal.utils import get_argument_value
from ddtrace.internal.utils.formats import format_trace_id
Expand All @@ -17,13 +19,63 @@
from ddtrace.llmobs._integrations.utils import format_langchain_io
from ddtrace.llmobs._utils import _get_attr
from ddtrace.llmobs._utils import _get_nearest_llmobs_ancestor
from ddtrace.llmobs.utils import LLMObsState
from ddtrace.trace import Span


class LangGraphRoutingContext:
def __init__(self, state: LLMObsState, current_node_metadata: Dict[str, Any], args: tuple):
self.state = state
self.current_node_metadata = current_node_metadata
self.args = args

def get_args(self):
return (self.state, *self.args[1:])

def __enter__(self):
self._do_enter()
return self

def __aenter__(self):
self._do_enter()
return self

def __exit__(self, exc_type, exc_value, traceback):
self._do_exit()

def __aexit__(self, exc_type, exc_value, traceback):
self._do_exit()

def _do_enter(self):
self.state.set_reading(carrier=self.current_node_metadata, carrier_key="influenced_by")

def _do_exit(self):
self.state.stop_reading()


class LangGraphIntegration(BaseLLMIntegration):
_integration_name = "langgraph"
_graph_nodes_by_task_id: Dict[str, Any] = {} # maps task_id to dictionary of name, span, and span_links

def routing_context(self, node_name, args, kwargs) -> Union[LangGraphRoutingContext, suppress]:
if not self.llmobs_enabled or not self.llm_influenced_control_enabled:
return suppress()

state: Optional[Dict[str, Any]] = get_argument_value(args, kwargs, 0, "input")
config = get_argument_value(args, kwargs, 1, "config", optional=True) or {}
task_id = config.get("metadata", {}).get("langgraph_checkpoint_ns", "").split(":")[-1]
current_node_metadata: Dict[str, Any] = self._graph_nodes_by_task_id.get(task_id, {})
current_node_name = current_node_metadata.get("name", None)

if node_name in ("_write", "_route", "_control_branch") or (node_name == current_node_name):
return suppress()

return LangGraphRoutingContext(
state=LLMObsState.from_dict(state),
current_node_metadata=current_node_metadata,
args=args,
)

def _llmobs_set_tags(
self,
span: Span,
Expand Down Expand Up @@ -51,12 +103,15 @@ def _llmobs_set_tags(
span._set_ctx_items(
{
SPAN_KIND: "agent" if operation == "graph" else "task",
INPUT_VALUE: format_langchain_io(inputs),
OUTPUT_VALUE: format_langchain_io(response),
INPUT_VALUE: format_langchain_io(inputs.to_state_dict() if isinstance(inputs, LLMObsState) else inputs),
OUTPUT_VALUE: format_langchain_io(
response.to_state_dict() if isinstance(response, LLMObsState) else response
),
NAME: self._graph_nodes_by_task_id.get(instance_id, {}).get("name") or kwargs.get("name", span.name),
SPAN_LINKS: current_span_links + span_links,
}
)

if operation == "graph" and not _is_subgraph(span):
self._graph_nodes_by_task_id.clear()

Expand All @@ -77,20 +132,30 @@ def llmobs_handle_pregel_loop_tick(
return
finished_task_names_to_ids = {task.name: task_id for task_id, task in finished_tasks.items()}
for task_id, task in next_tasks.items():
self._link_task_to_parent(task_id, task, finished_task_names_to_ids)
trigger_node_ids = self._link_task_to_parent(task_id, task, finished_task_names_to_ids)
if self.llm_influenced_control_enabled:
self._set_llmobs_state(task, task_id, next_tasks, finished_tasks, trigger_node_ids)

def _handle_finished_graph(self, graph_span, finished_tasks, is_subgraph_node):
def _handle_finished_graph(self, graph_span: Span, finished_tasks, is_subgraph_node):
"""Create the span links for a finished pregel graph from all finished tasks as the graph span's outputs.
Generate the output-to-output span links for the last nodes in a pregel graph.
If the graph isn't a subgraph, add a span link from the graph span to the calling LLMObs parent span.
Note: is_subgraph_node denotes whether the graph is a subgraph node,
not whether it is a standalone graph (called internally during a node execution).
"""
graph_caller_span = _get_nearest_llmobs_ancestor(graph_span) if graph_span else None
output_span_links = [
{**self._graph_nodes_by_task_id[task_id]["span"], "attributes": {"from": "output", "to": "output"}}
for task_id in finished_tasks.keys()
]
output_span_links = []
for task_id in finished_tasks.keys():
graph_node = self._graph_nodes_by_task_id.get(task_id, {})
graph_node_span = graph_node.get("span")
graph_node_influenced_by = graph_node.get("influenced_by", {})
output_span_links.append(
{
**graph_node_span,
"attributes": {"from": "output", "to": "output"},
}
)
output_span_links.extend(graph_node_influenced_by)
graph_span_span_links = graph_span._get_ctx_item(SPAN_LINKS) or []
graph_span._set_ctx_item(SPAN_LINKS, graph_span_span_links + output_span_links)
if graph_caller_span is not None and not is_subgraph_node:
Expand Down Expand Up @@ -123,7 +188,8 @@ def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
queued_node = self._graph_nodes_by_task_id.setdefault(task_id, {})
queued_node["name"] = getattr(task, "name", "")

trigger_node_span = self._graph_nodes_by_task_id.get(node_id, {}).get("span")
trigger_node: dict = self._graph_nodes_by_task_id.get(node_id, {})
trigger_node_span = trigger_node.get("span")
if not trigger_node_span:
# Subgraphs that are called at the start of the graph need to be named, but don't need any span links
continue
Expand All @@ -136,6 +202,32 @@ def _link_task_to_parent(self, task_id, task, finished_task_names_to_ids):
span_links = queued_node.setdefault("span_links", [])
span_links.append(span_link)

span_links.extend(trigger_node.get("influenced_by", []))

return trigger_node_ids

def _set_llmobs_state(self, task, task_id, next_tasks, finished_tasks, trigger_node_ids):
trigger_nodes = [finished_tasks.get(task_id) for task_id in trigger_node_ids if task_id]
old_llmobs_states = [getattr(trigger_node, "input", None) for trigger_node in trigger_nodes if trigger_node]
next_tasks[task_id] = task._replace(
input=LLMObsState.from_state(llmobs_states=old_llmobs_states, state=task.input, service=LLMObs._instance)
)

def get_llmobs_state(self, node_name: str, input_state, output_state: Optional[Dict[str, Any]] = None):
if (
not self.llmobs_enabled
or not isinstance(input_state, LLMObsState)
or not isinstance(output_state, dict)
or node_name in ("_write", "_route", "_control_branch")
):
return output_state

for key in output_state.keys():
input_state._handle_set(key)

# coerce the output state into an LLMObsState
return LLMObsState.from_state(llmobs_states=[input_state], state=output_state, service=LLMObs._instance)


def _normalize_triggers(triggers, finished_tasks, next_task) -> List[str]:
"""
Expand Down
132 changes: 132 additions & 0 deletions ddtrace/llmobs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Union

from ddtrace.llmobs._constants import SPAN_LINKS


# TypedDict was added to typing in python 3.8
try:
Expand All @@ -10,6 +14,8 @@
from typing_extensions import TypedDict

from ddtrace.internal.logger import get_logger
from ddtrace.internal.utils.formats import format_trace_id
from ddtrace.trace import Span


log = get_logger(__name__)
Expand Down Expand Up @@ -89,3 +95,129 @@ def __init__(self, documents: Union[List[DocumentType], DocumentType, str]):
raise TypeError("document score must be an integer or float.")
formatted_document["score"] = document_score
self.documents.append(formatted_document)


class LLMObsState(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.llmobs_service = kwargs.get("llmobs_service", None)
self.proxy: Dict[str, Dict[str, Any]] = kwargs.get("_proxy", {})

self.reading = False

def set_reading(self, carrier: Optional[Dict[str, Any]] = None, carrier_key: Optional[str] = None):
self.carrier = carrier
self.carrier_key = carrier_key

self.reading = True

def stop_reading(self):
self.reading = False

self.carrier = None
self.carrier_key = None

def __getitem__(self, key):
self._handle_get(key)
return super().__getitem__(key)

def get(self, key, default=None):
self._handle_get(key)
return super().get(key, default)

def _handle_get(self, key: str):
if not self.reading:
return

from_spans_meta: Optional[Dict[str, Any]] = self.proxy.get(key, None)

current_span: Span = self.llmobs_service._current_span()
existing_links = (
self.carrier.get(self.carrier_key, [])
if self.carrier is not None and self.carrier_key is not None
else current_span._get_ctx_item(SPAN_LINKS) or []
)

if not from_spans_meta:
return

from_spans: Optional[List[Dict[str, Any]]] = from_spans_meta.get("spans", None)
if from_spans is None:
return

for span in from_spans:
if span is None:
continue
existing_links.append(
{
"trace_id": span["trace_id"],
"span_id": span["span_id"],
"attributes": {
"source": "influence",
"accessed_attribute": key,
},
}
)

from_spans_meta["used"] = True

if self.carrier is not None and self.carrier_key is not None:
self.carrier[self.carrier_key] = existing_links
else:
current_span._set_ctx_item(SPAN_LINKS, existing_links)

def _handle_set(self, key: str):
if key in ("_proxy", "llmobs_service"):
return

current_span: Span = self.llmobs_service._current_span()
spans_meta: Dict[str, Any] = self.proxy.setdefault(key, {})
spans: Optional[List[Dict[str, Any]]] = spans_meta.get("spans", None)
if spans is None:
spans_meta["spans"] = [
{
"trace_id": format_trace_id(current_span.trace_id),
"span_id": str(current_span.span_id),
}
]
else:
if spans_meta.get("used", False):
spans.clear()
del spans_meta["used"]
spans.append(
{
"trace_id": format_trace_id(current_span.trace_id),
"span_id": str(current_span.span_id),
}
)

def to_state_dict(self):
dict_keys = [key for key in self.keys() if key not in ("_proxy", "llmobs_service")]
return {key: self[key] for key in dict_keys}

@staticmethod
def from_state(llmobs_states: Union["LLMObsState", List["LLMObsState"]], state: Dict, service):
llmobs_proxies = llmobs_states if isinstance(llmobs_states, list) else [llmobs_states]

# merge spans of all llmobs_states
merged_llmobs_proxies: Dict[str, Dict[str, Any]] = {}
for llmobs_proxy in llmobs_proxies:
if llmobs_proxy is None:
continue
for key, value in llmobs_proxy.proxy.items():
if key not in merged_llmobs_proxies:
merged_llmobs_proxies[key] = value
else:
merged_llmobs_proxies[key]["spans"].extend(value["spans"])

return LLMObsState(state, _proxy=merged_llmobs_proxies, llmobs_service=service)

@staticmethod
def from_dict(state: Optional[Dict[str, Any]]):
if isinstance(state, LLMObsState) or state is None:
return state

proxy = state.pop("_proxy", {})
llmobs_service = state.pop("llmobs_service", None)
return LLMObsState(state, _proxy=proxy, llmobs_service=llmobs_service)
Loading