Skip to content

Commit c1c7ba3

Browse files
authored
Merge pull request #284 from umago/LCORE-418-default-model-provider-config
LCORE-418: Allow configurating a default model/provider
2 parents fb193c5 + 5943d1b commit c1c7ba3

File tree

7 files changed

+247
-30
lines changed

7 files changed

+247
-30
lines changed

src/app/endpoints/query.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,26 @@ def select_model_and_provider_id(
175175
models: ModelListResponse, query_request: QueryRequest
176176
) -> tuple[str, str | None]:
177177
"""Select the model ID and provider ID based on the request or available models."""
178+
# If model_id and provider_id are provided in the request, use them
178179
model_id = query_request.model
179180
provider_id = query_request.provider
180181

181-
# TODO(lucasagomes): support default model selection via configuration
182-
if not model_id:
183-
logger.info("No model specified in request, using the first available LLM")
182+
# If model_id is not provided in the request, check the configuration
183+
if not model_id or not provider_id:
184+
logger.debug(
185+
"No model ID or provider ID specified in request, checking configuration"
186+
)
187+
model_id = configuration.inference.default_model # type: ignore[reportAttributeAccessIssue]
188+
provider_id = (
189+
configuration.inference.default_provider # type: ignore[reportAttributeAccessIssue]
190+
)
191+
192+
# If no model is specified in the request or configuration, use the first available LLM
193+
if not model_id or not provider_id:
194+
logger.debug(
195+
"No model ID or provider ID specified in request or configuration, "
196+
"using the first available LLM"
197+
)
184198
try:
185199
model = next(
186200
m
@@ -202,7 +216,8 @@ def select_model_and_provider_id(
202216
},
203217
) from e
204218

205-
logger.info("Searching for model: %s, provider: %s", model_id, provider_id)
219+
# Validate that the model_id and provider_id are in the available models
220+
logger.debug("Searching for model: %s, provider: %s", model_id, provider_id)
206221
if not any(
207222
m.identifier == model_id and m.provider_id == provider_id for m in models
208223
):

src/configuration.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
ServiceConfiguration,
1313
ModelContextProtocolServer,
1414
AuthenticationConfiguration,
15+
InferenceConfiguration,
1516
)
1617

1718
logger = logging.getLogger(__name__)
@@ -99,5 +100,13 @@ def customization(self) -> Optional[Customization]:
99100
), "logic error: configuration is not loaded"
100101
return self._configuration.customization
101102

103+
@property
104+
def inference(self) -> Optional[InferenceConfiguration]:
105+
"""Return inference configuration."""
106+
assert (
107+
self._configuration is not None
108+
), "logic error: configuration is not loaded"
109+
return self._configuration.inference
110+
102111

103112
configuration: AppConfig = AppConfig()

src/metrics/utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
"""Utility functions for metrics handling."""
22

3+
from configuration import configuration
34
from client import LlamaStackClientHolder
45
from log import get_logger
56
import metrics
67

78
logger = get_logger(__name__)
89

910

10-
# TODO(lucasagomes): Change this metric once we are allowed to set the the
11-
# default model/provider via the configuration.The default provider/model
12-
# will be set to 1, and the rest will be set to 0.
1311
def setup_model_metrics() -> None:
1412
"""Perform setup of all metrics related to LLM model and provider."""
1513
client = LlamaStackClientHolder().get_client()
@@ -19,14 +17,29 @@ def setup_model_metrics() -> None:
1917
if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue]
2018
]
2119

