Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from metrics.utils import update_llm_token_count_from_turn
from models.config import Action
from models.database.conversations import UserConversation
from models.requests import Attachment, QueryRequest
Expand All @@ -55,6 +54,7 @@
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
from utils.transcripts import store_transcript
from utils.types import TurnSummary
from utils.token_counter import extract_and_update_token_metrics, TokenCounter

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["query"])
Expand Down Expand Up @@ -279,16 +279,16 @@ async def query_endpoint_handler( # pylint: disable=R0914
user_conversation=user_conversation, query_request=query_request
),
)
summary, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
token,
mcp_headers=mcp_headers,
provider_id=provider_id,
summary, conversation_id, referenced_documents, token_usage = (
await retrieve_response(
client,
llama_stack_model_id,
query_request,
token,
mcp_headers=mcp_headers,
provider_id=provider_id,
)
)
# Update metrics for the LLM call
metrics.llm_calls_total.labels(provider_id, model_id).inc()

# Get the initial topic summary for the conversation
topic_summary = None
Expand Down Expand Up @@ -371,6 +371,10 @@ async def query_endpoint_handler( # pylint: disable=R0914
rag_chunks=summary.rag_chunks if summary.rag_chunks else [],
tool_calls=tool_calls if tool_calls else None,
referenced_documents=referenced_documents,
truncated=False, # TODO: implement truncation detection
input_tokens=token_usage.input_tokens,
output_tokens=token_usage.output_tokens,
available_quotas={}, # TODO: implement quota tracking
)
logger.info("Query processing completed successfully!")
return response
Expand Down Expand Up @@ -583,7 +587,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers: dict[str, dict[str, str]] | None = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str, list[ReferencedDocument]]:
) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]:
"""
Retrieve response from LLMs and agents.

Expand All @@ -607,9 +611,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[TurnSummary, str, list[ReferencedDocument]]: A tuple containing
tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: A tuple containing
a summary of the LLM or agent's response
content, the conversation ID and the list of parsed referenced documents.
content, the conversation ID, the list of parsed referenced documents, and token usage information.
"""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -704,9 +708,11 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche

referenced_documents = parse_referenced_documents(response)

# Update token count metrics for the LLM call
# Update token count metrics and extract token usage in one call
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
token_usage = extract_and_update_token_metrics(
response, model_label, provider_id, system_prompt
)

# Check for validation errors in the response
steps = response.steps or []
Expand All @@ -722,7 +728,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
return (summary, conversation_id, referenced_documents)
return (summary, conversation_id, referenced_documents, token_usage)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
29 changes: 23 additions & 6 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
validate_model_provider_override,
)
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
from utils.token_counter import TokenCounter, extract_token_usage_from_turn
from utils.transcripts import store_transcript
from utils.types import TurnSummary

Expand Down Expand Up @@ -154,17 +155,23 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_end_event(metadata_map: dict, media_type: str = MEDIA_TYPE_JSON) -> str:
def stream_end_event(
metadata_map: dict,
summary: TurnSummary, # pylint: disable=unused-argument
token_usage: TokenCounter,
media_type: str = MEDIA_TYPE_JSON,
) -> str:
"""
Yield the end of the data stream.

Format and return the end event for a streaming response,
including referenced document metadata and placeholder token
counts.
including referenced document metadata and token usage information.

Parameters:
metadata_map (dict): A mapping containing metadata about
referenced documents.
summary (TurnSummary): Summary of the conversation turn.
token_usage (TokenCounter): Token usage information.
media_type (str): The media type for the response format.

Returns:
Expand Down Expand Up @@ -199,8 +206,8 @@ def stream_end_event(metadata_map: dict, media_type: str = MEDIA_TYPE_JSON) -> s
"rag_chunks": [], # TODO(jboos): implement RAG chunks when summary is available
"referenced_documents": referenced_docs_dict,
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
"input_tokens": token_usage.input_tokens,
"output_tokens": token_usage.output_tokens,
},
"available_quotas": {}, # TODO(jboos): implement available quotas
}
Expand Down Expand Up @@ -787,6 +794,8 @@ async def response_generator(
# Send start event at the beginning of the stream
yield stream_start_event(conversation_id)

latest_turn: Any | None = None

async for chunk in turn_response:
if chunk.event is None:
continue
Expand All @@ -795,6 +804,7 @@ async def response_generator(
summary.llm_response = interleaved_content_as_str(
p.turn.output_message.content
)
latest_turn = p.turn
system_prompt = get_system_prompt(query_request, configuration)
try:
update_llm_token_count_from_turn(
Expand All @@ -812,7 +822,14 @@ async def response_generator(
chunk_id += 1
yield event

yield stream_end_event(metadata_map, media_type)
# Extract token usage from the turn
token_usage = (
extract_token_usage_from_turn(latest_turn)
if latest_turn is not None
else TokenCounter()
)

yield stream_end_event(metadata_map, summary, token_usage, media_type)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand Down
37 changes: 32 additions & 5 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,10 @@ class QueryResponse(BaseModel):
rag_chunks: List of RAG chunks used to generate the response.
referenced_documents: The URLs and titles for the documents used to generate the response.
tool_calls: List of tool calls made during response generation.
TODO: truncated: Whether conversation history was truncated.
TODO: input_tokens: Number of tokens sent to LLM.
TODO: output_tokens: Number of tokens received from LLM.
TODO: available_quotas: Quota available as measured by all configured quota limiters
TODO: tool_results: List of tool results.
truncated: Whether conversation history was truncated.
input_tokens: Number of tokens sent to LLM.
output_tokens: Number of tokens received from LLM.
available_quotas: Quota available as measured by all configured quota limiters.
"""

