Skip to content

Commit ec79cb9

Browse files
committed
allow disabling query model and provider
1 parent 18fdf3c commit ec79cb9

File tree

7 files changed

+204
-0
lines changed

7 files changed

+204
-0
lines changed

src/app/endpoints/query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
get_agent,
3636
get_system_prompt,
3737
validate_conversation_ownership,
38+
validate_model_provider_override,
3839
)
3940
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
4041
from utils.transcripts import store_transcript
@@ -174,6 +175,9 @@ async def query_endpoint_handler(
174175
"""
175176
check_configuration_loaded(configuration)
176177

178+
# Enforce configuration: optionally disallow overriding model/provider in requests
179+
validate_model_provider_override(query_request, configuration)
180+
177181
# log Llama Stack configuration, but without sensitive information
178182
llama_stack_config = configuration.llama_stack_configuration.model_copy()
179183
llama_stack_config.api_key = "********"

src/app/endpoints/streaming_query.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
3434
from utils.transcripts import store_transcript
3535
from utils.types import TurnSummary
36+
from utils.endpoints import validate_model_provider_override
3637

3738
from app.endpoints.query import (
3839
get_rag_toolgroups,
@@ -548,6 +549,9 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
548549

549550
check_configuration_loaded(configuration)
550551

552+
# Enforce configuration: optionally disallow overriding model/provider in requests
553+
validate_model_provider_override(query_request, configuration)
554+
551555
# log Llama Stack configuration, but without sensitive information
552556
llama_stack_config = configuration.llama_stack_configuration.model_copy()
553557
llama_stack_config.api_key = "********"

src/models/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ class Customization(ConfigurationBase):
388388
"""Service customization."""
389389

390390
disable_query_system_prompt: bool = False
391+
disable_query_model_provider_override: bool = False
391392
system_prompt_path: Optional[FilePath] = None
392393
system_prompt: Optional[str] = None
393394

src/utils/endpoints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,31 @@ def get_system_prompt(query_request: QueryRequest, config: AppConfig) -> str:
8484
return constants.DEFAULT_SYSTEM_PROMPT
8585

8686

87+
def validate_model_provider_override(
88+
query_request: QueryRequest, config: AppConfig
89+
) -> None:
90+
"""Validate whether model/provider overrides are allowed.
91+
92+
Raises HTTP 422 if overrides are disabled and the request includes model or provider.
93+
"""
94+
disabled = config.customization is not None and getattr(
95+
config.customization, "disable_query_model_provider_override", False
96+
)
97+
if disabled and (
98+
query_request.model is not None or query_request.provider is not None
99+
):
100+
raise HTTPException(
101+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
102+
detail={
103+
"response": (
104+
"This instance does not support overriding model/provider in the query request "
105+
"(disable_query_model_provider_override is set). Please remove the model and "
106+
"provider fields from your request."
107+
)
108+
},
109+
)
110+
111+
87112
# # pylint: disable=R0913,R0917
88113
async def get_agent(
89114
client: AsyncLlamaStackClient,

tests/unit/app/endpoints/test_query.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,3 +1507,57 @@ def test_evaluate_model_hints(
15071507

15081508
assert provider_id == expected_provider
15091509
assert model_id == expected_model
1510+
1511+
1512+
@pytest.mark.asyncio
1513+
async def test_query_endpoint_rejects_model_provider_override_when_disabled(
1514+
mocker, dummy_request
1515+
):
1516+
"""Assert 422 and message when the override is disabled.
1517+
1518+
Validates behavior when request includes model/provider and
1519+
customization.disable_query_model_provider_override is set.
1520+
"""
1521+
# Prepare configuration with the override disabled
1522+
config_dict = {
1523+
"name": "test",
1524+
"service": {
1525+
"host": "localhost",
1526+
"port": 8080,
1527+
"auth_enabled": False,
1528+
"workers": 1,
1529+
"color_log": True,
1530+
"access_log": True,
1531+
},
1532+
"llama_stack": {
1533+
"api_key": "test-key",
1534+
"url": "http://test.com:1234",
1535+
"use_as_library_client": False,
1536+
},
1537+
"user_data_collection": {"transcripts_enabled": False},
1538+
"mcp_servers": [],
1539+
"customization": {
1540+
"disable_query_model_provider_override": True,
1541+
},
1542+
}
1543+
cfg = AppConfig()
1544+
cfg.init_from_dict(config_dict)
1545+
1546+
# Patch endpoint configuration
1547+
mocker.patch("app.endpoints.query.configuration", cfg)
1548+
1549+
# Build a request that tries to override model/provider
1550+
query_request = QueryRequest(query="What?", model="m", provider="p")
1551+
1552+
with pytest.raises(HTTPException) as exc_info:
1553+
await query_endpoint_handler(
1554+
request=dummy_request, query_request=query_request, auth=MOCK_AUTH
1555+
)
1556+
1557+
expected_msg = (
1558+
"This instance does not support overriding model/provider in the query request "
1559+
"(disable_query_model_provider_override is set). Please remove the model and "
1560+
"provider fields from your request."
1561+
)
1562+
assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
1563+
assert exc_info.value.detail["response"] == expected_msg

tests/unit/app/endpoints/test_streaming_query.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,3 +1515,61 @@ async def test_retrieve_response_no_tools_false_preserves_functionality(
15151515
stream=True,
15161516
toolgroups=expected_toolgroups,
15171517
)
1518+
1519+
1520+
@pytest.mark.asyncio
1521+
async def test_streaming_query_endpoint_rejects_model_provider_override_when_disabled(
1522+
mocker,
1523+
):
1524+
"""Assert 422 and message when the override is disabled.
1525+
1526+
Validates behavior when request includes model/provider and
1527+
customization.disable_query_model_provider_override is set.
1528+
"""
1529+
# Prepare configuration with the override disabled
1530+
config_dict = {
1531+
"name": "test",
1532+
"service": {
1533+
"host": "localhost",
1534+
"port": 8080,
1535+
"auth_enabled": False,
1536+
"workers": 1,
1537+
"color_log": True,
1538+
"access_log": True,
1539+
},
1540+
"llama_stack": {
1541+
"api_key": "test-key",
1542+
"url": "http://test.com:1234",
1543+
"use_as_library_client": False,
1544+
},
1545+
"user_data_collection": {"transcripts_enabled": False},
1546+
"mcp_servers": [],
1547+
"customization": {
1548+
"disable_query_model_provider_override": True,
1549+
},
1550+
}
1551+
cfg = AppConfig()
1552+
cfg.init_from_dict(config_dict)
1553+
1554+
# Patch endpoint configuration
1555+
mocker.patch("app.endpoints.streaming_query.configuration", cfg)
1556+
1557+
# Build a query request that tries to override model/provider
1558+
query_request = QueryRequest(query="What?", model="m", provider="p")
1559+
1560+
request = Request(
1561+
scope={
1562+
"type": "http",
1563+
}
1564+
)
1565+
1566+
with pytest.raises(HTTPException) as exc_info:
1567+
await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH)
1568+
1569+
expected_msg = (
1570+
"This instance does not support overriding model/provider in the query request "
1571+
"(disable_query_model_provider_override is set). Please remove the model and "
1572+
"provider fields from your request."
1573+
)
1574+
assert exc_info.value.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
1575+
assert exc_info.value.detail["response"] == expected_msg

tests/unit/utils/test_endpoints.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,61 @@ async def test_get_agent_no_tools_false_preserves_parser(
591591
tool_parser=mock_parser,
592592
enable_session_persistence=True,
593593
)
594+
595+
596+
@pytest.fixture(name="config_with_override_disabled")
597+
def config_with_override_disabled_fixture():
598+
"""Configuration where overriding model/provider is allowed (flag False)."""
599+
test_config = config_dict.copy()
600+
test_config["customization"] = {
601+
"disable_query_model_provider_override": False,
602+
}
603+
cfg = AppConfig()
604+
cfg.init_from_dict(test_config)
605+
return cfg
606+
607+
608+
@pytest.fixture(name="config_with_override_enabled")
609+
def config_with_override_enabled_fixture():
610+
"""Configuration where overriding model/provider is NOT allowed (flag True)."""
611+
test_config = config_dict.copy()
612+
test_config["customization"] = {
613+
"disable_query_model_provider_override": True,
614+
}
615+
cfg = AppConfig()
616+
cfg.init_from_dict(test_config)
617+
return cfg
618+
619+
620+
def test_validate_model_provider_override_allowed_when_flag_false(
621+
config_with_override_disabled,
622+
):
623+
"""Ensure no exception when overrides are allowed and request includes model/provider."""
624+
query_request = QueryRequest(query="q", model="m", provider="p")
625+
# Should not raise
626+
endpoints.validate_model_provider_override(
627+
query_request, config_with_override_disabled
628+
)
629+
630+
631+
def test_validate_model_provider_override_rejected_when_flag_true(
632+
config_with_override_enabled,
633+
):
634+
"""Ensure HTTP 422 when overrides are disabled and request includes model/provider."""
635+
query_request = QueryRequest(query="q", model="m", provider="p")
636+
with pytest.raises(HTTPException) as exc_info:
637+
endpoints.validate_model_provider_override(
638+
query_request, config_with_override_enabled
639+
)
640+
assert exc_info.value.status_code == 422
641+
642+
643+
def test_validate_model_provider_override_no_override_with_flag_true(
644+
config_with_override_enabled,
645+
):
646+
"""No exception when overrides are disabled but request does not include model/provider."""
647+
query_request = QueryRequest(query="q")
648+
# Should not raise
649+
endpoints.validate_model_provider_override(
650+
query_request, config_with_override_enabled
651+
)

0 commit comments

Comments
 (0)