Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[Agent, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
Expand All @@ -99,7 +100,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
instructions=system_prompt,
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=GraniteToolParser.get_parser(model_id),
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
conversation_id = agent.create_session(get_suid())
Expand Down Expand Up @@ -288,36 +289,47 @@ def retrieve_response( # pylint: disable=too-many-locals
available_input_shields,
available_output_shields,
query_request.conversation_id,
query_request.no_tools or False,
)

# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}

agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}
agent.extra_headers = {}
toolgroups = None
else:
# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
mcp_headers = {}
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}

agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
# Convert empty list to None for consistency with existing behavior
if not toolgroups:
toolgroups = None

vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
response = agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=False,
toolgroups=toolgroups or None,
toolgroups=toolgroups,
)

# Check for validation errors in the response
Expand Down
59 changes: 36 additions & 23 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ async def get_agent(
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
no_tools: bool = False,
) -> tuple[AsyncAgent, str]:
"""Get existing agent or create a new one with session persistence."""
if conversation_id is not None:
Expand All @@ -76,7 +77,7 @@ async def get_agent(
instructions=system_prompt,
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=GraniteToolParser.get_parser(model_id),
tool_parser=None if no_tools else GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
conversation_id = await agent.create_session(get_suid())
Expand Down Expand Up @@ -532,41 +533,53 @@ async def retrieve_response(
available_input_shields,
available_output_shields,
query_request.conversation_id,
query_request.no_tools or False,
)

# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
# bypass tools and MCP servers if no_tools is True
if query_request.no_tools:
mcp_headers = {}
agent.extra_headers = {}
toolgroups = None
else:
# preserve compatibility when mcp_headers is not provided
if mcp_headers is None:
mcp_headers = {}

mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)
mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration)

if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}
if not mcp_headers and token:
for mcp_server in configuration.mcp_servers:
mcp_headers[mcp_server.url] = {
"Authorization": f"Bearer {token}",
}

agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}
agent.extra_headers = {
"X-LlamaStack-Provider-Data": json.dumps(
{
"mcp_headers": mcp_headers,
}
),
}

logger.debug("Session ID: %s", conversation_id)
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
# Convert empty list to None for consistency with existing behavior
if not toolgroups:
toolgroups = None

logger.debug("Session ID: %s", conversation_id)
vector_db_ids = [
vector_db.identifier for vector_db in await client.vector_dbs.list()
]
toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [
mcp_server.name for mcp_server in configuration.mcp_servers
]
response = await agent.create_turn(
messages=[UserMessage(role="user", content=query_request.query)],
session_id=conversation_id,
documents=query_request.get_documents(),
stream=True,
toolgroups=toolgroups or None,
toolgroups=toolgroups,
)

return response, conversation_id
3 changes: 3 additions & 0 deletions src/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class QueryRequest(BaseModel):
model: The optional model.
system_prompt: The optional system prompt.
attachments: The optional attachments.
no_tools: Whether to bypass all tools and MCP servers (default: False).

Example:
```python
Expand All @@ -82,6 +83,7 @@ class QueryRequest(BaseModel):
model: Optional[str] = None
system_prompt: Optional[str] = None
attachments: Optional[list[Attachment]] = None
no_tools: Optional[bool] = False
# media_type is not used in 'lightspeed-stack' that only supports application/json.
# the field is kept here to enable compatibility with 'road-core' clients.
media_type: Optional[str] = None
Expand All @@ -97,6 +99,7 @@ class QueryRequest(BaseModel):
"provider": "openai",
"model": "model-name",
"system_prompt": "You are a helpful assistant",
"no_tools": False,
"attachments": [
{
"attachment_type": "log",
Expand Down
Loading
Loading