20+
default_model_label = (
21+
configuration.inference.default_provider, # type: ignore[reportAttributeAccessIssue]
22+
configuration.inference.default_model, # type: ignore[reportAttributeAccessIssue]
23+
)
24+
2225
for model in models:
2326
provider = model.provider_id
2427
model_name = model.identifier
2528
if provider and model_name:
29+
# If the model/provider combination is the default, set the metric value to 1
30+
# Otherwise, set it to 0
31+
default_model_value = 0
2632
label_key = (provider, model_name)
27-
metrics.provider_model_configuration.labels(*label_key).set(1)
33+
if label_key == default_model_label:
34+
default_model_value = 1
35+
36+
# Set the metric for the provider/model configuration
37+
metrics.provider_model_configuration.labels(*label_key).set(
38+
default_model_value
39+
)
2840
logger.debug(
29-
"Set provider/model configuration for %s/%s to 1",
41+
"Set provider/model configuration for %s/%s to %d",
3042
provider,
3143
model_name,
44+
default_model_value,
3245
)

src/models/config.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,26 @@ def check_customization_model(self) -> Self:
185185
return self
186186

187187

188+
class InferenceConfiguration(BaseModel):
189+
"""Inference configuration."""
190+
191+
default_model: Optional[str] = None
192+
default_provider: Optional[str] = None
193+
194+
@model_validator(mode="after")
195+
def check_default_model_and_provider(self) -> Self:
196+
"""Check default model and provider."""
197+
if self.default_model is None and self.default_provider is not None:
198+
raise ValueError(
199+
"Default model must be specified when default provider is set"
200+
)
201+
if self.default_model is not None and self.default_provider is None:
202+
raise ValueError(
203+
"Default provider must be specified when default model is set"
204+
)
205+
return self
206+
207+
188208
class Configuration(BaseModel):
189209
"""Global service configuration."""
190210

@@ -197,6 +217,7 @@ class Configuration(BaseModel):
197217
AuthenticationConfiguration()
198218
)
199219
customization: Optional[Customization] = None
220+
inference: Optional[InferenceConfiguration] = InferenceConfiguration()
200221

201222
def dump(self, filename: str = "configuration.json") -> None:
202223
"""Dump actual configuration into JSON file."""

tests/unit/app/endpoints/test_query.py

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -179,30 +179,70 @@ def test_query_endpoint_handler_store_transcript(mocker):
179179
_test_query_endpoint_handler(mocker, store_transcript_to_file=True)
180180

181181

