|
| 1 | +"""FastMCP-specific instrumentation logic.""" |
| 2 | + |
| 3 | +import json |
| 4 | +import os |
| 5 | + |
| 6 | +from opentelemetry.trace import Tracer |
| 7 | +from opentelemetry.trace.status import Status, StatusCode |
| 8 | +from opentelemetry.semconv_ai import SpanAttributes, TraceloopSpanKindValues |
| 9 | +from opentelemetry.semconv.attributes.error_attributes import ERROR_TYPE |
| 10 | +from wrapt import register_post_import_hook, wrap_function_wrapper |
| 11 | + |
| 12 | +from .utils import dont_throw |
| 13 | + |
| 14 | + |
| 15 | +class FastMCPInstrumentor: |
| 16 | + """Handles FastMCP-specific instrumentation logic.""" |
| 17 | + |
| 18 | + def __init__(self): |
| 19 | + self._tracer = None |
| 20 | + |
| 21 | + def instrument(self, tracer: Tracer): |
| 22 | + """Apply FastMCP-specific instrumentation.""" |
| 23 | + self._tracer = tracer |
| 24 | + |
| 25 | + # Instrument FastMCP server-side tool execution |
| 26 | + register_post_import_hook( |
| 27 | + lambda _: wrap_function_wrapper( |
| 28 | + "fastmcp.tools.tool_manager", "ToolManager.call_tool", self._fastmcp_tool_wrapper() |
| 29 | + ), |
| 30 | + "fastmcp.tools.tool_manager", |
| 31 | + ) |
| 32 | + |
| 33 | + def uninstrument(self): |
| 34 | + """Remove FastMCP-specific instrumentation.""" |
| 35 | + # Note: wrapt doesn't provide a clean way to unwrap post-import hooks |
| 36 | + # This is a limitation we'll need to document |
| 37 | + pass |
| 38 | + |
| 39 | + def _fastmcp_tool_wrapper(self): |
| 40 | + """Create wrapper for FastMCP tool execution.""" |
| 41 | + @dont_throw |
| 42 | + async def traced_method(wrapped, instance, args, kwargs): |
| 43 | + if not self._tracer: |
| 44 | + return await wrapped(*args, **kwargs) |
| 45 | + |
| 46 | + # Extract tool name from arguments - FastMCP has different call patterns |
| 47 | + tool_key = None |
| 48 | + tool_arguments = {} |
| 49 | + |
| 50 | + # Pattern 1: kwargs with 'key' parameter |
| 51 | + if kwargs and 'key' in kwargs: |
| 52 | + tool_key = kwargs.get('key') |
| 53 | + tool_arguments = kwargs.get('arguments', {}) |
| 54 | + # Pattern 2: positional args (tool_name, arguments) |
| 55 | + elif args and len(args) >= 1: |
| 56 | + tool_key = args[0] |
| 57 | + tool_arguments = args[1] if len(args) > 1 else {} |
| 58 | + |
| 59 | + entity_name = tool_key if tool_key else "unknown_tool" |
| 60 | + span_name = f"{entity_name}.tool" |
| 61 | + |
| 62 | + with self._tracer.start_as_current_span(span_name) as span: |
| 63 | + span.set_attribute(SpanAttributes.TRACELOOP_SPAN_KIND, TraceloopSpanKindValues.TOOL.value) |
| 64 | + span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_NAME, entity_name) |
| 65 | + |
| 66 | + if self._should_send_prompts(): |
| 67 | + try: |
| 68 | + input_data = { |
| 69 | + "tool_name": entity_name, |
| 70 | + "arguments": tool_arguments |
| 71 | + } |
| 72 | + json_input = json.dumps(input_data, cls=self._get_json_encoder()) |
| 73 | + truncated_input = self._truncate_json_if_needed(json_input) |
| 74 | + span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_INPUT, truncated_input) |
| 75 | + except (TypeError, ValueError): |
| 76 | + pass # Skip input logging if serialization fails |
| 77 | + |
| 78 | + try: |
| 79 | + result = await wrapped(*args, **kwargs) |
| 80 | + |
| 81 | + # Add output in traceloop format |
| 82 | + if self._should_send_prompts() and result: |
| 83 | + try: |
| 84 | + # Convert FastMCP Content objects to serializable format |
| 85 | + output_data = [] |
| 86 | + for item in result: |
| 87 | + if hasattr(item, 'text'): |
| 88 | + output_data.append({"type": "text", "content": item.text}) |
| 89 | + elif hasattr(item, '__dict__'): |
| 90 | + output_data.append(item.__dict__) |
| 91 | + else: |
| 92 | + output_data.append(str(item)) |
| 93 | + |
| 94 | + json_output = json.dumps(output_data, cls=self._get_json_encoder()) |
| 95 | + truncated_output = self._truncate_json_if_needed(json_output) |
| 96 | + span.set_attribute(SpanAttributes.TRACELOOP_ENTITY_OUTPUT, truncated_output) |
| 97 | + except (TypeError, ValueError): |
| 98 | + pass # Skip output logging if serialization fails |
| 99 | + |
| 100 | + span.set_status(Status(StatusCode.OK)) |
| 101 | + return result |
| 102 | + |
| 103 | + except Exception as e: |
| 104 | + span.set_attribute(ERROR_TYPE, type(e).__name__) |
| 105 | + span.record_exception(e) |
| 106 | + span.set_status(Status(StatusCode.ERROR, str(e))) |
| 107 | + raise |
| 108 | + |
| 109 | + return traced_method |
| 110 | + |
| 111 | + def _should_send_prompts(self): |
| 112 | + """Check if content tracing is enabled (matches traceloop SDK)""" |
| 113 | + return ( |
| 114 | + os.getenv("TRACELOOP_TRACE_CONTENT") or "true" |
| 115 | + ).lower() == "true" |
| 116 | + |
| 117 | + def _get_json_encoder(self): |
| 118 | + """Get JSON encoder class (simplified - traceloop SDK uses custom JSONEncoder)""" |
| 119 | + return None # Use default JSON encoder |
| 120 | + |
| 121 | + def _truncate_json_if_needed(self, json_str: str) -> str: |
| 122 | + """Truncate JSON if it exceeds OTEL limits (matches traceloop SDK)""" |
| 123 | + limit_str = os.getenv("OTEL_SPAN_ATTRIBUTE_VALUE_LENGTH_LIMIT") |
| 124 | + if limit_str: |
| 125 | + try: |
| 126 | + limit = int(limit_str) |
| 127 | + if limit > 0 and len(json_str) > limit: |
| 128 | + return json_str[:limit] |
| 129 | + except ValueError: |
| 130 | + pass |
| 131 | + return json_str |
0 commit comments