@@ -359,6 +359,9 @@ def _create_llm_span(
359
359
_set_span_attribute (span , SpanAttributes .LLM_SYSTEM , vendor )
360
360
_set_span_attribute (span , SpanAttributes .LLM_REQUEST_TYPE , request_type .value )
361
361
362
+ span_kind = self ._determine_llm_span_kind (serialized )
363
+ _set_span_attribute (span , SpanAttributes .TRACELOOP_SPAN_KIND , span_kind .value )
364
+
362
365
# we already have an LLM span by this point,
363
366
# so skip any downstream instrumentation from here
364
367
try :
@@ -375,6 +378,72 @@ def _create_llm_span(
375
378
376
379
return span
377
380
381
+ def _determine_llm_span_kind (self , serialized : Optional [dict [str , Any ]]) -> TraceloopSpanKindValues :
382
+ """Determine the appropriate span kind for LLM operations based on model type."""
383
+ if not serialized :
384
+ return TraceloopSpanKindValues .GENERATION
385
+
386
+ class_name = _extract_class_name_from_serialized (serialized )
387
+ class_name_lower = class_name .lower ()
388
+
389
+ if any (keyword in class_name_lower for keyword in ['embedding' , 'embed' ]):
390
+ return TraceloopSpanKindValues .EMBEDDING
391
+
392
+ # Default to generation for other LLM operations
393
+ return TraceloopSpanKindValues .GENERATION
394
+
395
+ def _determine_chain_span_kind (
396
+ self ,
397
+ serialized : dict [str , Any ],
398
+ name : str ,
399
+ tags : Optional [list [str ]] = None
400
+ ) -> TraceloopSpanKindValues :
401
+ if serialized and "id" in serialized :
402
+ class_path = serialized ["id" ]
403
+ if any ("agent" in part .lower () for part in class_path ):
404
+ return TraceloopSpanKindValues .AGENT
405
+
406
+ if "agent" in name .lower ():
407
+ return TraceloopSpanKindValues .AGENT
408
+
409
+ class_name = _extract_class_name_from_serialized (serialized )
410
+ name_lower = name .lower ()
411
+
412
+ # Tool detection for RunnableLambda and custom tool chains
413
+ if any (keyword in class_name .lower () for keyword in ['tool' ]):
414
+ return TraceloopSpanKindValues .TOOL
415
+
416
+ # More precise tool detection: exclude operation like `parsers`
417
+ if any (keyword in name_lower for keyword in ['tool' ]) or (
418
+ 'function' in name_lower and 'parser' not in name_lower
419
+ ):
420
+ return TraceloopSpanKindValues .TOOL
421
+
422
+ if tags and any ('tool' in tag .lower () for tag in tags ):
423
+ return TraceloopSpanKindValues .TOOL
424
+
425
+ # Retriever detection for RunnableLambda and custom tool chains
426
+ if any (keyword in class_name .lower () for keyword in ['retriever' , 'retrieve' , 'vectorstore' ]):
427
+ return TraceloopSpanKindValues .RETRIEVER
428
+
429
+ if any (keyword in name_lower for keyword in ['retriever' , 'retrieve' , 'search' ]):
430
+ return TraceloopSpanKindValues .RETRIEVER
431
+
432
+ # Embedding detection for RunnableLambda and custom chains
433
+ if any (keyword in class_name .lower () for keyword in ['embedding' , 'embed' ]):
434
+ return TraceloopSpanKindValues .EMBEDDING
435
+
436
+ if any (keyword in name_lower for keyword in ['embedding' , 'embed' ]):
437
+ return TraceloopSpanKindValues .EMBEDDING
438
+
439
+ if any (keyword in class_name .lower () for keyword in ['rerank' , 'reorder' ]):
440
+ return TraceloopSpanKindValues .RERANKER
441
+
442
+ if any (keyword in name_lower for keyword in ['rerank' , 'reorder' ]):
443
+ return TraceloopSpanKindValues .RERANKER
444
+
445
+ return TraceloopSpanKindValues .TASK
446
+
378
447
@dont_throw
379
448
def on_chain_start (
380
449
self ,
@@ -395,12 +464,18 @@ def on_chain_start(
395
464
entity_path = ""
396
465
397
466
name = self ._get_name_from_callback (serialized , ** kwargs )
398
- kind = (
467
+
468
+ base_kind = (
399
469
TraceloopSpanKindValues .WORKFLOW
400
470
if parent_run_id is None or parent_run_id not in self .spans
401
471
else TraceloopSpanKindValues .TASK
402
472
)
403
473
474
+ if base_kind == TraceloopSpanKindValues .TASK :
475
+ kind = self ._determine_chain_span_kind (serialized , name , tags )
476
+ else :
477
+ kind = base_kind
478
+
404
479
if kind == TraceloopSpanKindValues .WORKFLOW :
405
480
workflow_name = name
406
481
else :
@@ -710,6 +785,73 @@ def on_tool_end(
710
785
)
711
786
self ._end_span (span , run_id )
712
787
788
+ @dont_throw
789
+ def on_retriever_start (
790
+ self ,
791
+ serialized : dict [str , Any ],
792
+ query : str ,
793
+ * ,
794
+ run_id : UUID ,
795
+ parent_run_id : Optional [UUID ] = None ,
796
+ tags : Optional [list [str ]] = None ,
797
+ metadata : Optional [dict [str , Any ]] = None ,
798
+ ** kwargs : Any ,
799
+ ) -> None :
800
+ """Run when retriever starts running."""
801
+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
802
+ return
803
+
804
+ name = self ._get_name_from_callback (serialized , kwargs = kwargs )
805
+ workflow_name = self .get_workflow_name (parent_run_id )
806
+ entity_path = self .get_entity_path (parent_run_id )
807
+
808
+ span = self ._create_task_span (
809
+ run_id ,
810
+ parent_run_id ,
811
+ name ,
812
+ TraceloopSpanKindValues .RETRIEVER ,
813
+ workflow_name ,
814
+ name ,
815
+ entity_path ,
816
+ )
817
+ if not should_emit_events () and should_send_prompts ():
818
+ span .set_attribute (
819
+ SpanAttributes .TRACELOOP_ENTITY_INPUT ,
820
+ json .dumps (
821
+ {
822
+ "query" : query ,
823
+ "tags" : tags ,
824
+ "metadata" : metadata ,
825
+ "kwargs" : kwargs ,
826
+ },
827
+ cls = CallbackFilteredJSONEncoder ,
828
+ ),
829
+ )
830
+
831
+ @dont_throw
832
+ def on_retriever_end (
833
+ self ,
834
+ documents : Any ,
835
+ * ,
836
+ run_id : UUID ,
837
+ parent_run_id : Optional [UUID ] = None ,
838
+ ** kwargs : Any ,
839
+ ) -> None :
840
+ """Run when retriever ends running."""
841
+ if context_api .get_value (_SUPPRESS_INSTRUMENTATION_KEY ):
842
+ return
843
+
844
+ span = self ._get_span (run_id )
845
+ if not should_emit_events () and should_send_prompts ():
846
+ span .set_attribute (
847
+ SpanAttributes .TRACELOOP_ENTITY_OUTPUT ,
848
+ json .dumps (
849
+ {"documents" : str (documents )[:1000 ], "kwargs" : kwargs }, # Limit output size
850
+ cls = CallbackFilteredJSONEncoder ,
851
+ ),
852
+ )
853
+ self ._end_span (span , run_id )
854
+
713
855
def get_parent_span (self , parent_run_id : Optional [str ] = None ):
714
856
if parent_run_id is None :
715
857
return None
0 commit comments