@@ -179,30 +179,70 @@ def test_query_endpoint_handler_store_transcript(mocker):
179179 _test_query_endpoint_handler (mocker , store_transcript_to_file = True )
180180
181181
182- def test_select_model_and_provider_id (mocker ):
182+ def test_select_model_and_provider_id_from_request (mocker ):
183183 """Test the select_model_and_provider_id function."""
184- mock_client = mocker .Mock ()
185- mock_client .models .list .return_value = [
184+ mocker .patch (
185+ "metrics.utils.configuration.inference.default_provider" ,
186+ "default_provider" ,
187+ )
188+ mocker .patch (
189+ "metrics.utils.configuration.inference.default_model" ,
190+ "default_model" ,
191+ )
192+
193+ model_list = [
186194 mocker .Mock (identifier = "model1" , model_type = "llm" , provider_id = "provider1" ),
187195 mocker .Mock (identifier = "model2" , model_type = "llm" , provider_id = "provider2" ),
196+ mocker .Mock (
197+ identifier = "default_model" , model_type = "llm" , provider_id = "default_provider"
198+ ),
188199 ]
189200
201+ # Create a query request with model and provider specified
190202 query_request = QueryRequest (
191- query = "What is OpenStack?" , model = "model1 " , provider = "provider1 "
203+ query = "What is OpenStack?" , model = "model2 " , provider = "provider2 "
192204 )
193205
194- model_id , provider_id = select_model_and_provider_id (
195- mock_client .models .list (), query_request
206+ # Assert the model and provider from request take precedence from the configuration one
207+ model_id , provider_id = select_model_and_provider_id (model_list , query_request )
208+
209+ assert model_id == "model2"
210+ assert provider_id == "provider2"
211+
212+
213+ def test_select_model_and_provider_id_from_configuration (mocker ):
214+ """Test the select_model_and_provider_id function."""
215+ mocker .patch (
216+ "metrics.utils.configuration.inference.default_provider" ,
217+ "default_provider" ,
218+ )
219+ mocker .patch (
220+ "metrics.utils.configuration.inference.default_model" ,
221+ "default_model" ,
196222 )
197223
198- assert model_id == "model1"
199- assert provider_id == "provider1"
224+ model_list = [
225+ mocker .Mock (identifier = "model1" , model_type = "llm" , provider_id = "provider1" ),
226+ mocker .Mock (
227+ identifier = "default_model" , model_type = "llm" , provider_id = "default_provider"
228+ ),
229+ ]
230+
231+ # Create a query request without model and provider specified
232+ query_request = QueryRequest (
233+ query = "What is OpenStack?" ,
234+ )
235+
236+ model_id , provider_id = select_model_and_provider_id (model_list , query_request )
237+
238+ # Assert that the default model and provider from the configuration are returned
239+ assert model_id == "default_model"
240+ assert provider_id == "default_provider"
200241
201242
202- def test_select_model_and_provider_id_no_model (mocker ):
243+ def test_select_model_and_provider_id_first_from_list (mocker ):
203244 """Test the select_model_and_provider_id function when no model is specified."""
204- mock_client = mocker .Mock ()
205- mock_client .models .list .return_value = [
245+ model_list = [
206246 mocker .Mock (
207247 identifier = "not_llm_type" , model_type = "embedding" , provider_id = "provider1"
208248 ),
@@ -216,11 +256,10 @@ def test_select_model_and_provider_id_no_model(mocker):
216256
217257 query_request = QueryRequest (query = "What is OpenStack?" )
218258
219- model_id , provider_id = select_model_and_provider_id (
220- mock_client .models .list (), query_request
221- )
259+ model_id , provider_id = select_model_and_provider_id (model_list , query_request )
222260
223- # Assert return the first available LLM model
261+ # Assert return the first available LLM model when no model/provider is
262+ # specified in the request or in the configuration
224263 assert model_id == "first_model"
225264 assert provider_id == "provider1"
226265
0 commit comments