Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ dependencies = [
"llama-stack>=0.2.13",
"rich>=14.0.0",
"cachetools>=6.1.0",
"prometheus-client>=0.22.1",
"starlette>=0.47.1",
]

[tool.pyright]
Expand Down
16 changes: 16 additions & 0 deletions src/app/endpoints/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Handler for REST API call to provide metrics."""

from fastapi.responses import PlainTextResponse
from fastapi import APIRouter, Request
from prometheus_client import (
generate_latest,
CONTENT_TYPE_LATEST,
)

router = APIRouter(tags=["metrics"])


@router.get("/metrics", response_class=PlainTextResponse)
def metrics_endpoint_handler(_request: Request) -> PlainTextResponse:
"""Handle request to the /metrics endpoint."""
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
33 changes: 26 additions & 7 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from client import LlamaStackClientHolder
from configuration import configuration
from app.endpoints.conversations import conversation_id_to_agent_id
import metrics
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.requests import QueryRequest, Attachment
import constants
Expand Down Expand Up @@ -122,14 +123,18 @@ def query_endpoint_handler(
try:
# try to get Llama Stack client
client = LlamaStackClientHolder().get_client()
model_id = select_model_id(client.models.list(), query_request)
model_id, provider_id = select_model_and_provider_id(
client.models.list(), query_request
)
response, conversation_id = retrieve_response(
client,
model_id,
query_request,
token,
mcp_headers=mcp_headers,
)
# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand All @@ -150,6 +155,8 @@ def query_endpoint_handler(

# connection to Llama Stack server
except APIConnectionError as e:
# Update metrics for the LLM call failure
metrics.llm_calls_failures_total.inc()
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand All @@ -160,8 +167,10 @@ def query_endpoint_handler(
) from e


def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
"""Select the model ID based on the request or available models."""
def select_model_and_provider_id(
models: ModelListResponse, query_request: QueryRequest
) -> tuple[str, str | None]:
"""Select the model ID and provider ID based on the request or available models."""
model_id = query_request.model
provider_id = query_request.provider

Expand All @@ -173,9 +182,11 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
m
for m in models
if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
).identifier
)
model_id = model.identifier
provider_id = model.provider_id
logger.info("Selected model: %s", model)
return model
return model_id, provider_id
except (StopIteration, AttributeError) as e:
message = "No LLM model found in available models"
logger.error(message)
Expand All @@ -201,7 +212,7 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
},
)

return model_id
return model_id, provider_id


def _is_inout_shield(shield: Shield) -> bool:
Expand All @@ -218,7 +229,7 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


def retrieve_response(
def retrieve_response( # pylint: disable=too-many-locals
client: LlamaStackClient,
model_id: str,
query_request: QueryRequest,
Expand Down Expand Up @@ -288,6 +299,14 @@ def retrieve_response(
toolgroups=toolgroups or None,
)

# Check for validation errors in the response
steps = getattr(response, "steps", [])
for step in steps:
if step.step_type == "shield_call" and step.violation:
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
break

return str(response.output_message.content), conversation_id # type: ignore[union-attr]


Expand Down
14 changes: 12 additions & 2 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from auth import get_auth_dependency
from client import AsyncLlamaStackClientHolder
from configuration import configuration
import metrics
from models.requests import QueryRequest
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.common import retrieve_user_id
Expand All @@ -37,7 +38,7 @@
is_output_shield,
is_transcripts_enabled,
store_transcript,
select_model_id,
select_model_and_provider_id,
validate_attachments_metadata,
)

Expand Down Expand Up @@ -229,6 +230,8 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
}
)
else:
# Metric for LLM validation errors
metrics.llm_calls_validation_errors_total.inc()
violation = (
f"Violation: {violation.user_message} (Metadata: {violation.metadata})"
)
Expand Down Expand Up @@ -421,7 +424,9 @@ async def streaming_query_endpoint_handler(
try:
# try to get Llama Stack client
client = AsyncLlamaStackClientHolder().get_client()
model_id = select_model_id(await client.models.list(), query_request)
model_id, provider_id = select_model_and_provider_id(
await client.models.list(), query_request
)
response, conversation_id = await retrieve_response(
client,
model_id,
Expand Down Expand Up @@ -465,9 +470,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
attachments=query_request.attachments or [],
)

# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()

return StreamingResponse(response_generator(response))
# connection to Llama Stack server
except APIConnectionError as e:
# Update metrics for the LLM call failure
metrics.llm_calls_failures_total.inc()
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
Expand Down
47 changes: 43 additions & 4 deletions src/app/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Definition of FastAPI based web service."""

from fastapi import FastAPI
from typing import Callable, Awaitable

from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from app import routers
from starlette.routing import Mount, Route, WebSocketRoute

import version
from log import get_logger
from app import routers
from configuration import configuration
from log import get_logger
import metrics
from metrics.utils import setup_model_metrics
from utils.common import register_mcp_servers_async
import version

