Skip to content

Commit 9037b3b

Browse files
committed
LCORE-411: add token usage metrics
Signed-off-by: Haoyu Sun <[email protected]>
1 parent d89f7a3 commit 9037b3b

File tree

7 files changed

+120
-1
lines changed

7 files changed

+120
-1
lines changed

src/app/endpoints/query.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ async def query_endpoint_handler(
220220
query_request,
221221
token,
222222
mcp_headers=mcp_headers,
223+
provider_id=provider_id,
223224
)
224225
# Update metrics for the LLM call
225226
metrics.llm_calls_total.labels(provider_id, model_id).inc()
@@ -395,6 +396,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
395396
query_request: QueryRequest,
396397
token: str,
397398
mcp_headers: dict[str, dict[str, str]] | None = None,
399+
provider_id: str = "",
398400
) -> tuple[TurnSummary, str]:
399401
"""
400402
Retrieve response from LLMs and agents.
@@ -413,6 +415,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
413415
414416
Parameters:
415417
model_id (str): The identifier of the LLM model to use.
418+
provider_id (str): The identifier of the LLM provider to use.
416419
query_request (QueryRequest): The user's query and associated metadata.
417420
token (str): The authentication token for authorization.
418421
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.
@@ -512,6 +515,11 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
512515
tool_calls=[],
513516
)
514517

518+
# Update token count metrics for the LLM call
519+
metrics.update_llm_token_count_from_turn(
520+
response, model_id, provider_id, system_prompt
521+
)
522+
515523
# Check for validation errors in the response
516524
steps = response.steps or []
517525
for step in steps:

src/app/endpoints/streaming_query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,10 @@ async def response_generator(
614614
summary.llm_response = interleaved_content_as_str(
615615
p.turn.output_message.content
616616
)
617+
system_prompt = get_system_prompt(query_request, configuration)
618+
metrics.update_llm_token_count_from_turn(
619+
p.turn, model_id, provider_id, system_prompt
620+
)
617621
elif p.event_type == "step_complete":
618622
if p.step_details.step_type == "tool_execution":
619623
summary.append_tool_calls_from_llama(p.step_details)

src/metrics/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
Histogram,
77
)
88

9+
from .utils import update_llm_token_count_from_turn
10+
911
# Counter to track REST API calls
1012
# This will be used to count how many times each API endpoint is called
1113
# and the status code of the response

src/metrics/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
"""Utility functions for metrics handling."""
22

3+
from typing import cast
34
from configuration import configuration
45
from client import AsyncLlamaStackClientHolder
56
from log import get_logger
67
import metrics
78
from utils.common import run_once_async
9+
from llama_stack_client.types.agents.turn import Turn
10+
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
11+
from llama_stack.models.llama.llama3.chat_format import ChatFormat
12+
from llama_stack.models.llama.datatypes import RawMessage
813

914
logger = get_logger(__name__)
1015

@@ -48,3 +53,23 @@ async def setup_model_metrics() -> None:
4853
default_model_value,
4954
)
5055
logger.info("Model metrics setup complete")
56+
57+
58+
def update_llm_token_count_from_turn(
59+
turn: Turn, model: str, provider: str, system_prompt: str = ""
60+
) -> None:
61+
"""Update the LLM calls metrics from a turn."""
62+
tokenizer = Tokenizer.get_instance()
63+
formatter = ChatFormat(tokenizer)
64+
65+
raw_message = cast(RawMessage, turn.output_message)
66+
encoded_output = formatter.encode_dialog_prompt([raw_message])
67+
token_count = len(encoded_output.tokens) if encoded_output.tokens else 0
68+
metrics.llm_token_received_total.labels(provider, model).inc(token_count)
69+
70+
input_messages = [RawMessage(role="user", content=system_prompt)] + cast(
71+
list[RawMessage], turn.input_messages
72+
)
73+
encoded_input = formatter.encode_dialog_prompt(input_messages)
74+
token_count = len(encoded_input.tokens) if encoded_input.tokens else 0
75+
metrics.llm_token_sent_total.labels(provider, model).inc(token_count)

tests/unit/app/endpoints/test_query.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,14 @@ def dummy_request() -> Request:
4646
return req
4747

4848

49+
def mock_metrics(mocker):
50+
"""Helper function to mock metrics operations for query endpoints."""
51+
mocker.patch(
52+
"app.endpoints.query.metrics.update_llm_token_count_from_turn",
53+
return_value=None,
54+
)
55+
56+
4957
def mock_database_operations(mocker):
5058
"""Helper function to mock database operations for query endpoints."""
5159
mocker.patch(
@@ -443,6 +451,7 @@ async def test_retrieve_response_no_returned_message(prepare_agent_mocks, mocker
443451
"app.endpoints.query.get_agent",
444452
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
445453
)
454+
mock_metrics(mocker)
446455

447456
query_request = QueryRequest(query="What is OpenStack?")
448457
model_id = "fake_model_id"
@@ -474,6 +483,7 @@ async def test_retrieve_response_message_without_content(prepare_agent_mocks, mo
474483
"app.endpoints.query.get_agent",
475484
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
476485
)
486+
mock_metrics(mocker)
477487

478488
query_request = QueryRequest(query="What is OpenStack?")
479489
model_id = "fake_model_id"
@@ -506,6 +516,7 @@ async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker
506516
"app.endpoints.query.get_agent",
507517
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
508518
)
519+
mock_metrics(mocker)
509520

510521
query_request = QueryRequest(query="What is OpenStack?")
511522
model_id = "fake_model_id"
@@ -544,6 +555,7 @@ async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocke
544555
"app.endpoints.query.get_agent",
545556
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
546557
)
558+
mock_metrics(mocker)
547559

548560
query_request = QueryRequest(query="What is OpenStack?")
549561
model_id = "fake_model_id"
@@ -593,6 +605,7 @@ def __repr__(self):
593605
"app.endpoints.query.get_agent",
594606
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
595607
)
608+
mock_metrics(mocker)
596609

597610
query_request = QueryRequest(query="What is OpenStack?")
598611
model_id = "fake_model_id"
@@ -645,6 +658,7 @@ def __repr__(self):
645658
"app.endpoints.query.get_agent",
646659
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
647660
)
661+
mock_metrics(mocker)
648662

649663
query_request = QueryRequest(query="What is OpenStack?")
650664
model_id = "fake_model_id"
@@ -699,6 +713,7 @@ def __repr__(self):
699713
"app.endpoints.query.get_agent",
700714
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
701715
)
716+
mock_metrics(mocker)
702717

703718
query_request = QueryRequest(query="What is OpenStack?")
704719
model_id = "fake_model_id"
@@ -755,6 +770,7 @@ async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker
755770
"app.endpoints.query.get_agent",
756771
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
757772
)
773+
mock_metrics(mocker)
758774

759775
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
760776
model_id = "fake_model_id"
@@ -809,6 +825,7 @@ async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocke
809825
"app.endpoints.query.get_agent",
810826
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
811827
)
828+
mock_metrics(mocker)
812829

813830
query_request = QueryRequest(query="What is OpenStack?", attachments=attachments)
814831
model_id = "fake_model_id"
@@ -864,6 +881,7 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker):
864881
"app.endpoints.query.get_agent",
865882
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
866883
)
884+
mock_metrics(mocker)
867885

868886
query_request = QueryRequest(query="What is OpenStack?")
869887
model_id = "fake_model_id"
@@ -933,6 +951,7 @@ async def test_retrieve_response_with_mcp_servers_empty_token(
933951
"app.endpoints.query.get_agent",
934952
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
935953
)
954+
mock_metrics(mocker)
936955

937956
query_request = QueryRequest(query="What is OpenStack?")
938957
model_id = "fake_model_id"
@@ -994,6 +1013,7 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(
9941013
"app.endpoints.query.get_agent",
9951014
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
9961015
)
1016+
mock_metrics(mocker)
9971017

9981018
query_request = QueryRequest(query="What is OpenStack?")
9991019
model_id = "fake_model_id"
@@ -1090,6 +1110,7 @@ async def test_retrieve_response_shield_violation(prepare_agent_mocks, mocker):
10901110
"app.endpoints.query.get_agent",
10911111
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
10921112
)
1113+
mock_metrics(mocker)
10931114

10941115
query_request = QueryRequest(query="What is OpenStack?")
10951116

@@ -1326,6 +1347,7 @@ async def test_retrieve_response_no_tools_bypasses_mcp_and_rag(
13261347
"app.endpoints.query.get_agent",
13271348
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13281349
)
1350+
mock_metrics(mocker)
13291351

13301352
query_request = QueryRequest(query="What is OpenStack?", no_tools=True)
13311353
model_id = "fake_model_id"
@@ -1376,6 +1398,7 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
13761398
"app.endpoints.query.get_agent",
13771399
return_value=(mock_agent, "fake_conversation_id", "fake_session_id"),
13781400
)
1401+
mock_metrics(mocker)
13791402

13801403
query_request = QueryRequest(query="What is OpenStack?", no_tools=False)
13811404
model_id = "fake_model_id"

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,14 @@ def mock_database_operations(mocker):
5858
mocker.patch("app.endpoints.streaming_query.persist_user_conversation_details")
5959

6060

61+
def mock_metrics(mocker):
62+
"""Helper function to mock metrics operations for streaming query endpoints."""
63+
mocker.patch(
64+
"app.endpoints.streaming_query.metrics.update_llm_token_count_from_turn",
65+
return_value=None,
66+
)
67+
68+
6169
SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [
6270
"""knowledge_search tool found 2 chunks:
6371
BEGIN of knowledge_search tool results.
@@ -346,12 +354,14 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
346354
@pytest.mark.asyncio
347355
async def test_streaming_query_endpoint_handler(mocker):
348356
"""Test the streaming query endpoint handler with transcript storage disabled."""
357+
mock_metrics(mocker)
349358
await _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
350359

