diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 9859521674..c2eff5a4ce 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -215,6 +215,7 @@ if [[ "$STACK_CONFIG" == *"server:"* && "$COLLECT_ONLY" == false ]]; then export OTEL_EXPORTER_OTLP_PROTOCOL="http/protobuf" export OTEL_BSP_SCHEDULE_DELAY="200" export OTEL_BSP_EXPORT_TIMEOUT="2000" + export OTEL_METRIC_EXPORT_INTERVAL="200" # remove "server:" from STACK_CONFIG stack_config=$(echo "$STACK_CONFIG" | sed 's/^server://') @@ -311,6 +312,9 @@ if [[ "$STACK_CONFIG" == *"docker:"* && "$COLLECT_ONLY" == false ]]; then DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_INFERENCE_MODE=$INFERENCE_MODE" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e LLAMA_STACK_TEST_STACK_CONFIG_TYPE=server" DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:${COLLECTOR_PORT}" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_METRIC_EXPORT_INTERVAL=200" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_SCHEDULE_DELAY=200" + DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OTEL_BSP_EXPORT_TIMEOUT=2000" # Pass through API keys if they exist [ -n "${TOGETHER_API_KEY:-}" ] && DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e TOGETHER_API_KEY=$TOGETHER_API_KEY" diff --git a/src/llama_stack/core/telemetry/telemetry.py b/src/llama_stack/core/telemetry/telemetry.py index 1ba43724d3..9476c961a7 100644 --- a/src/llama_stack/core/telemetry/telemetry.py +++ b/src/llama_stack/core/telemetry/telemetry.py @@ -427,6 +427,7 @@ class QueryMetricsResponse(BaseModel): "counters": {}, "gauges": {}, "up_down_counters": {}, + "histograms": {}, } _global_lock = threading.Lock() _TRACER_PROVIDER = None @@ -540,6 +541,16 @@ def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge: ) return cast(metrics.ObservableGauge, _GLOBAL_STORAGE["gauges"][name]) + def _get_or_create_histogram(self, name: str, unit: str) -> metrics.Histogram: + assert self.meter is not None + if name not in _GLOBAL_STORAGE["histograms"]: + _GLOBAL_STORAGE["histograms"][name] = self.meter.create_histogram( + name=name, + unit=unit, + description=f"Histogram for {name}", + ) + return cast(metrics.Histogram, _GLOBAL_STORAGE["histograms"][name]) + def _log_metric(self, event: MetricEvent) -> None: # Add metric as an event to the current span try: @@ -571,7 +582,16 @@ def _log_metric(self, event: MetricEvent) -> None: # Log to OpenTelemetry meter if available if self.meter is None: return - if isinstance(event.value, int): + + # Use histograms for token-related metrics (per-request measurements) + # Use counters for other cumulative metrics + token_metrics = {"prompt_tokens", "completion_tokens", "total_tokens"} + + if event.metric in token_metrics: + # Token metrics are per-request measurements, use histogram + histogram = self._get_or_create_histogram(event.metric, event.unit) + histogram.record(event.value, attributes=_clean_attributes(event.attributes)) + elif isinstance(event.value, int): counter = self._get_or_create_counter(event.metric, event.unit) counter.add(event.value, attributes=_clean_attributes(event.attributes)) elif isinstance(event.value, float): diff --git a/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json b/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json index a5f841baa0..878fcc650c 100644 --- a/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json +++ b/tests/integration/common/recordings/models-64a2277c90f0f42576f60c1030e3a020403d34a95f56931b792d5939f4cebc57-826d44c3.json @@ -84,5 +84,6 @@ } ], "is_streaming": false - } + }, + "id_normalization_mapping": {} } diff --git a/tests/integration/telemetry/collectors/base.py b/tests/integration/telemetry/collectors/base.py index a85e6cf3fb..a0fa803afd 100644 --- a/tests/integration/telemetry/collectors/base.py +++ b/tests/integration/telemetry/collectors/base.py @@ -6,20 +6,88 @@ """Shared helpers for telemetry test collectors.""" +import time from collections.abc import Iterable from dataclasses import dataclass from typing import Any +@dataclass +class MetricStub: + """Unified metric interface for both in-memory and OTLP collectors.""" + + name: str + value: Any + attributes: dict[str, Any] | None = None + + @dataclass class SpanStub: + """Unified span interface for both in-memory and OTLP collectors.""" + name: str - attributes: dict[str, Any] + attributes: dict[str, Any] | None = None resource_attributes: dict[str, Any] | None = None events: list[dict[str, Any]] | None = None trace_id: str | None = None span_id: str | None = None + @property + def context(self): + """Provide context-like interface for trace_id compatibility.""" + if self.trace_id is None: + return None + return type("Context", (), {"trace_id": int(self.trace_id, 16)})() + + def get_trace_id(self) -> str | None: + """Get trace ID in hex format. + + Tries context.trace_id first, then falls back to direct trace_id. + """ + context = getattr(self, "context", None) + if context and getattr(context, "trace_id", None) is not None: + return f"{context.trace_id:032x}" + return getattr(self, "trace_id", None) + + def has_message(self, text: str) -> bool: + """Check if span contains a specific message in its args.""" + if self.attributes is None: + return False + args = self.attributes.get("__args__") + if not args or not isinstance(args, str): + return False + return text in args + + def is_root_span(self) -> bool: + """Check if this is a root span.""" + if self.attributes is None: + return False + return self.attributes.get("__root__") is True + + def is_autotraced(self) -> bool: + """Check if this span was automatically traced.""" + if self.attributes is None: + return False + return self.attributes.get("__autotraced__") is True + + def get_span_type(self) -> str | None: + """Get the span type (async, sync, async_generator).""" + if self.attributes is None: + return None + return self.attributes.get("__type__") + + def get_class_method(self) -> tuple[str | None, str | None]: + """Get the class and method names for autotraced spans.""" + if self.attributes is None: + return None, None + return (self.attributes.get("__class__"), self.attributes.get("__method__")) + + def get_location(self) -> str | None: + """Get the location (library_client, server) for root spans.""" + if self.attributes is None: + return None + return self.attributes.get("__location__") + def _value_to_python(value: Any) -> Any: kind = value.WhichOneof("value") @@ -56,14 +124,18 @@ def events_to_list(events: Iterable[Any]) -> list[dict[str, Any]]: class BaseTelemetryCollector: + """Base class for telemetry collectors that ensures consistent return types. + + All collectors must return SpanStub objects to ensure test compatibility + across both library-client and server modes. + """ + def get_spans( self, expected_count: int | None = None, timeout: float = 5.0, poll_interval: float = 0.05, - ) -> tuple[Any, ...]: - import time - + ) -> tuple[SpanStub, ...]: deadline = time.time() + timeout min_count = expected_count if expected_count is not None else 1 last_len: int | None = None @@ -91,16 +163,206 @@ def get_spans( last_len = len(spans) time.sleep(poll_interval) - def get_metrics(self) -> Any | None: - return self._snapshot_metrics() + def get_metrics( + self, + expected_count: int | None = None, + timeout: float = 5.0, + poll_interval: float = 0.05, + expect_model_id: str | None = None, + ) -> dict[str, MetricStub]: + """Get metrics with polling until metrics are available or timeout is reached.""" + + # metrics need to be collected since get requests delete stored metrics + deadline = time.time() + timeout + min_count = expected_count if expected_count is not None else 1 + accumulated_metrics = {} + count_metrics_with_model_id = 0 + + while time.time() < deadline: + current_metrics = self._snapshot_metrics() + if current_metrics: + for metric in current_metrics: + metric_name = metric.name + if metric_name not in accumulated_metrics: + accumulated_metrics[metric_name] = metric + if ( + expect_model_id + and metric.attributes + and metric.attributes.get("model_id") == expect_model_id + ): + count_metrics_with_model_id += 1 + else: + accumulated_metrics[metric_name] = metric + + # Check if we have enough metrics + if len(accumulated_metrics) >= min_count: + if not expect_model_id: + return accumulated_metrics + if count_metrics_with_model_id >= min_count: + return accumulated_metrics + + time.sleep(poll_interval) + + return accumulated_metrics + + @staticmethod + def _convert_attributes_to_dict(attrs: Any) -> dict[str, Any]: + """Convert various attribute types to a consistent dictionary format. + + Handles mappingproxy, dict, and other attribute types. + """ + if attrs is None: + return {} + + try: + return dict(attrs.items()) # type: ignore[attr-defined] + except AttributeError: + try: + return dict(attrs) + except TypeError: + return dict(attrs) if attrs else {} + + @staticmethod + def _extract_trace_span_ids(span: Any) -> tuple[str | None, str | None]: + """Extract trace_id and span_id from OpenTelemetry span object. + + Handles both context-based and direct attribute access. + """ + trace_id = None + span_id = None + + context = getattr(span, "context", None) + if context: + trace_id = f"{context.trace_id:032x}" + span_id = f"{context.span_id:016x}" + else: + trace_id = getattr(span, "trace_id", None) + span_id = getattr(span, "span_id", None) + + return trace_id, span_id + + @staticmethod + def _create_span_stub_from_opentelemetry(span: Any) -> SpanStub: + """Create SpanStub from OpenTelemetry span object. + + This helper reduces code duplication between collectors. + """ + trace_id, span_id = BaseTelemetryCollector._extract_trace_span_ids(span) + attributes = BaseTelemetryCollector._convert_attributes_to_dict(span.attributes) or {} + + return SpanStub( + name=span.name, + attributes=attributes, + trace_id=trace_id, + span_id=span_id, + ) + + @staticmethod + def _create_span_stub_from_protobuf(span: Any, resource_attrs: dict[str, Any] | None = None) -> SpanStub: + """Create SpanStub from protobuf span object. + + This helper handles the different structure of protobuf spans. + """ + attributes = attributes_to_dict(span.attributes) or {} + events = events_to_list(span.events) if span.events else None + trace_id = span.trace_id.hex() if span.trace_id else None + span_id = span.span_id.hex() if span.span_id else None + + return SpanStub( + name=span.name, + attributes=attributes, + resource_attributes=resource_attrs, + events=events, + trace_id=trace_id, + span_id=span_id, + ) + + @staticmethod + def _extract_metric_from_opentelemetry(metric: Any) -> MetricStub | None: + """Extract MetricStub from OpenTelemetry metric object. + + This helper reduces code duplication between collectors. + """ + if not (hasattr(metric, "name") and hasattr(metric, "data") and hasattr(metric.data, "data_points")): + return None + + if not (metric.data.data_points and len(metric.data.data_points) > 0): + return None + + # Get the value from the first data point + data_point = metric.data.data_points[0] + + # Handle different metric types + if hasattr(data_point, "value"): + # Counter or Gauge + value = data_point.value + elif hasattr(data_point, "sum"): + # Histogram - use the sum of all recorded values + value = data_point.sum + else: + return None + + # Extract attributes if available + attributes = {} + if hasattr(data_point, "attributes"): + attrs = data_point.attributes + if attrs is not None and hasattr(attrs, "items"): + attributes = dict(attrs.items()) + elif attrs is not None and not isinstance(attrs, dict): + attributes = dict(attrs) + + return MetricStub( + name=metric.name, + value=value, + attributes=attributes or {}, + ) + + @staticmethod + def _create_metric_stub_from_protobuf(metric: Any) -> MetricStub | None: + """Create MetricStub from protobuf metric object. + + Protobuf metrics have a different structure than OpenTelemetry metrics. + They can have sum, gauge, or histogram data. + """ + if not hasattr(metric, "name"): + return None + + # Try to extract value from different metric types + for metric_type in ["sum", "gauge", "histogram"]: + if hasattr(metric, metric_type): + metric_data = getattr(metric, metric_type) + if metric_data and hasattr(metric_data, "data_points"): + data_points = metric_data.data_points + if data_points and len(data_points) > 0: + data_point = data_points[0] + + # Extract attributes first (needed for all metric types) + attributes = ( + attributes_to_dict(data_point.attributes) if hasattr(data_point, "attributes") else {} + ) + + # Extract value based on metric type + if metric_type == "sum": + value = data_point.as_int + elif metric_type == "gauge": + value = data_point.as_double + else: # histogram + value = data_point.sum + + return MetricStub( + name=metric.name, + value=value, + attributes=attributes, + ) + return None def clear(self) -> None: self._clear_impl() - def _snapshot_spans(self) -> tuple[Any, ...]: # pragma: no cover - interface hook + def _snapshot_spans(self) -> tuple[SpanStub, ...]: # pragma: no cover - interface hook raise NotImplementedError - def _snapshot_metrics(self) -> Any | None: # pragma: no cover - interface hook + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: # pragma: no cover - interface hook raise NotImplementedError def _clear_impl(self) -> None: # pragma: no cover - interface hook diff --git a/tests/integration/telemetry/collectors/in_memory.py b/tests/integration/telemetry/collectors/in_memory.py index 2cf320f7be..f431ed94d8 100644 --- a/tests/integration/telemetry/collectors/in_memory.py +++ b/tests/integration/telemetry/collectors/in_memory.py @@ -6,8 +6,6 @@ """In-memory telemetry collector for library-client tests.""" -from typing import Any - import opentelemetry.metrics as otel_metrics import opentelemetry.trace as otel_trace from opentelemetry import metrics, trace @@ -19,46 +17,41 @@ import llama_stack.core.telemetry.telemetry as telemetry_module -from .base import BaseTelemetryCollector, SpanStub +from .base import BaseTelemetryCollector, MetricStub, SpanStub class InMemoryTelemetryCollector(BaseTelemetryCollector): + """In-memory telemetry collector for library-client tests. + + Converts OpenTelemetry span objects to SpanStub objects to ensure + consistent interface with OTLP collector used in server mode. + """ + def __init__(self, span_exporter: InMemorySpanExporter, metric_reader: InMemoryMetricReader) -> None: self._span_exporter = span_exporter self._metric_reader = metric_reader - def _snapshot_spans(self) -> tuple[Any, ...]: + def _snapshot_spans(self) -> tuple[SpanStub, ...]: spans = [] for span in self._span_exporter.get_finished_spans(): - trace_id = None - span_id = None - context = getattr(span, "context", None) - if context: - trace_id = f"{context.trace_id:032x}" - span_id = f"{context.span_id:016x}" - else: - trace_id = getattr(span, "trace_id", None) - span_id = getattr(span, "span_id", None) - - stub = SpanStub( - span.name, - span.attributes, - getattr(span, "resource", None), - getattr(span, "events", None), - trace_id, - span_id, - ) - spans.append(stub) - + spans.append(self._create_span_stub_from_opentelemetry(span)) return tuple(spans) - def _snapshot_metrics(self) -> Any | None: + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: data = self._metric_reader.get_metrics_data() - if data and data.resource_metrics: - resource_metric = data.resource_metrics[0] + if not data or not data.resource_metrics: + return None + + metric_stubs = [] + for resource_metric in data.resource_metrics: if resource_metric.scope_metrics: - return resource_metric.scope_metrics[0].metrics - return None + for scope_metric in resource_metric.scope_metrics: + for metric in scope_metric.metrics: + metric_stub = self._extract_metric_from_opentelemetry(metric) + if metric_stub: + metric_stubs.append(metric_stub) + + return tuple(metric_stubs) if metric_stubs else None def _clear_impl(self) -> None: self._span_exporter.clear() diff --git a/tests/integration/telemetry/collectors/otlp.py b/tests/integration/telemetry/collectors/otlp.py index 2d6cb0b7eb..750ca6a552 100644 --- a/tests/integration/telemetry/collectors/otlp.py +++ b/tests/integration/telemetry/collectors/otlp.py @@ -9,20 +9,20 @@ import gzip import os import threading +import time from http.server import BaseHTTPRequestHandler, HTTPServer from socketserver import ThreadingMixIn -from typing import Any from opentelemetry.proto.collector.metrics.v1.metrics_service_pb2 import ExportMetricsServiceRequest from opentelemetry.proto.collector.trace.v1.trace_service_pb2 import ExportTraceServiceRequest -from .base import BaseTelemetryCollector, SpanStub, attributes_to_dict, events_to_list +from .base import BaseTelemetryCollector, MetricStub, SpanStub, attributes_to_dict class OtlpHttpTestCollector(BaseTelemetryCollector): def __init__(self) -> None: self._spans: list[SpanStub] = [] - self._metrics: list[Any] = [] + self._metrics: list[MetricStub] = [] self._lock = threading.Lock() class _ThreadingHTTPServer(ThreadingMixIn, HTTPServer): @@ -47,11 +47,7 @@ def _handle_traces(self, request: ExportTraceServiceRequest) -> None: for scope_spans in resource_spans.scope_spans: for span in scope_spans.spans: - attributes = attributes_to_dict(span.attributes) - events = events_to_list(span.events) if span.events else None - trace_id = span.trace_id.hex() if span.trace_id else None - span_id = span.span_id.hex() if span.span_id else None - new_spans.append(SpanStub(span.name, attributes, resource_attrs or None, events, trace_id, span_id)) + new_spans.append(self._create_span_stub_from_protobuf(span, resource_attrs or None)) if not new_spans: return @@ -60,10 +56,13 @@ def _handle_traces(self, request: ExportTraceServiceRequest) -> None: self._spans.extend(new_spans) def _handle_metrics(self, request: ExportMetricsServiceRequest) -> None: - new_metrics: list[Any] = [] + new_metrics: list[MetricStub] = [] for resource_metrics in request.resource_metrics: for scope_metrics in resource_metrics.scope_metrics: - new_metrics.extend(scope_metrics.metrics) + for metric in scope_metrics.metrics: + metric_stub = self._create_metric_stub_from_protobuf(metric) + if metric_stub: + new_metrics.append(metric_stub) if not new_metrics: return @@ -75,11 +74,40 @@ def _snapshot_spans(self) -> tuple[SpanStub, ...]: with self._lock: return tuple(self._spans) - def _snapshot_metrics(self) -> Any | None: + def _snapshot_metrics(self) -> tuple[MetricStub, ...] | None: with self._lock: - return list(self._metrics) if self._metrics else None + return tuple(self._metrics) if self._metrics else None def _clear_impl(self) -> None: + """Clear telemetry over a period of time to prevent race conditions between tests.""" + with self._lock: + self._spans.clear() + self._metrics.clear() + + # Prevent race conditions where telemetry arrives after clear() but before + # the test starts, causing contamination between tests + deadline = time.time() + 2.0 # Maximum wait time + last_span_count = 0 + last_metric_count = 0 + stable_iterations = 0 + + while time.time() < deadline: + with self._lock: + current_span_count = len(self._spans) + current_metric_count = len(self._metrics) + + if current_span_count == last_span_count and current_metric_count == last_metric_count: + stable_iterations += 1 + if stable_iterations >= 4: # 4 * 50ms = 200ms of stability + break + else: + stable_iterations = 0 + last_span_count = current_span_count + last_metric_count = current_metric_count + + time.sleep(0.05) + + # Final clear to remove any telemetry that arrived during stabilization with self._lock: self._spans.clear() self._metrics.clear() diff --git a/tests/integration/telemetry/recordings/1fcfd86d8111374dc852cfdea6bfdb6a511f92cee84a6325b04ae84878512c30.json b/tests/integration/telemetry/recordings/1fcfd86d8111374dc852cfdea6bfdb6a511f92cee84a6325b04ae84878512c30.json index 1981a583ac..6a06ca2c52 100644 --- a/tests/integration/telemetry/recordings/1fcfd86d8111374dc852cfdea6bfdb6a511f92cee84a6325b04ae84878512c30.json +++ b/tests/integration/telemetry/recordings/1fcfd86d8111374dc852cfdea6bfdb6a511f92cee84a6325b04ae84878512c30.json @@ -30,7 +30,7 @@ "index": 0, "logprobs": null, "message": { - "content": "import torch\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n# Load the pre-trained model and tokenizer\nmodel_name = \"CompVis/transformers-base-uncased\"\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n\n# Set the temperature to 0.7\ntemperature = 0.7\n\n# Define a function to generate text\ndef generate_text(prompt, max_length=100):\n input", + "content": "To test the trace function from OpenAI's API with a temperature of 0.7, you can use the following Python code:\n\n```python\nimport json\n\n# Import the required libraries\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\n# Set the API endpoint and model name\nmodel_name = \"dalle-mini\"\n\n# Initialize the model and tokenizer\nmodel = AutoModelForCausalLM.from_pretrained(model_name)\ntokenizer = AutoTokenizer.from_pretrained(model_name)\n\n", "refusal": null, "role": "assistant", "annotations": null, @@ -55,5 +55,6 @@ } }, "is_streaming": false - } + }, + "id_normalization_mapping": {} } diff --git a/tests/integration/telemetry/test_completions.py b/tests/integration/telemetry/test_completions.py index 5322f021a1..695f0c0363 100644 --- a/tests/integration/telemetry/test_completions.py +++ b/tests/integration/telemetry/test_completions.py @@ -4,48 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Telemetry tests verifying @trace_protocol decorator format across stack modes.""" +"""Telemetry tests verifying @trace_protocol decorator format across stack modes. -import json - - -def _span_attributes(span): - attrs = getattr(span, "attributes", None) - if attrs is None: - return {} - # ReadableSpan.attributes acts like a mapping - try: - return dict(attrs.items()) # type: ignore[attr-defined] - except AttributeError: - try: - return dict(attrs) - except TypeError: - return attrs - - -def _span_attr(span, key): - attrs = _span_attributes(span) - return attrs.get(key) - - -def _span_trace_id(span): - context = getattr(span, "context", None) - if context and getattr(context, "trace_id", None) is not None: - return f"{context.trace_id:032x}" - return getattr(span, "trace_id", None) +Note: The mock_otlp_collector fixture automatically clears telemetry data +before and after each test, ensuring test isolation. +""" - -def _span_has_message(span, text: str) -> bool: - args = _span_attr(span, "__args__") - if not args or not isinstance(args, str): - return False - return text in args +import json def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_model_id): """Verify streaming adds chunk_count and __type__=async_generator.""" - mock_otlp_collector.clear() - stream = llama_stack_client.chat.completions.create( model=text_model_id, messages=[{"role": "user", "content": "Test trace openai 1"}], @@ -62,16 +31,16 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod ( span for span in reversed(spans) - if _span_attr(span, "__type__") == "async_generator" - and _span_attr(span, "chunk_count") - and _span_has_message(span, "Test trace openai 1") + if span.get_span_type() == "async_generator" + and span.attributes.get("chunk_count") + and span.has_message("Test trace openai 1") ), None, ) assert async_generator_span is not None - raw_chunk_count = _span_attr(async_generator_span, "chunk_count") + raw_chunk_count = async_generator_span.attributes.get("chunk_count") assert raw_chunk_count is not None chunk_count = int(raw_chunk_count) @@ -80,7 +49,6 @@ def test_streaming_chunk_count(mock_otlp_collector, llama_stack_client, text_mod def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, text_model_id): """Comprehensive validation of telemetry data format including spans and metrics.""" - mock_otlp_collector.clear() response = llama_stack_client.chat.completions.create( model=text_model_id, @@ -101,37 +69,36 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # Verify spans spans = mock_otlp_collector.get_spans(expected_count=7) target_span = next( - (span for span in reversed(spans) if _span_has_message(span, "Test trace openai with temperature 0.7")), + (span for span in reversed(spans) if span.has_message("Test trace openai with temperature 0.7")), None, ) assert target_span is not None - trace_id = _span_trace_id(target_span) + trace_id = target_span.get_trace_id() assert trace_id is not None - spans = [span for span in spans if _span_trace_id(span) == trace_id] - spans = [span for span in spans if _span_attr(span, "__root__") or _span_attr(span, "__autotraced__")] + spans = [span for span in spans if span.get_trace_id() == trace_id] + spans = [span for span in spans if span.is_root_span() or span.is_autotraced()] assert len(spans) >= 4 # Collect all model_ids found in spans logged_model_ids = [] for span in spans: - attrs = _span_attributes(span) + attrs = span.attributes assert attrs is not None # Root span is created manually by tracing middleware, not by @trace_protocol decorator - is_root_span = attrs.get("__root__") is True - - if is_root_span: - assert attrs.get("__location__") in ["library_client", "server"] + if span.is_root_span(): + assert span.get_location() in ["library_client", "server"] continue - assert attrs.get("__autotraced__") - assert attrs.get("__class__") and attrs.get("__method__") - assert attrs.get("__type__") in ["async", "sync", "async_generator"] + assert span.is_autotraced() + class_name, method_name = span.get_class_method() + assert class_name and method_name + assert span.get_span_type() in ["async", "sync", "async_generator"] - args_field = attrs.get("__args__") + args_field = span.attributes.get("__args__") if args_field: args = json.loads(args_field) if "model_id" in args: @@ -140,21 +107,40 @@ def test_telemetry_format_completeness(mock_otlp_collector, llama_stack_client, # At least one span should capture the fully qualified model ID assert text_model_id in logged_model_ids, f"Expected to find {text_model_id} in spans, but got {logged_model_ids}" - # TODO: re-enable this once metrics get fixed - """ - # Verify token usage metrics in response - metrics = mock_otlp_collector.get_metrics() - - assert metrics - for metric in metrics: - assert metric.name in ["completion_tokens", "total_tokens", "prompt_tokens"] - assert metric.unit == "tokens" - assert metric.data.data_points and len(metric.data.data_points) == 1 - match metric.name: - case "completion_tokens": - assert metric.data.data_points[0].value == usage["completion_tokens"] - case "total_tokens": - assert metric.data.data_points[0].value == usage["total_tokens"] - case "prompt_tokens": - assert metric.data.data_points[0].value == usage["prompt_tokens" - """ + # Verify token usage metrics in response using polling + expected_metrics = ["completion_tokens", "total_tokens", "prompt_tokens"] + metrics = mock_otlp_collector.get_metrics(expected_count=len(expected_metrics), expect_model_id=text_model_id) + assert len(metrics) > 0, "No metrics found within timeout" + + # Filter metrics to only those from the specific model used in the request + # This prevents issues when multiple metrics with the same name exist from different models + # (e.g., when safety models like llama-guard are also called) + inference_model_metrics = {} + all_model_ids = set() + + for name, metric in metrics.items(): + if name in expected_metrics: + model_id = metric.attributes.get("model_id") + all_model_ids.add(model_id) + # Only include metrics from the specific model used in the test request + if model_id == text_model_id: + inference_model_metrics[name] = metric + + # Verify expected metrics are present for our specific model + for metric_name in expected_metrics: + assert metric_name in inference_model_metrics, ( + f"Expected metric {metric_name} for model {text_model_id} not found. " + f"Available models: {sorted(all_model_ids)}, " + f"Available metrics for {text_model_id}: {list(inference_model_metrics.keys())}" + ) + + # Verify metric values match usage data + assert inference_model_metrics["completion_tokens"].value == usage["completion_tokens"], ( + f"Expected {usage['completion_tokens']} for completion_tokens, but got {inference_model_metrics['completion_tokens'].value}" + ) + assert inference_model_metrics["total_tokens"].value == usage["total_tokens"], ( + f"Expected {usage['total_tokens']} for total_tokens, but got {inference_model_metrics['total_tokens'].value}" + ) + assert inference_model_metrics["prompt_tokens"].value == usage["prompt_tokens"], ( + f"Expected {usage['prompt_tokens']} for prompt_tokens, but got {inference_model_metrics['prompt_tokens'].value}" + )