182-
def test_select_model_and_provider_id(mocker):
182+
def test_select_model_and_provider_id_from_request(mocker):
183183
"""Test the select_model_and_provider_id function."""
184-
mock_client = mocker.Mock()
185-
mock_client.models.list.return_value = [
184+
mocker.patch(
185+
"metrics.utils.configuration.inference.default_provider",
186+
"default_provider",
187+
)
188+
mocker.patch(
189+
"metrics.utils.configuration.inference.default_model",
190+
"default_model",
191+
)
192+
193+
model_list = [
186194
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
187195
mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"),
196+
mocker.Mock(
197+
identifier="default_model", model_type="llm", provider_id="default_provider"
198+
),
188199
]
189200

201+
# Create a query request with model and provider specified
190202
query_request = QueryRequest(
191-
query="What is OpenStack?", model="model1", provider="provider1"
203+
query="What is OpenStack?", model="model2", provider="provider2"
192204
)
193205

194-
model_id, provider_id = select_model_and_provider_id(
195-
mock_client.models.list(), query_request
206+
# Assert the model and provider from request take precedence from the configuration one
207+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
208+
209+
assert model_id == "model2"
210+
assert provider_id == "provider2"
211+
212+
213+
def test_select_model_and_provider_id_from_configuration(mocker):
214+
"""Test the select_model_and_provider_id function."""
215+
mocker.patch(
216+
"metrics.utils.configuration.inference.default_provider",
217+
"default_provider",
218+
)
219+
mocker.patch(
220+
"metrics.utils.configuration.inference.default_model",
221+
"default_model",
196222
)
197223

198-
assert model_id == "model1"
199-
assert provider_id == "provider1"
224+
model_list = [
225+
mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"),
226+
mocker.Mock(
227+
identifier="default_model", model_type="llm", provider_id="default_provider"
228+
),
229+
]
230+
231+
# Create a query request without model and provider specified
232+
query_request = QueryRequest(
233+
query="What is OpenStack?",
234+
)
235+
236+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
237+
238+
# Assert that the default model and provider from the configuration are returned
239+
assert model_id == "default_model"
240+
assert provider_id == "default_provider"
200241

201242

202-
def test_select_model_and_provider_id_no_model(mocker):
243+
def test_select_model_and_provider_id_first_from_list(mocker):
203244
"""Test the select_model_and_provider_id function when no model is specified."""
204-
mock_client = mocker.Mock()
205-
mock_client.models.list.return_value = [
245+
model_list = [
206246
mocker.Mock(
207247
identifier="not_llm_type", model_type="embedding", provider_id="provider1"
208248
),
@@ -216,11 +256,10 @@ def test_select_model_and_provider_id_no_model(mocker):
216256

217257
query_request = QueryRequest(query="What is OpenStack?")
218258

219-
model_id, provider_id = select_model_and_provider_id(
220-
mock_client.models.list(), query_request
221-
)
259+
model_id, provider_id = select_model_and_provider_id(model_list, query_request)
222260

223-
# Assert return the first available LLM model
261+
# Assert return the first available LLM model when no model/provider is
262+
# specified in the request or in the configuration
224263
assert model_id == "first_model"
225264
assert provider_id == "provider1"
226265

tests/unit/metrics/test_utis.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,62 @@ def test_setup_model_metrics(mocker):
88

99
# Mock the LlamaStackAsLibraryClient
1010
mock_client = mocker.patch("client.LlamaStackClientHolder.get_client").return_value
11+
mocker.patch(
12+
"metrics.utils.configuration.inference.default_provider",
13+
"default_provider",
14+
)
15+
mocker.patch(
16+
"metrics.utils.configuration.inference.default_model",
17+
"default_model",
18+
)
1119

1220
mock_metric = mocker.patch("metrics.provider_model_configuration")
13-
fake_model = mocker.Mock(
14-
provider_id="test_provider",
15-
identifier="test_model",
21+
# Mock a model that is the default
22+
model_default = mocker.Mock(
23+
provider_id="default_provider",
24+
identifier="default_model",
1625
model_type="llm",
1726
)
18-
mock_client.models.list.return_value = [fake_model]
27+
# Mock a model that is not the default
28+
model_0 = mocker.Mock(
29+
provider_id="test_provider-0",
30+
identifier="test_model-0",
31+
model_type="llm",
32+
)
33+
# Mock a second model which is not default
34+
model_1 = mocker.Mock(
35+
provider_id="test_provider-1",
36+
identifier="test_model-1",
37+
model_type="llm",
38+
)
39+
# Mock a model that is not an LLM type, should be ignored
40+
not_llm_model = mocker.Mock(
41+
provider_id="not-llm-provider",
42+
identifier="not-llm-model",
43+
model_type="not-llm",
44+
)
45+
46+
# Mock the list of models returned by the client
47+
mock_client.models.list.return_value = [
48+
model_0,
49+
model_default,
50+
not_llm_model,
51+
model_1,
52+
]
1953

2054
setup_model_metrics()
2155

22-
# Assert that the metric was set correctly
23-
mock_metric.labels("test_provider", "test_model").set.assert_called_once_with(1)
56+
# Check that the provider_model_configuration metric was set correctly
57+
# The default model should have a value of 1, others should be 0
58+
assert mock_metric.labels.call_count == 3
59+
mock_metric.assert_has_calls(
60+
[
61+
mocker.call.labels("test_provider-0", "test_model-0"),
62+
mocker.call.labels().set(0),
63+
mocker.call.labels("default_provider", "default_model"),
64+
mocker.call.labels().set(1),
65+
mocker.call.labels("test_provider-1", "test_model-1"),
66+
mocker.call.labels().set(0),
67+
],
68+
any_order=False, # Order matters here
69+
)

0 commit comments

Comments
 (0)