351360

352361
@pytest.mark.asyncio
353362
async def test_streaming_query_endpoint_handler_store_transcript(mocker):
354363
"""Test the streaming query endpoint handler with transcript storage enabled."""
364+
mock_metrics(mocker)
355365
await _test_streaming_query_endpoint_handler(mocker, store_transcript=True)
356366

357367

tests/unit/metrics/test_utis.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Unit tests for functions defined in metrics/utils.py"""
22

3-
from metrics.utils import setup_model_metrics
3+
from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn
44

55

66
async def test_setup_model_metrics(mocker):
@@ -74,3 +74,50 @@ async def test_setup_model_metrics(mocker):
7474
],
7575
any_order=False, # Order matters here
7676
)
77+
78+
79+
def test_update_llm_token_count_from_turn(mocker):
80+
"""Test the update_llm_token_count_from_turn function."""
81+
mocker.patch("metrics.utils.Tokenizer.get_instance")
82+
mock_formatter_class = mocker.patch("metrics.utils.ChatFormat")
83+
mock_formatter = mocker.Mock()
84+
mock_formatter_class.return_value = mock_formatter
85+
86+
mock_received_metric = mocker.patch(
87+
"metrics.utils.metrics.llm_token_received_total"
88+
)
89+
mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total")
90+
91+
mock_turn = mocker.Mock()
92+
# turn.output_message should satisfy the type RawMessage
93+
mock_turn.output_message = {"role": "assistant", "content": "test response"}
94+
# turn.input_messages should satisfy the type list[RawMessage]
95+
mock_turn.input_messages = [{"role": "user", "content": "test input"}]
96+
97+
# Mock the encoded results with tokens
98+
mock_encoded_output = mocker.Mock()
99+
mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens
100+
mock_encoded_input = mocker.Mock()
101+
mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens
102+
mock_formatter.encode_dialog_prompt.side_effect = [
103+
mock_encoded_output,
104+
mock_encoded_input,
105+
]
106+
107+
test_model = "test_model"
108+
test_provider = "test_provider"
109+
test_system_prompt = "test system prompt"
110+
111+
update_llm_token_count_from_turn(
112+
mock_turn, test_model, test_provider, test_system_prompt
113+
)
114+
115+
# Verify that llm_token_received_total.labels() was called with correct metrics
116+
mock_received_metric.labels.assert_called_once_with(test_provider, test_model)
117+
mock_received_metric.labels().inc.assert_called_once_with(
118+
3
119+
) # token count from output
120+
121+
# Verify that llm_token_sent_total.labels() was called with correct metrics
122+
mock_sent_metric.labels.assert_called_once_with(test_provider, test_model)
123+
mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input

0 commit comments

Comments
 (0)