logger = get_logger(__name__)

Expand All @@ -34,9 +39,43 @@
allow_headers=["*"],
)


@app.middleware("")
async def rest_api_metrics(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Middleware with REST API counter update logic."""
path = request.url.path
logger.debug("Received request for path: %s", path)

# ignore paths that are not part of the app routes
if path not in app_routes_paths:
return await call_next(request)

logger.debug("Processing API request for path: %s", path)

# measure time to handle duration + update histogram
with metrics.response_duration_seconds.labels(path).time():
response = await call_next(request)

# ignore /metrics endpoint that will be called periodically
if not path.endswith("/metrics"):
# just update metrics
metrics.rest_api_calls_total.labels(path, response.status_code).inc()
return response


logger.info("Including routers")
routers.include_routers(app)

app_routes_paths = [
route.path
for route in app.routes
if isinstance(route, (Mount, Route, WebSocketRoute))
]

setup_model_metrics()


@app.on_event("startup")
async def startup_event() -> None:
Expand Down
2 changes: 2 additions & 0 deletions src/app/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
streaming_query,
authorized,
conversations,
metrics,
)


Expand All @@ -34,3 +35,4 @@ def include_routers(app: FastAPI) -> None:
# road-core does not version these endpoints
app.include_router(health.router)
app.include_router(authorized.router)
app.include_router(metrics.router)
51 changes: 51 additions & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Metrics module for Lightspeed Stack."""

from prometheus_client import (
Counter,
Gauge,
Histogram,
)

# Counter to track REST API calls
# This will be used to count how many times each API endpoint is called
# and the status code of the response
rest_api_calls_total = Counter(
"ls_rest_api_calls_total", "REST API calls counter", ["path", "status_code"]
)

# Histogram to measure response durations
# This will be used to track how long it takes to handle requests
response_duration_seconds = Histogram(
"ls_response_duration_seconds", "Response durations", ["path"]
)

# Metric that indicates what provider + model customers are using so we can
# understand what is popular/important
provider_model_configuration = Gauge(
"ls_provider_model_configuration",
"LLM provider/models combinations defined in configuration",
["provider", "model"],
)

# Metric that counts how many LLM calls were made for each provider + model
llm_calls_total = Counter(
"ls_llm_calls_total", "LLM calls counter", ["provider", "model"]
)

# Metric that counts how many LLM calls failed
llm_calls_failures_total = Counter("ls_llm_calls_failures_total", "LLM calls failures")

# Metric that counts how many LLM calls had validation errors
llm_calls_validation_errors_total = Counter(
"ls_llm_validation_errors_total", "LLM validation errors"
)

# TODO(lucasagomes): Add metric for token usage
llm_token_sent_total = Counter(
"ls_llm_token_sent_total", "LLM tokens sent", ["provider", "model"]
)

# TODO(lucasagomes): Add metric for token usage
llm_token_received_total = Counter(
"ls_llm_token_received_total", "LLM tokens received", ["provider", "model"]
)
32 changes: 32 additions & 0 deletions src/metrics/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Utility functions for metrics handling."""

from client import LlamaStackClientHolder
from log import get_logger
import metrics

logger = get_logger(__name__)


# TODO(lucasagomes): Change this metric once we are allowed to set the the
# default model/provider via the configuration.The default provider/model
# will be set to 1, and the rest will be set to 0.
def setup_model_metrics() -> None:
"""Perform setup of all metrics related to LLM model and provider."""
client = LlamaStackClientHolder().get_client()
models = [
model
for model in client.models.list()
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
]

for model in models:
provider = model.provider_id
model_name = model.identifier
if provider and model_name:
label_key = (provider, model_name)
metrics.provider_model_configuration.labels(*label_key).set(1)
logger.debug(
"Set provider/model configuration for %s/%s to 1",
provider,
model_name,
)
25 changes: 25 additions & 0 deletions tests/unit/app/endpoints/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Unit tests for the /metrics REST API endpoint."""

from app.endpoints.metrics import metrics_endpoint_handler


def test_metrics_endpoint():
"""Test the metrics endpoint handler."""
response = metrics_endpoint_handler(None)
assert response is not None
assert response.status_code == 200
assert "text/plain" in response.headers["Content-Type"]

response_body = response.body.decode()

# Check if the response contains Prometheus metrics format
assert "# TYPE ls_rest_api_calls_total counter" in response_body
assert "# TYPE ls_response_duration_seconds histogram" in response_body
assert "# TYPE ls_provider_model_configuration gauge" in response_body
assert "# TYPE ls_llm_calls_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_total counter" in response_body
assert "# TYPE ls_llm_calls_failures_created gauge" in response_body
assert "# TYPE ls_llm_validation_errors_total counter" in response_body
assert "# TYPE ls_llm_validation_errors_created gauge" in response_body
assert "# TYPE ls_llm_token_sent_total counter" in response_body
assert "# TYPE ls_llm_token_received_total counter" in response_body
Loading