diff --git a/python/instrumentation/openinference-instrumentation-agno/pyproject.toml b/python/instrumentation/openinference-instrumentation-agno/pyproject.toml index ad326371fc..2fac5a7259 100644 --- a/python/instrumentation/openinference-instrumentation-agno/pyproject.toml +++ b/python/instrumentation/openinference-instrumentation-agno/pyproject.toml @@ -35,10 +35,10 @@ dependencies = [ [project.optional-dependencies] instruments = [ - "agno>=1.5.2", + "agno>=2.1.2", ] test = [ - "agno==1.5.2", + "agno==2.1.2", "opentelemetry-sdk", "pytest-recording", "openai", diff --git a/python/instrumentation/openinference-instrumentation-agno/src/openinference/instrumentation/agno/_wrappers.py b/python/instrumentation/openinference-instrumentation-agno/src/openinference/instrumentation/agno/_wrappers.py index 574fb510a6..96274798a5 100644 --- a/python/instrumentation/openinference-instrumentation-agno/src/openinference/instrumentation/agno/_wrappers.py +++ b/python/instrumentation/openinference-instrumentation-agno/src/openinference/instrumentation/agno/_wrappers.py @@ -25,6 +25,8 @@ from agno.team import Team from agno.tools.function import Function, FunctionCall from agno.tools.toolkit import Toolkit +from agno.run.messages import RunMessages +from agno.agent import RunOutput from openinference.instrumentation import get_attributes_from_context, safe_json_dumps from openinference.semconv.trace import ( MessageAttributes, @@ -57,11 +59,27 @@ def _flatten(mapping: Optional[Mapping[str, Any]]) -> Iterator[Tuple[str, Attrib yield key, value -def _get_input_value(method: Callable[..., Any], *args: Any, **kwargs: Any) -> str: +def _get_user_message_content(method: Callable[..., Any], *args: Any, **kwargs: Any) -> str: arguments = _bind_arguments(method, *args, **kwargs) arguments = _strip_method_args(arguments) - return safe_json_dumps(arguments) - + + # Try to get input from run_response.input.input_content + run_response: RunOutput = arguments.get("run_response") + if run_response and hasattr(run_response, 'input') and run_response.input: + if hasattr(run_response.input, 'input_content') and run_response.input.input_content: + return run_response.input.input_content + + # Fallback: try run_messages approach + run_messages: RunMessages = arguments.get("run_messages") + if run_messages and run_messages.user_message: + return run_messages.user_message.content + + return "" + +def _extract_run_response_output(run_response: RunOutput) -> str: + if run_response and run_response.content: + return run_response.content + return "" def _bind_arguments(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Dict[str, Any]: method_signature = signature(method) @@ -93,7 +111,6 @@ def _run_arguments(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Attribut if session_id: yield SESSION_ID, session_id - if user_id: yield USER_ID, user_id @@ -197,7 +214,7 @@ def run( { OPENINFERENCE_SPAN_KIND: AGENT, GRAPH_NODE_ID: node_id, - INPUT_VALUE: _get_input_value( + INPUT_VALUE: _get_user_message_content( wrapped, *args, **kwargs, @@ -212,9 +229,9 @@ def run( team_token = _setup_team_context(agent, node_id) try: - run_response = wrapped(*args, **kwargs) + run_response: RunOutput = wrapped(*args, **kwargs) span.set_status(trace_api.StatusCode.OK) - span.set_attribute(OUTPUT_VALUE, run_response.to_json()) + span.set_attribute(OUTPUT_VALUE, _extract_run_response_output(run_response)) span.set_attribute(OUTPUT_MIME_TYPE, JSON) return run_response @@ -246,7 +263,6 @@ def run_stream( # Generate unique node ID for this execution node_id = _generate_node_id() arguments = _bind_arguments(wrapped, *args, **kwargs) - with self._tracer.start_as_current_span( span_name, attributes=dict( @@ -254,7 +270,7 @@ def run_stream( { OPENINFERENCE_SPAN_KIND: AGENT, GRAPH_NODE_ID: node_id, - INPUT_VALUE: _get_input_value( + INPUT_VALUE: _get_user_message_content( wrapped, *args, **kwargs, @@ -269,22 +285,40 @@ def run_stream( team_token = _setup_team_context(agent, node_id) try: - yield from wrapped(*args, **kwargs) - # Use get_last_run_output instead of removed agent.run_response - session_id = None - try: - session_id = arguments.get("session_id") - except Exception: - session_id = None - - run_response = None - if hasattr(agent, "get_last_run_output"): - run_response = agent.get_last_run_output(session_id=session_id) + current_run_id = None + for response in wrapped(*args, **kwargs): + if hasattr(response, "run_id"): + current_run_id = response.run_id + yield response + if "session" in arguments and arguments.get("session") and len(arguments.get("session").runs) > 0: + for run in arguments.get("session").runs: + if run.run_id == current_run_id and run.content: + if isinstance(run.content, str): + span.set_attribute(OUTPUT_VALUE, run.content) + else: + span.set_attribute(OUTPUT_VALUE, run.content.model_dump_json()) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) + span.set_status(trace_api.StatusCode.OK) + break - span.set_status(trace_api.StatusCode.OK) - if run_response is not None: - span.set_attribute(OUTPUT_VALUE, run_response.to_json()) - span.set_attribute(OUTPUT_MIME_TYPE, JSON) + else: + # Extract session_id from the session object + session_id = None + try: + session = arguments.get("session") + if session and hasattr(session, 'session_id'): + session_id = session.session_id + except Exception: + session_id = None + + run_response = None + if hasattr(agent, "get_last_run_output"): + run_response = agent.get_last_run_output(session_id=session_id) + + span.set_status(trace_api.StatusCode.OK) + if run_response is not None: + span.set_attribute(OUTPUT_VALUE, run_response.to_json()) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) except Exception as e: span.set_status(trace_api.StatusCode.ERROR, str(e)) @@ -324,7 +358,7 @@ async def arun( { OPENINFERENCE_SPAN_KIND: AGENT, GRAPH_NODE_ID: node_id, - INPUT_VALUE: _get_input_value( + INPUT_VALUE: _get_user_message_content( wrapped, *args, **kwargs, @@ -341,7 +375,7 @@ async def arun( try: run_response = await wrapped(*args, **kwargs) span.set_status(trace_api.StatusCode.OK) - span.set_attribute(OUTPUT_VALUE, run_response.to_json()) + span.set_attribute(OUTPUT_VALUE, _extract_run_response_output(run_response)) span.set_attribute(OUTPUT_MIME_TYPE, JSON) return run_response except Exception as e: @@ -382,7 +416,7 @@ async def arun_stream( { OPENINFERENCE_SPAN_KIND: AGENT, GRAPH_NODE_ID: node_id, - INPUT_VALUE: _get_input_value( + INPUT_VALUE: _get_user_message_content( wrapped, *args, **kwargs, @@ -397,24 +431,41 @@ async def arun_stream( team_token = _setup_team_context(agent, node_id) try: + current_run_id = None async for response in wrapped(*args, **kwargs): # type: ignore[attr-defined] + if hasattr(response, "run_id"): + current_run_id = response.run_id yield response - # Use get_last_run_output instead of removed agent.run_response - session_id = None - try: - session_id = arguments.get("session_id") - except Exception: - session_id = None - - run_response = None - if hasattr(agent, "get_last_run_output"): - run_response = agent.get_last_run_output(session_id=session_id) + if arguments.get("session") and len(arguments.get("session").runs) > 0: + for run in arguments.get("session").runs: + if run.run_id == current_run_id and run.content: + if isinstance(run.content, str): + span.set_attribute(OUTPUT_VALUE, run.content) + else: + span.set_attribute(OUTPUT_VALUE, run.content.model_dump_json()) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) + span.set_status(trace_api.StatusCode.OK) + break - span.set_status(trace_api.StatusCode.OK) - if run_response is not None: - span.set_attribute(OUTPUT_VALUE, run_response.to_json()) - span.set_attribute(OUTPUT_MIME_TYPE, JSON) + else: + # Extract session_id from the session object + session_id = None + try: + session = arguments.get("session") + if session and hasattr(session, 'session_id'): + session_id = session.session_id + except Exception: + session_id = None + + run_response = None + if hasattr(agent, "get_last_run_output"): + run_response = agent.get_last_run_output(session_id=session_id) + + span.set_status(trace_api.StatusCode.OK) + if run_response is not None: + span.set_attribute(OUTPUT_VALUE, _extract_run_response_output(run_response)) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) except Exception as e: span.set_status(trace_api.StatusCode.ERROR, str(e)) @@ -426,16 +477,32 @@ async def arun_stream( def _llm_input_messages(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]: - def process_message(idx: int, role: str, content: str) -> Iterator[Tuple[str, Any]]: - yield f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_ROLE}", role - yield f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_CONTENT}", content + + def process_message(idx: int, message: Any) -> Iterator[Tuple[str, Any]]: + yield f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_ROLE}", message.role + if message.content: + yield f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_CONTENT}", message.get_content_string() + if message.tool_calls: + for tool_call_index, tool_call in enumerate(message.tool_calls): + yield ( + f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_ID}", + tool_call.get("id"), + ) + yield ( + f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_NAME}", + tool_call.get("function", {}).get("name"), + ) + yield ( + f"{LLM_INPUT_MESSAGES}.{idx}.{MESSAGE_TOOL_CALLS}.{tool_call_index}.{TOOL_CALL_FUNCTION_ARGUMENTS_JSON}", + safe_json_dumps(tool_call.get("function", {}).get("arguments", {})), + ) + + messages = arguments.get("messages", []) for i, message in enumerate(messages): - role, content = message.role, message.get_content_string() - if content: - yield from process_message(i, role, content) - + if message.role in ["system", "user", "assistant", "tool"]: + yield from process_message(i, message) tools = arguments.get("tools", []) for tool_index, tool in enumerate(tools): yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", safe_json_dumps(tool) @@ -493,22 +560,118 @@ def _filter_sensitive_params(params: Dict[str, Any]) -> Dict[str, Any]: def _input_value_and_mime_type(arguments: Mapping[str, Any]) -> Iterator[Tuple[str, Any]]: yield INPUT_MIME_TYPE, JSON - yield INPUT_VALUE, safe_json_dumps(arguments) + + cleaned_input = [] + for message in arguments.get("messages", []): + message_dict = message.to_dict() + message_dict = {k: v for k, v in message_dict.items() if v is not None} + cleaned_input.append(message_dict) + yield INPUT_VALUE, safe_json_dumps({"messages": cleaned_input}) def _output_value_and_mime_type(output: str) -> Iterator[Tuple[str, Any]]: yield OUTPUT_MIME_TYPE, JSON - yield OUTPUT_VALUE, output + + # Try to parse the output and extract LLM_OUTPUT_MESSAGES + try: + output_data = json.loads(output) + if isinstance(output_data, dict): + # Extract message information for LLM_OUTPUT_MESSAGES (only core message fields) + messages = [] + message = {} + + if role := output_data.get("role"): + message["role"] = role + + if content := output_data.get("content"): + message["content"] = content + + # Only include tool_calls if they exist and are not empty + if tool_calls := output_data.get("tool_calls"): + if tool_calls: # Only include if not empty list + message["tool_calls"] = tool_calls + + messages.append(message) + for i,message in enumerate(messages): + yield f"{LLM_OUTPUT_MESSAGES}.{i}", safe_json_dumps(message) + + yield OUTPUT_VALUE, safe_json_dumps(messages) + + except (json.JSONDecodeError, TypeError): + # Fall back to the original output if parsing fails + yield OUTPUT_VALUE, output def _parse_model_output(output: Any) -> str: - if hasattr(output, "model_dump_json"): - return output.model_dump_json() # type: ignore[no-any-return] - elif isinstance(output, dict): - return json.dumps(output) - else: - return str(output) - + if hasattr(output, 'role') or hasattr(output, 'content') or hasattr(output, 'tool_calls'): + try: + result_dict = { + "created_at": getattr(output, 'created_at', None), + } + + if hasattr(output, 'role'): + result_dict["role"] = output.role + if hasattr(output, 'content'): + result_dict["content"] = output.content + if hasattr(output, 'tool_calls'): + result_dict["tool_calls"] = output.tool_calls + + # Add response_usage if available + if hasattr(output, 'response_usage') and output.response_usage: + result_dict["response_usage"] = { + "input_tokens": getattr(output.response_usage, 'input_tokens', None), + "output_tokens": getattr(output.response_usage, 'output_tokens', None), + "total_tokens": getattr(output.response_usage, 'total_tokens', None), + } + + return json.dumps(result_dict) + except Exception: + pass + + return json.dumps(output) if isinstance(output, dict) else str(output) + + +def _parse_model_output_stream(output: Any) -> dict: + + # Accumulate all content and tool calls across chunks + accumulated_content = "" + all_tool_calls = [] + + for chunk in output: + + # Accumulate content from this chunk + if chunk.content: + accumulated_content += chunk.content + + # Collect tool calls from this chunk + if chunk.tool_calls: + for tool_call in chunk.tool_calls: + if tool_call.id: + tool_call_dict = { + "id": tool_call.id, + "type": tool_call.type + } + if hasattr(tool_call, 'function'): + tool_call_dict["function"] = { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + all_tool_calls.append(tool_call_dict) + + # Create single message with accumulated content and all tool calls + messages = [] + if accumulated_content or all_tool_calls: + result_dict = {"role": "assistant"} + + if accumulated_content: + result_dict["content"] = accumulated_content + + if all_tool_calls: + result_dict["tool_calls"] = all_tool_calls + + messages.append(result_dict) + + return messages class _ModelWrapper: def __init__(self, tracer: trace_api.Tracer) -> None: @@ -535,8 +698,8 @@ def run( attributes={ OPENINFERENCE_SPAN_KIND: LLM, **dict(_input_value_and_mime_type(arguments)), - **dict(_llm_invocation_parameters(model, arguments)), **dict(_llm_input_messages(arguments)), + **dict(_llm_invocation_parameters(model, arguments)), **dict(get_attributes_from_context()), }, ) as span: @@ -546,8 +709,27 @@ def run( response = wrapped(*args, **kwargs) output_message = _parse_model_output(response) - span.set_attributes(dict(_output_value_and_mime_type(output_message))) + + # Extract and set token usage from the response + if hasattr(response, "response_usage") and response.response_usage: + metrics = response.response_usage + + # Set token usage attributes + if hasattr(metrics, "input_tokens") and metrics.input_tokens: + span.set_attribute("llm.token_count.prompt", metrics.input_tokens) + + if hasattr(metrics, "output_tokens") and metrics.output_tokens: + span.set_attribute("llm.token_count.completion", metrics.output_tokens) + + # Set cache-related tokens if available + if hasattr(metrics, "cache_read_tokens") and metrics.cache_read_tokens: + span.set_attribute("llm.token_count.cache_read", metrics.cache_read_tokens) + + if hasattr(metrics, "cache_write_tokens") and metrics.cache_write_tokens: + span.set_attribute("llm.token_count.cache_write", metrics.cache_write_tokens) + + return response def run_stream( @@ -561,7 +743,6 @@ def run_stream( return wrapped(*args, **kwargs) arguments = _bind_arguments(wrapped, *args, **kwargs) - model = instance model_name = model.name span_name = f"{model_name}.invoke_stream" @@ -579,13 +760,43 @@ def run_stream( span.set_status(trace_api.StatusCode.OK) span.set_attribute(LLM_MODEL_NAME, model.id) span.set_attribute(LLM_PROVIDER, model.provider) + # Token usage will be set after streaming completes based on final response responses = [] for chunk in wrapped(*args, **kwargs): responses.append(chunk) yield chunk - output_message = json.dumps([_parse_model_output(response) for response in responses]) - span.set_attributes(dict(_output_value_and_mime_type(output_message))) + + + output_message = _parse_model_output_stream(responses) + output_message = json.dumps(output_message) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) + span.set_attribute(OUTPUT_VALUE, output_message) + + # Find the final response with complete metrics (usually the last one with response_usage) + final_response_with_metrics = None + for response in reversed(responses): # Check from last to first + if hasattr(response, "response_usage") and response.response_usage: + final_response_with_metrics = response + break + + # Extract and set token usage from the final response + if final_response_with_metrics and final_response_with_metrics.response_usage: + metrics = final_response_with_metrics.response_usage + + # Set token usage attributes + if hasattr(metrics, "input_tokens") and metrics.input_tokens: + span.set_attribute("llm.token_count.prompt", metrics.input_tokens) + + if hasattr(metrics, "output_tokens") and metrics.output_tokens: + span.set_attribute("llm.token_count.completion", metrics.output_tokens) + + # Set cache-related tokens if available + if hasattr(metrics, "cache_read_tokens") and metrics.cache_read_tokens: + span.set_attribute("llm.token_count.cache_read", metrics.cache_read_tokens) + + if hasattr(metrics, "cache_write_tokens") and metrics.cache_write_tokens: + span.set_attribute("llm.token_count.cache_write", metrics.cache_write_tokens) async def arun( self, @@ -620,6 +831,28 @@ async def arun( response = await wrapped(*args, **kwargs) output_message = _parse_model_output(response) + # Extract and set token usage from the response + if hasattr(response, "response_usage") and response.response_usage: + metrics = response.response_usage + + # Set token usage attributes + if hasattr(metrics, "input_tokens") and metrics.input_tokens: + span.set_attribute(LLM_TOKEN_COUNT_PROMPT, metrics.input_tokens) + + if hasattr(metrics, "output_tokens") and metrics.output_tokens: + span.set_attribute(LLM_TOKEN_COUNT_COMPLETION, metrics.output_tokens) + + if hasattr(metrics, "total_tokens") and metrics.total_tokens: + span.set_attribute(LLM_TOKEN_COUNT_TOTAL, metrics.total_tokens) + + + # Set cache-related tokens if available + if hasattr(metrics, "cache_read_tokens") and metrics.cache_read_tokens: + span.set_attribute(LLM_COST_PROMPT_DETAILS_CACHE_READ, metrics.cache_read_tokens) + + if hasattr(metrics, "cache_write_tokens") and metrics.cache_write_tokens: + span.set_attribute(LLM_COST_PROMPT_DETAILS_CACHE_WRITE, metrics.cache_write_tokens) + span.set_attributes(dict(_output_value_and_mime_type(output_message))) return response @@ -654,13 +887,42 @@ async def arun_stream( span.set_status(trace_api.StatusCode.OK) span.set_attribute(LLM_MODEL_NAME, model.id) span.set_attribute(LLM_PROVIDER, model.provider) + # Token usage will be set after streaming completes based on final response responses = [] async for chunk in wrapped(*args, **kwargs): # type: ignore[attr-defined] responses.append(chunk) yield chunk - output_message = json.dumps([_parse_model_output(response) for response in responses]) - span.set_attributes(dict(_output_value_and_mime_type(output_message))) + + output_message = _parse_model_output_stream(responses) + output_message = json.dumps(output_message) + span.set_attribute(OUTPUT_MIME_TYPE, JSON) + span.set_attribute(OUTPUT_VALUE, output_message) + + # Find the final response with complete metrics (usually the last one with response_usage) + final_response_with_metrics = None + for response in reversed(responses): # Check from last to first + if hasattr(response, "response_usage") and response.response_usage: + final_response_with_metrics = response + break + + # Extract and set token usage from the final response + if final_response_with_metrics and final_response_with_metrics.response_usage: + metrics = final_response_with_metrics.response_usage + + # Set token usage attributes + if hasattr(metrics, "input_tokens") and metrics.input_tokens: + span.set_attribute("llm.token_count.prompt", metrics.input_tokens) + + if hasattr(metrics, "output_tokens") and metrics.output_tokens: + span.set_attribute("llm.token_count.completion", metrics.output_tokens) + + # Set cache-related tokens if available + if hasattr(metrics, "cache_read_tokens") and metrics.cache_read_tokens: + span.set_attribute("llm.token_count.cache_read", metrics.cache_read_tokens) + + if hasattr(metrics, "cache_write_tokens") and metrics.cache_write_tokens: + span.set_attribute("llm.token_count.cache_write", metrics.cache_write_tokens) def _function_call_attributes(function_call: FunctionCall) -> Iterator[Tuple[str, Any]]: @@ -838,3 +1100,9 @@ async def arun( TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME TOOL_CALL_ID = ToolCallAttributes.TOOL_CALL_ID + +# token count attributes +LLM_TOKEN_COUNT_PROMPT = SpanAttributes.LLM_TOKEN_COUNT_PROMPT +LLM_TOKEN_COUNT_COMPLETION = SpanAttributes.LLM_TOKEN_COUNT_COMPLETION +LLM_COST_PROMPT_DETAILS_CACHE_READ = SpanAttributes.LLM_COST_PROMPT_DETAILS_CACHE_READ +LLM_COST_PROMPT_DETAILS_CACHE_WRITE = SpanAttributes.LLM_COST_PROMPT_DETAILS_CACHE_WRITE diff --git a/python/instrumentation/openinference-instrumentation-agno/tests/test_instrumentor.py b/python/instrumentation/openinference-instrumentation-agno/tests/test_instrumentor.py index 5779ef130f..611193a70e 100644 --- a/python/instrumentation/openinference-instrumentation-agno/tests/test_instrumentor.py +++ b/python/instrumentation/openinference-instrumentation-agno/tests/test_instrumentor.py @@ -12,6 +12,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.util._importlib_metadata import entry_points +import importlib from openinference.instrumentation import OITracer from openinference.instrumentation.agno import AgnoInstrumentor @@ -24,7 +25,6 @@ match_on=["uri", "method"], ) - @pytest.fixture() def in_memory_span_exporter() -> InMemorySpanExporter: return InMemorySpanExporter()