From db5b65bd2f1a09d2bf0f6876c8334550aaa88c90 Mon Sep 17 00:00:00 2001 From: Lucas Alvares Gomes Date: Fri, 18 Jul 2025 14:10:56 +0100 Subject: [PATCH] Add metrics to lightspeed-stack This patch is adding the /metrics endpoint to the project. Differences with road-core/service: * The metrics are prefixed with "ls_" instead of "ols_" * The "provider_model_configuration" does not set the non-default model/providers to 0 because we currently do not have a way to set a default model/provider in the configuration. A TODO was left in the code. Supported metrics: * rest_api_calls_total: Counter to track REST API calls * response_duration_seconds: Histogram to measure how long it takes to handle requests * provider_model_configuration: Indicates what provider + model customers are using * llm_calls_total: How many LLM calls were made for each provider + model * llm_calls_failures_total: How many LLM calls failed * llm_calls_validation_errors_total: How many LLM calls had validation errors Missing metrics: * llm_token_sent_total: How many tokens were sent * llm_token_received_total: How many tokens were received The above metrics are missing because token counting PR (https://github.com/lightspeed-core/lightspeed-stack/pull/215) is not merged yet. Signed-off-by: Lucas Alvares Gomes --- pyproject.toml | 2 + src/app/endpoints/metrics.py | 16 +++ src/app/endpoints/query.py | 33 ++++- src/app/endpoints/streaming_query.py | 14 +- src/app/main.py | 47 ++++++- src/app/routers.py | 2 + src/metrics/__init__.py | 51 ++++++++ src/metrics/utils.py | 32 +++++ tests/unit/app/endpoints/test_metrics.py | 25 ++++ tests/unit/app/endpoints/test_query.py | 123 ++++++++++++++---- .../app/endpoints/test_streaming_query.py | 20 ++- tests/unit/app/test_routers.py | 7 +- tests/unit/metrics/__init__.py | 1 + tests/unit/metrics/test_utis.py | 23 ++++ uv.lock | 13 ++ 15 files changed, 363 insertions(+), 46 deletions(-) create mode 100644 src/app/endpoints/metrics.py create mode 100644 src/metrics/__init__.py create mode 100644 src/metrics/utils.py create mode 100644 tests/unit/app/endpoints/test_metrics.py create mode 100644 tests/unit/metrics/__init__.py create mode 100644 tests/unit/metrics/test_utis.py diff --git a/pyproject.toml b/pyproject.toml index 3bfd84d1..29fe476a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,8 @@ dependencies = [ "llama-stack>=0.2.13", "rich>=14.0.0", "cachetools>=6.1.0", + "prometheus-client>=0.22.1", + "starlette>=0.47.1", ] [tool.pyright] diff --git a/src/app/endpoints/metrics.py b/src/app/endpoints/metrics.py new file mode 100644 index 00000000..9a8eb1a2 --- /dev/null +++ b/src/app/endpoints/metrics.py @@ -0,0 +1,16 @@ +"""Handler for REST API call to provide metrics.""" + +from fastapi.responses import PlainTextResponse +from fastapi import APIRouter, Request +from prometheus_client import ( + generate_latest, + CONTENT_TYPE_LATEST, +) + +router = APIRouter(tags=["metrics"]) + + +@router.get("/metrics", response_class=PlainTextResponse) +def metrics_endpoint_handler(_request: Request) -> PlainTextResponse: + """Handle request to the /metrics endpoint.""" + return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index a10ed4f4..e1f13ede 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -24,6 +24,7 @@ from client import LlamaStackClientHolder from configuration import configuration from app.endpoints.conversations import conversation_id_to_agent_id +import metrics from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment import constants @@ -122,7 +123,9 @@ def query_endpoint_handler( try: # try to get Llama Stack client client = LlamaStackClientHolder().get_client() - model_id = select_model_id(client.models.list(), query_request) + model_id, provider_id = select_model_and_provider_id( + client.models.list(), query_request + ) response, conversation_id = retrieve_response( client, model_id, @@ -130,6 +133,8 @@ def query_endpoint_handler( token, mcp_headers=mcp_headers, ) + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -150,6 +155,8 @@ def query_endpoint_handler( # connection to Llama Stack server except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() logger.error("Unable to connect to Llama Stack: %s", e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, @@ -160,8 +167,10 @@ def query_endpoint_handler( ) from e -def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str: - """Select the model ID based on the request or available models.""" +def select_model_and_provider_id( + models: ModelListResponse, query_request: QueryRequest +) -> tuple[str, str | None]: + """Select the model ID and provider ID based on the request or available models.""" model_id = query_request.model provider_id = query_request.provider @@ -173,9 +182,11 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s m for m in models if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] - ).identifier + ) + model_id = model.identifier + provider_id = model.provider_id logger.info("Selected model: %s", model) - return model + return model_id, provider_id except (StopIteration, AttributeError) as e: message = "No LLM model found in available models" logger.error(message) @@ -201,7 +212,7 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s }, ) - return model_id + return model_id, provider_id def _is_inout_shield(shield: Shield) -> bool: @@ -218,7 +229,7 @@ def is_input_shield(shield: Shield) -> bool: return _is_inout_shield(shield) or not is_output_shield(shield) -def retrieve_response( +def retrieve_response( # pylint: disable=too-many-locals client: LlamaStackClient, model_id: str, query_request: QueryRequest, @@ -288,6 +299,14 @@ def retrieve_response( toolgroups=toolgroups or None, ) + # Check for validation errors in the response + steps = getattr(response, "steps", []) + for step in steps: + if step.step_type == "shield_call" and step.violation: + # Metric for LLM validation errors + metrics.llm_calls_validation_errors_total.inc() + break + return str(response.output_message.content), conversation_id # type: ignore[union-attr] diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 601db15a..278ff412 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -23,6 +23,7 @@ from auth import get_auth_dependency from client import AsyncLlamaStackClientHolder from configuration import configuration +import metrics from models.requests import QueryRequest from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.common import retrieve_user_id @@ -37,7 +38,7 @@ is_output_shield, is_transcripts_enabled, store_transcript, - select_model_id, + select_model_and_provider_id, validate_attachments_metadata, ) @@ -229,6 +230,8 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: } ) else: + # Metric for LLM validation errors + metrics.llm_calls_validation_errors_total.inc() violation = ( f"Violation: {violation.user_message} (Metadata: {violation.metadata})" ) @@ -421,7 +424,9 @@ async def streaming_query_endpoint_handler( try: # try to get Llama Stack client client = AsyncLlamaStackClientHolder().get_client() - model_id = select_model_id(await client.models.list(), query_request) + model_id, provider_id = select_model_and_provider_id( + await client.models.list(), query_request + ) response, conversation_id = await retrieve_response( client, model_id, @@ -465,9 +470,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: attachments=query_request.attachments or [], ) + # Update metrics for the LLM call + metrics.llm_calls_total.labels(provider_id, model_id).inc() + return StreamingResponse(response_generator(response)) # connection to Llama Stack server except APIConnectionError as e: + # Update metrics for the LLM call failure + metrics.llm_calls_failures_total.inc() logger.error("Unable to connect to Llama Stack: %s", e) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/src/app/main.py b/src/app/main.py index 90c09446..9aace6d3 100644 --- a/src/app/main.py +++ b/src/app/main.py @@ -1,13 +1,18 @@ """Definition of FastAPI based web service.""" -from fastapi import FastAPI +from typing import Callable, Awaitable + +from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware -from app import routers +from starlette.routing import Mount, Route, WebSocketRoute -import version -from log import get_logger +from app import routers from configuration import configuration +from log import get_logger +import metrics +from metrics.utils import setup_model_metrics from utils.common import register_mcp_servers_async +import version logger = get_logger(__name__) @@ -34,9 +39,43 @@ allow_headers=["*"], ) + +@app.middleware("") +async def rest_api_metrics( + request: Request, call_next: Callable[[Request], Awaitable[Response]] +) -> Response: + """Middleware with REST API counter update logic.""" + path = request.url.path + logger.debug("Received request for path: %s", path) + + # ignore paths that are not part of the app routes + if path not in app_routes_paths: + return await call_next(request) + + logger.debug("Processing API request for path: %s", path) + + # measure time to handle duration + update histogram + with metrics.response_duration_seconds.labels(path).time(): + response = await call_next(request) + + # ignore /metrics endpoint that will be called periodically + if not path.endswith("/metrics"): + # just update metrics + metrics.rest_api_calls_total.labels(path, response.status_code).inc() + return response + + logger.info("Including routers") routers.include_routers(app) +app_routes_paths = [ + route.path + for route in app.routes + if isinstance(route, (Mount, Route, WebSocketRoute)) +] + +setup_model_metrics() + @app.on_event("startup") async def startup_event() -> None: diff --git a/src/app/routers.py b/src/app/routers.py index abfc0059..9131076a 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -13,6 +13,7 @@ streaming_query, authorized, conversations, + metrics, ) @@ -34,3 +35,4 @@ def include_routers(app: FastAPI) -> None: # road-core does not version these endpoints app.include_router(health.router) app.include_router(authorized.router) + app.include_router(metrics.router) diff --git a/src/metrics/__init__.py b/src/metrics/__init__.py new file mode 100644 index 00000000..0c424498 --- /dev/null +++ b/src/metrics/__init__.py @@ -0,0 +1,51 @@ +"""Metrics module for Lightspeed Stack.""" + +from prometheus_client import ( + Counter, + Gauge, + Histogram, +) + +# Counter to track REST API calls +# This will be used to count how many times each API endpoint is called +# and the status code of the response +rest_api_calls_total = Counter( + "ls_rest_api_calls_total", "REST API calls counter", ["path", "status_code"] +) + +# Histogram to measure response durations +# This will be used to track how long it takes to handle requests +response_duration_seconds = Histogram( + "ls_response_duration_seconds", "Response durations", ["path"] +) + +# Metric that indicates what provider + model customers are using so we can +# understand what is popular/important +provider_model_configuration = Gauge( + "ls_provider_model_configuration", + "LLM provider/models combinations defined in configuration", + ["provider", "model"], +) + +# Metric that counts how many LLM calls were made for each provider + model +llm_calls_total = Counter( + "ls_llm_calls_total", "LLM calls counter", ["provider", "model"] +) + +# Metric that counts how many LLM calls failed +llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures") + +# Metric that counts how many LLM calls had validation errors +llm_calls_validation_errors_total = Counter( + "ls_llm_validation_errors_total", "LLM validation errors" +) + +# TODO(lucasagomes): Add metric for token usage +llm_token_sent_total = Counter( + "ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"] +) + +# TODO(lucasagomes): Add metric for token usage +llm_token_received_total = Counter( + "ls_llm_token_received_total", "LLM tokens received", ["provider", "model"] +) diff --git a/src/metrics/utils.py b/src/metrics/utils.py new file mode 100644 index 00000000..62f94bb0 --- /dev/null +++ b/src/metrics/utils.py @@ -0,0 +1,32 @@ +"""Utility functions for metrics handling.""" + +from client import LlamaStackClientHolder +from log import get_logger +import metrics + +logger = get_logger(__name__) + + +# TODO(lucasagomes): Change this metric once we are allowed to set the the +# default model/provider via the configuration.The default provider/model +# will be set to 1, and the rest will be set to 0. +def setup_model_metrics() -> None: + """Perform setup of all metrics related to LLM model and provider.""" + client = LlamaStackClientHolder().get_client() + models = [ + model + for model in client.models.list() + if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] + ] + + for model in models: + provider = model.provider_id + model_name = model.identifier + if provider and model_name: + label_key = (provider, model_name) + metrics.provider_model_configuration.labels(*label_key).set(1) + logger.debug( + "Set provider/model configuration for %s/%s to 1", + provider, + model_name, + ) diff --git a/tests/unit/app/endpoints/test_metrics.py b/tests/unit/app/endpoints/test_metrics.py new file mode 100644 index 00000000..19545541 --- /dev/null +++ b/tests/unit/app/endpoints/test_metrics.py @@ -0,0 +1,25 @@ +"""Unit tests for the /metrics REST API endpoint.""" + +from app.endpoints.metrics import metrics_endpoint_handler + + +def test_metrics_endpoint(): + """Test the metrics endpoint handler.""" + response = metrics_endpoint_handler(None) + assert response is not None + assert response.status_code == 200 + assert "text/plain" in response.headers["Content-Type"] + + response_body = response.body.decode() + + # Check if the response contains Prometheus metrics format + assert "# TYPE ls_rest_api_calls_total counter" in response_body + assert "# TYPE ls_response_duration_seconds histogram" in response_body + assert "# TYPE ls_provider_model_configuration gauge" in response_body + assert "# TYPE ls_llm_calls_total counter" in response_body + assert "# TYPE ls_llm_calls_failures_total counter" in response_body + assert "# TYPE ls_llm_calls_failures_created gauge" in response_body + assert "# TYPE ls_llm_validation_errors_total counter" in response_body + assert "# TYPE ls_llm_validation_errors_created gauge" in response_body + assert "# TYPE ls_llm_token_sent_total counter" in response_body + assert "# TYPE ls_llm_token_received_total counter" in response_body diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 3a768440..343fa5ef 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -12,7 +12,7 @@ from configuration import AppConfig from app.endpoints.query import ( query_endpoint_handler, - select_model_id, + select_model_and_provider_id, retrieve_response, validate_attachments_metadata, is_transcripts_enabled, @@ -63,6 +63,7 @@ def prepare_agent_mocks_fixture(mocker): """Fixture that yields mock agent when called.""" mock_client = mocker.Mock() mock_agent = mocker.Mock() + mock_agent.create_turn.return_value.steps = [] yield mock_client, mock_agent # cleanup agent cache after tests _agent_cache.clear() @@ -107,6 +108,7 @@ def test_is_transcripts_disabled(setup_configuration, mocker): def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): """Test the query endpoint handler.""" + mock_metric = mocker.patch("metrics.llm_calls_total") mock_client = mocker.Mock() mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client") mock_lsc.return_value = mock_client @@ -129,7 +131,10 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): "app.endpoints.query.retrieve_response", return_value=(llm_response, conversation_id), ) - mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), + ) mocker.patch( "app.endpoints.query.is_transcripts_enabled", return_value=store_transcript_to_file, @@ -144,6 +149,9 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): assert response.response == llm_response assert response.conversation_id == conversation_id + # Assert the metric for successful LLM calls is incremented + mock_metric.labels("fake_provider_id", "fake_model_id").inc.assert_called_once() + # Assert the store_transcript function is called if transcripts are enabled if store_transcript_to_file: mock_transcript.assert_called_once_with( @@ -171,8 +179,8 @@ def test_query_endpoint_handler_store_transcript(mocker): _test_query_endpoint_handler(mocker, store_transcript_to_file=True) -def test_select_model_id(mocker): - """Test the select_model_id function.""" +def test_select_model_and_provider_id(mocker): + """Test the select_model_and_provider_id function.""" mock_client = mocker.Mock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), @@ -183,13 +191,16 @@ def test_select_model_id(mocker): query="What is OpenStack?", model="model1", provider="provider1" ) - model_id = select_model_id(mock_client.models.list(), query_request) + model_id, provider_id = select_model_and_provider_id( + mock_client.models.list(), query_request + ) assert model_id == "model1" + assert provider_id == "provider1" -def test_select_model_id_no_model(mocker): - """Test the select_model_id function when no model is specified.""" +def test_select_model_and_provider_id_no_model(mocker): + """Test the select_model_and_provider_id function when no model is specified.""" mock_client = mocker.Mock() mock_client.models.list.return_value = [ mocker.Mock( @@ -205,14 +216,17 @@ def test_select_model_id_no_model(mocker): query_request = QueryRequest(query="What is OpenStack?") - model_id = select_model_id(mock_client.models.list(), query_request) + model_id, provider_id = select_model_and_provider_id( + mock_client.models.list(), query_request + ) # Assert return the first available LLM model assert model_id == "first_model" + assert provider_id == "provider1" -def test_select_model_id_invalid_model(mocker): - """Test the select_model_id function with an invalid model.""" +def test_select_model_and_provider_id_invalid_model(mocker): + """Test the select_model_and_provider_id function with an invalid model.""" mock_client = mocker.Mock() mock_client.models.list.return_value = [ mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), @@ -222,8 +236,8 @@ def test_select_model_id_invalid_model(mocker): query="What is OpenStack?", model="invalid_model", provider="provider1" ) - with pytest.raises(Exception) as exc_info: - select_model_id(mock_client.models.list(), query_request) + with pytest.raises(HTTPException) as exc_info: + select_model_and_provider_id(mock_client.models.list(), query_request) assert ( "Model invalid_model from provider provider1 not found in available models" @@ -231,16 +245,16 @@ def test_select_model_id_invalid_model(mocker): ) -def test_no_available_models(mocker): - """Test the select_model_id function with an invalid model.""" +def test_select_model_and_provider_id_no_available_models(mocker): + """Test the select_model_and_provider_id function with no available models.""" mock_client = mocker.Mock() # empty list of models mock_client.models.list.return_value = [] query_request = QueryRequest(query="What is OpenStack?", model=None, provider=None) - with pytest.raises(Exception) as exc_info: - select_model_id(mock_client.models.list(), query_request) + with pytest.raises(HTTPException) as exc_info: + select_model_and_provider_id(mock_client.models.list(), query_request) assert "No LLM model found in available models" in str(exc_info.value) @@ -304,6 +318,7 @@ def test_validate_attachments_metadata_invalid_content_type(): def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" + mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client.shields.list.return_value = [] @@ -327,6 +342,8 @@ def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): mock_client, model_id, query_request, access_token ) + # Assert that the metric for validation errors is NOT incremented + mock_metric.inc.assert_not_called() assert response == "LLM answer" assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( @@ -755,11 +772,12 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc ) -def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): +def test_retrieve_response_with_mcp_servers_and_mcp_headers( + prepare_agent_mocks, mocker +): """Test the retrieve_response function with MCP servers configured.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -846,6 +864,50 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): ) +def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker): + """Test the retrieve_response function.""" + mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + mock_client, mock_agent = prepare_agent_mocks + # Mock the agent's create_turn method to return a response with a shield violation + steps = [ + mocker.Mock( + step_type="shield_call", + violation=True, + ), + ] + mock_agent.create_turn.return_value.steps = steps + mock_client.shields.list.return_value = [] + mock_vector_db = mocker.Mock() + mock_vector_db.identifier = "VectorDB-1" + mock_client.vector_dbs.list.return_value = [mock_vector_db] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) + + query_request = QueryRequest(query="What is OpenStack?") + + _, conversation_id = retrieve_response( + mock_client, "fake_model_id", query_request, "test_token" + ) + + # Assert that the metric for validation errors is incremented + mock_metric.inc.assert_called_once() + + assert conversation_id == "fake_session_id" + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=get_rag_toolgroups(["VectorDB-1"]), + ) + + def test_construct_transcripts_path(setup_configuration, mocker): """Test the construct_transcripts_path function.""" # Update configuration for this test @@ -937,21 +999,25 @@ def test_get_rag_toolgroups(): def test_query_endpoint_handler_on_connection_error(mocker): """Test the query endpoint handler.""" + mock_metric = mocker.patch("metrics.llm_calls_failures_total") + mocker.patch( "app.endpoints.query.configuration", return_value=mocker.Mock(), ) - # construct mocked query - query = "What is OpenStack?" - query_request = QueryRequest(query=query) + query_request = QueryRequest(query="What is OpenStack?") # simulate situation when it is not possible to connect to Llama Stack - mock_lsc = mocker.Mock() - mock_lsc.get_client.side_effect = APIConnectionError(request=query_request) + mock_get_client = mocker.patch("client.LlamaStackClientHolder.get_client") + mock_get_client.side_effect = APIConnectionError(request=query_request) + + with pytest.raises(HTTPException) as exc_info: + query_endpoint_handler(query_request, auth=MOCK_AUTH) - with pytest.raises(Exception): - query_endpoint_handler(query_request) + assert exc_info.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert "Unable to connect to Llama Stack" in str(exc_info.value.detail) + mock_metric.inc.assert_called_once() def test_get_agent_cache_hit(prepare_agent_mocks): @@ -1254,7 +1320,10 @@ def test_auth_tuple_unpacking_in_query_endpoint_handler(mocker): return_value=("test response", "test_conversation_id"), ) - mocker.patch("app.endpoints.query.select_model_id", return_value="test_model") + mocker.patch( + "app.endpoints.query.select_model_and_provider_id", + return_value=("test_model", "test_provider"), + ) mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) _ = query_endpoint_handler( diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 305a2861..2c548e25 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -222,7 +222,8 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) return_value=(mock_streaming_response, "test_conversation_id"), ) mocker.patch( - "app.endpoints.streaming_query.select_model_id", return_value="fake_model_id" + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("fake_model_id", "fake_provider_id"), ) mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", @@ -732,8 +733,11 @@ def test_stream_build_event_turn_complete(): assert '"id": 0' in result -def test_stream_build_event_shield_call_step_complete_no_violation(): +def test_stream_build_event_shield_call_step_complete_no_violation(mocker): """Test stream_build_event function with shield_call_step_complete event type.""" + # Mock the metric for validation errors + mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + # Create a properly nested chunk structure # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing # attribute and therefore makes checks to see whether it is missing fail. @@ -760,10 +764,15 @@ def test_stream_build_event_shield_call_step_complete_no_violation(): assert '"token": "No Violation"' in result assert '"role": "shield_call"' in result assert '"id": 0' in result + # Assert that the metric for validation errors is NOT incremented + mock_metric.inc.assert_not_called() -def test_stream_build_event_shield_call_step_complete_with_violation(): +def test_stream_build_event_shield_call_step_complete_with_violation(mocker): """Test stream_build_event function with shield_call_step_complete event type with violation.""" + # Mock the metric for validation errors + mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") + # Create a properly nested chunk structure # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing # attribute and therefore makes checks to see whether it is missing fail. @@ -798,6 +807,8 @@ def test_stream_build_event_shield_call_step_complete_with_violation(): ) assert '"role": "shield_call"' in result assert '"id": 0' in result + # Assert that the metric for validation errors is incremented + mock_metric.inc.assert_called_once() def test_stream_build_event_step_progress(): @@ -1527,7 +1538,8 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): ) mocker.patch( - "app.endpoints.streaming_query.select_model_id", return_value="test_model" + "app.endpoints.streaming_query.select_model_and_provider_id", + return_value=("test_model", "test_provider"), ) mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 335443e5..629a3113 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -15,6 +15,7 @@ feedback, streaming_query, authorized, + metrics, ) # noqa:E402 @@ -44,7 +45,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 10 + assert len(app.routers) == 11 assert root.router in app.get_routers() assert info.router in app.get_routers() assert models.router in app.get_routers() @@ -55,6 +56,7 @@ def test_include_routers() -> None: assert health.router in app.get_routers() assert authorized.router in app.get_routers() assert conversations.router in app.get_routers() + assert metrics.router in app.get_routers() def test_check_prefixes() -> None: @@ -63,7 +65,7 @@ def test_check_prefixes() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 10 + assert len(app.routers) == 11 assert app.get_router_prefix(root.router) is None assert app.get_router_prefix(info.router) == "/v1" assert app.get_router_prefix(models.router) == "/v1" @@ -74,3 +76,4 @@ def test_check_prefixes() -> None: assert app.get_router_prefix(health.router) is None assert app.get_router_prefix(authorized.router) is None assert app.get_router_prefix(conversations.router) == "/v1" + assert app.get_router_prefix(metrics.router) is None diff --git a/tests/unit/metrics/__init__.py b/tests/unit/metrics/__init__.py new file mode 100644 index 00000000..7524ff7e --- /dev/null +++ b/tests/unit/metrics/__init__.py @@ -0,0 +1 @@ +"""Unit tests for metrics.""" diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py new file mode 100644 index 00000000..5d273c3a --- /dev/null +++ b/tests/unit/metrics/test_utis.py @@ -0,0 +1,23 @@ +"""Unit tests for functions defined in metrics/utils.py""" + +from metrics.utils import setup_model_metrics + + +def test_setup_model_metrics(mocker): + """Test the setup_model_metrics function.""" + + # Mock the LlamaStackAsLibraryClient + mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value + + mock_metric = mocker.patch("metrics.provider_model_configuration") + fake_model = mocker.Mock( + provider_id="test_provider", + identifier="test_model", + model_type="llm", + ) + mock_client.models.list.return_value = [fake_model] + + setup_model_metrics() + + # Assert that the metric was set correctly + mock_metric.labels("test_provider", "test_model").set.assert_called_once_with(1) diff --git a/uv.lock b/uv.lock index 9360c035..946c3527 100644 --- a/uv.lock +++ b/uv.lock @@ -835,7 +835,9 @@ dependencies = [ { name = "fastapi" }, { name = "kubernetes" }, { name = "llama-stack" }, + { name = "prometheus-client" }, { name = "rich" }, + { name = "starlette" }, { name = "uvicorn" }, ] @@ -867,7 +869,9 @@ requires-dist = [ { name = "fastapi", specifier = ">=0.115.6" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "llama-stack", specifier = ">=0.2.13" }, + { name = "prometheus-client", specifier = ">=0.22.1" }, { name = "rich", specifier = ">=14.0.0" }, + { name = "starlette", specifier = ">=0.47.1" }, { name = "uvicorn", specifier = ">=0.34.3" }, ] @@ -1459,6 +1463,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prometheus-client" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/cf/40dde0a2be27cc1eb41e333d1a674a74ce8b8b0457269cc640fd42b07cf7/prometheus_client-0.22.1.tar.gz", hash = "sha256:190f1331e783cf21eb60bca559354e0a4d4378facecf78f5428c39b675d20d28", size = 69746, upload-time = "2025-06-02T14:29:01.152Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/32/ae/ec06af4fe3ee72d16973474f122541746196aaa16cea6f66d18b963c6177/prometheus_client-0.22.1-py3-none-any.whl", hash = "sha256:cca895342e308174341b2cbf99a56bef291fbc0ef7b9e5412a0f26d653ba7094", size = 58694, upload-time = "2025-06-02T14:29:00.068Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.51"