conversation_id: Optional[str] = Field(
Expand Down Expand Up @@ -229,6 +228,30 @@ class QueryResponse(BaseModel):
],
)

truncated: bool = Field(
False,
description="Whether conversation history was truncated",
examples=[False, True],
)

input_tokens: int = Field(
0,
description="Number of tokens sent to LLM",
examples=[150, 250, 500],
)

output_tokens: int = Field(
0,
description="Number of tokens received from LLM",
examples=[50, 100, 200],
)

available_quotas: dict[str, int] = Field(
default_factory=dict,
description="Quota available as measured by all configured quota limiters",
examples=[{"daily": 1000, "monthly": 50000}],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
Expand Down Expand Up @@ -257,6 +280,10 @@ class QueryResponse(BaseModel):
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
"truncated": False,
"input_tokens": 150,
"output_tokens": 75,
"available_quotas": {"daily": 1000, "monthly": 50000},
}
]
}
Expand Down
130 changes: 130 additions & 0 deletions src/utils/token_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Helper classes to count tokens sent and received by the LLM."""

import logging
from dataclasses import dataclass
from typing import cast

from llama_stack.models.llama.datatypes import RawMessage
from llama_stack.models.llama.llama3.chat_format import ChatFormat
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack_client.types.agents.turn import Turn

import metrics

logger = logging.getLogger(__name__)


@dataclass
class TokenCounter:
"""Model representing token counter.

Attributes:
input_tokens: number of tokens sent to LLM
output_tokens: number of tokens received from LLM
input_tokens_counted: number of input tokens counted by the handler
llm_calls: number of LLM calls
"""

input_tokens: int = 0
output_tokens: int = 0
input_tokens_counted: int = 0
llm_calls: int = 0

def __str__(self) -> str:
"""Textual representation of TokenCounter instance."""
return (
f"{self.__class__.__name__}: "
+ f"input_tokens: {self.input_tokens} "
+ f"output_tokens: {self.output_tokens} "
+ f"counted: {self.input_tokens_counted} "
+ f"LLM calls: {self.llm_calls}"
)


def extract_token_usage_from_turn(turn: Turn, system_prompt: str = "") -> TokenCounter:
"""Extract token usage information from a turn.

This function uses the same tokenizer and logic as the metrics system
to ensure consistency between API responses and Prometheus metrics.

Args:
turn: The turn object containing token usage information
system_prompt: The system prompt used for the turn

Returns:
TokenCounter: Token usage information
"""
token_counter = TokenCounter()

try:
# Use the same tokenizer as the metrics system for consistency
tokenizer = Tokenizer.get_instance()
formatter = ChatFormat(tokenizer)

# Count output tokens (same logic as metrics.utils.update_llm_token_count_from_turn)
if hasattr(turn, "output_message") and turn.output_message:
raw_message = cast(RawMessage, turn.output_message)
encoded_output = formatter.encode_dialog_prompt([raw_message])
token_counter.output_tokens = (
len(encoded_output.tokens) if encoded_output.tokens else 0
)

# Count input tokens (same logic as metrics.utils.update_llm_token_count_from_turn)
if hasattr(turn, "input_messages") and turn.input_messages:
input_messages = cast(list[RawMessage], turn.input_messages)
if system_prompt:
input_messages = [
RawMessage(role="system", content=system_prompt)
] + input_messages
encoded_input = formatter.encode_dialog_prompt(input_messages)
token_counter.input_tokens = (
len(encoded_input.tokens) if encoded_input.tokens else 0
)
token_counter.input_tokens_counted = token_counter.input_tokens

token_counter.llm_calls = 1

except (AttributeError, TypeError, ValueError) as e:
logger.warning("Failed to extract token usage from turn: %s", e)
# Fallback to default values if token counting fails
token_counter.input_tokens = 100 # Default estimate
token_counter.output_tokens = 50 # Default estimate
token_counter.llm_calls = 1

return token_counter


def extract_and_update_token_metrics(
turn: Turn, model: str, provider: str, system_prompt: str = ""
) -> TokenCounter:
"""Extract token usage and update Prometheus metrics in one call.

This function combines the token counting logic with the metrics system
to ensure both API responses and Prometheus metrics are updated consistently.

Args:
turn: The turn object containing token usage information
model: The model identifier for metrics labeling
provider: The provider identifier for metrics labeling
system_prompt: The system prompt used for the turn

Returns:
TokenCounter: Token usage information
"""
token_counter = extract_token_usage_from_turn(turn, system_prompt)

# Update Prometheus metrics with the same token counts
try:
# Update the metrics using the same token counts we calculated
metrics.llm_token_sent_total.labels(provider, model).inc(
token_counter.input_tokens
)
metrics.llm_token_received_total.labels(provider, model).inc(
token_counter.output_tokens
)
metrics.llm_calls_total.labels(provider, model).inc()

except (AttributeError, TypeError, ValueError) as e:
logger.warning("Failed to update token metrics: %s", e)

return token_counter
Loading
Loading