Skip to content

Commit e293136

Browse files
feat(semconv): expand genai span kind
1 parent fd9626b commit e293136

File tree

7 files changed

+540
-53
lines changed

7 files changed

+540
-53
lines changed

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ def _instrument(self, **kwargs):
8383
wrapper=_BaseCallbackManagerInitWrapper(traceloopCallbackHandler),
8484
)
8585

86+
# Wrap CallbackManager.configure to ensure our handler is included
87+
wrap_function_wrapper(
88+
module="langchain_core.callbacks.manager",
89+
name="CallbackManager.configure",
90+
wrapper=_CallbackManagerConfigureWrapper(traceloopCallbackHandler),
91+
)
92+
8693
if not self.disable_trace_context_propagation:
8794
self._wrap_openai_functions_for_tracing(traceloopCallbackHandler)
8895

@@ -168,6 +175,7 @@ def _wrap_openai_functions_for_tracing(self, traceloopCallbackHandler):
168175

169176
def _uninstrument(self, **kwargs):
170177
unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__")
178+
unwrap("langchain_core.callbacks.manager", "CallbackManager.configure")
171179
if not self.disable_trace_context_propagation:
172180
if is_package_available("langchain_community"):
173181
unwrap("langchain_community.llms.openai", "BaseOpenAI._generate")
@@ -208,6 +216,30 @@ def __call__(
208216
instance.add_handler(self._callback_handler, True)
209217

210218

219+
class _CallbackManagerConfigureWrapper:
220+
def __init__(self, callback_handler: "TraceloopCallbackHandler"):
221+
self._callback_handler = callback_handler
222+
223+
def __call__(
224+
self,
225+
wrapped,
226+
instance,
227+
args,
228+
kwargs,
229+
):
230+
result = wrapped(*args, **kwargs)
231+
232+
if result and hasattr(result, 'add_handler'):
233+
for handler in result.inheritable_handlers:
234+
if isinstance(handler, type(self._callback_handler)):
235+
break
236+
else:
237+
self._callback_handler._callback_manager = result
238+
result.add_handler(self._callback_handler, True)
239+
240+
return result
241+
242+
211243
# This class wraps a function call to inject tracing information (trace headers) into
212244
# OpenAI client requests. It assumes the following:
213245
# 1. The wrapped function includes a `run_manager` keyword argument that contains a `run_id`.

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ def _create_llm_span(
359359
_set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
360360
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, request_type.value)
361361

362+
span_kind = self._determine_llm_span_kind(serialized)
363+
_set_span_attribute(span, SpanAttributes.TRACELOOP_SPAN_KIND, span_kind.value)
364+
362365
# we already have an LLM span by this point,
363366
# so skip any downstream instrumentation from here
364367
try:
@@ -375,6 +378,72 @@ def _create_llm_span(
375378

376379
return span
377380

381+
def _determine_llm_span_kind(self, serialized: Optional[dict[str, Any]]) -> TraceloopSpanKindValues:
382+
"""Determine the appropriate span kind for LLM operations based on model type."""
383+
if not serialized:
384+
return TraceloopSpanKindValues.GENERATION
385+
386+
class_name = _extract_class_name_from_serialized(serialized)
387+
class_name_lower = class_name.lower()
388+
389+
if any(keyword in class_name_lower for keyword in ['embedding', 'embed']):
390+
return TraceloopSpanKindValues.EMBEDDING
391+
392+
# Default to generation for other LLM operations
393+
return TraceloopSpanKindValues.GENERATION
394+
395+
def _determine_chain_span_kind(
396+
self,
397+
serialized: dict[str, Any],
398+
name: str,
399+
tags: Optional[list[str]] = None
400+
) -> TraceloopSpanKindValues:
401+
if serialized and "id" in serialized:
402+
class_path = serialized["id"]
403+
if any("agent" in part.lower() for part in class_path):
404+
return TraceloopSpanKindValues.AGENT
405+
406+
if "agent" in name.lower():
407+
return TraceloopSpanKindValues.AGENT
408+
409+
class_name = _extract_class_name_from_serialized(serialized)
410+
name_lower = name.lower()
411+
412+
# Tool detection for RunnableLambda and custom tool chains
413+
if any(keyword in class_name.lower() for keyword in ['tool']):
414+
return TraceloopSpanKindValues.TOOL
415+
416+
# More precise tool detection: exclude operation like `parsers`
417+
if any(keyword in name_lower for keyword in ['tool']) or (
418+
'function' in name_lower and 'parser' not in name_lower
419+
):
420+
return TraceloopSpanKindValues.TOOL
421+
422+
if tags and any('tool' in tag.lower() for tag in tags):
423+
return TraceloopSpanKindValues.TOOL
424+
425+
# Retriever detection for RunnableLambda and custom tool chains
426+
if any(keyword in class_name.lower() for keyword in ['retriever', 'retrieve', 'vectorstore']):
427+
return TraceloopSpanKindValues.RETRIEVER
428+
429+
if any(keyword in name_lower for keyword in ['retriever', 'retrieve', 'search']):
430+
return TraceloopSpanKindValues.RETRIEVER
431+
432+
# Embedding detection for RunnableLambda and custom chains
433+
if any(keyword in class_name.lower() for keyword in ['embedding', 'embed']):
434+
return TraceloopSpanKindValues.EMBEDDING
435+
436+
if any(keyword in name_lower for keyword in ['embedding', 'embed']):
437+
return TraceloopSpanKindValues.EMBEDDING
438+
439+
if any(keyword in class_name.lower() for keyword in ['rerank', 'reorder']):
440+
return TraceloopSpanKindValues.RERANKER
441+
442+
if any(keyword in name_lower for keyword in ['rerank', 'reorder']):
443+
return TraceloopSpanKindValues.RERANKER
444+
445+
return TraceloopSpanKindValues.TASK
446+
378447
@dont_throw
379448
def on_chain_start(
380449
self,
@@ -395,12 +464,18 @@ def on_chain_start(
395464
entity_path = ""
396465

397466
name = self._get_name_from_callback(serialized, **kwargs)
398-
kind = (
467+
468+
base_kind = (
399469
TraceloopSpanKindValues.WORKFLOW
400470
if parent_run_id is None or parent_run_id not in self.spans
401471
else TraceloopSpanKindValues.TASK
402472
)
403473

474+
if base_kind == TraceloopSpanKindValues.TASK:
475+
kind = self._determine_chain_span_kind(serialized, name, tags)
476+
else:
477+
kind = base_kind
478+
404479
if kind == TraceloopSpanKindValues.WORKFLOW:
405480
workflow_name = name
406481
else:
@@ -710,6 +785,73 @@ def on_tool_end(
710785
)
711786
self._end_span(span, run_id)
712787

788+
@dont_throw
789+
def on_retriever_start(
790+
self,
791+
serialized: dict[str, Any],
792+
query: str,
793+
*,
794+
run_id: UUID,
795+
parent_run_id: Optional[UUID] = None,
796+
tags: Optional[list[str]] = None,
797+
metadata: Optional[dict[str, Any]] = None,
798+
**kwargs: Any,
799+
) -> None:
800+
"""Run when retriever starts running."""
801+
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
802+
return
803+
804+
name = self._get_name_from_callback(serialized, kwargs=kwargs)
805+
workflow_name = self.get_workflow_name(parent_run_id)
806+
entity_path = self.get_entity_path(parent_run_id)
807+
808+
span = self._create_task_span(
809+
run_id,
810+
parent_run_id,
811+
name,
812+
TraceloopSpanKindValues.RETRIEVER,
813+
workflow_name,
814+
name,
815+
entity_path,
816+
)
817+
if not should_emit_events() and should_send_prompts():
818+
span.set_attribute(
819+
SpanAttributes.TRACELOOP_ENTITY_INPUT,
820+
json.dumps(
821+
{
822+
"query": query,
823+
"tags": tags,
824+
"metadata": metadata,
825+
"kwargs": kwargs,
826+
},
827+
cls=CallbackFilteredJSONEncoder,
828+
),
829+
)
830+
831+
@dont_throw
832+
def on_retriever_end(
833+
self,
834+
documents: Any,
835+
*,
836+
run_id: UUID,
837+
parent_run_id: Optional[UUID] = None,
838+
**kwargs: Any,
839+
) -> None:
840+
"""Run when retriever ends running."""
841+
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
842+
return
843+
844+
span = self._get_span(run_id)
845+
if not should_emit_events() and should_send_prompts():
846+
span.set_attribute(
847+
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
848+
json.dumps(
849+
{"documents": str(documents)[:1000], "kwargs": kwargs}, # Limit output size
850+
cls=CallbackFilteredJSONEncoder,
851+
),
852+
)
853+
self._end_span(span, run_id)
854+
713855
def get_parent_span(self, parent_run_id: Optional[str] = None):
714856
if parent_run_id is None:
715857
return None

0 commit comments

Comments
 (0)