@@ -24,15 +24,24 @@ async def test_gateway_client_init():
24
24
client = GatewayClient (
25
25
base_url = "http://localhost:8888/" ,
26
26
default_model = "test-model" ,
27
- polling_interval = 0.01 ,
27
+ polling_interval = 0.5 ,
28
28
timeout = 0.1 ,
29
29
)
30
30
assert client .base_url == "http://localhost:8888"
31
31
assert client .default_model == "test-model"
32
- assert client .polling_interval == 0.01
32
+ assert client .polling_interval == 0.5
33
33
assert client .timeout == 0.1
34
34
assert client ._client is None
35
35
36
+ client = GatewayClient (
37
+ base_url = "http://localhost:8888/" ,
38
+ polling_interval = 10 ,
39
+ )
40
+ assert client .polling_interval == 3.0 # Maximum 3.0 seconds
41
+
42
+ client .polling_interval = 0.05
43
+ assert client .polling_interval == 0.5 # Minimum is 0.5 seconds
44
+
36
45
37
46
@pytest .mark .asyncio
38
47
async def test_gateway_client_aenter_aexit (mock_httpx_async_client ):
@@ -65,16 +74,18 @@ async def test_submit_task_success(mock_httpx_async_client):
65
74
_ , mock_client_instance = mock_httpx_async_client
66
75
mock_response = MagicMock ()
67
76
mock_response .json .return_value = {"uuid" : "task-123" , "status" : "pending" }
68
- mock_client_instance .post .return_value = mock_response
77
+ mock_response .raise_for_status .return_value = mock_response
78
+ mock_client_instance .request .return_value = mock_response
69
79
70
80
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
71
81
task_info = await client .submit_task (
72
82
model_name = "my_model" , task = "process" , data = "some text"
73
83
)
74
84
75
85
assert task_info == {"uuid" : "task-123" , "status" : "pending" }
76
- mock_client_instance .post .assert_awaited_once_with (
77
- "http://test-gateway.com/models/my_model/tasks/process" ,
86
+ mock_client_instance .request .assert_awaited_once_with (
87
+ method = "POST" ,
88
+ url = "http://test-gateway.com/models/my_model/tasks/process" ,
78
89
data = "some text" ,
79
90
json = None ,
80
91
files = None ,
@@ -90,16 +101,18 @@ async def test_submit_task_with_default_model(mock_httpx_async_client):
90
101
_ , mock_client_instance = mock_httpx_async_client
91
102
mock_response = MagicMock ()
92
103
mock_response .json .return_value = {"uuid" : "task-456" , "status" : "pending" }
93
- mock_client_instance .post .return_value = mock_response
104
+ mock_response .raise_for_status .return_value = mock_response
105
+ mock_client_instance .request .return_value = mock_response
94
106
95
107
async with GatewayClient (
96
108
base_url = "http://test-gateway.com" , default_model = "default_model"
97
109
) as client :
98
110
task_info = await client .submit_task (task = "process" , data = "some text" )
99
111
100
112
assert task_info == {"uuid" : "task-456" , "status" : "pending" }
101
- mock_client_instance .post .assert_awaited_once_with (
102
- "http://test-gateway.com/models/default_model/tasks/process" ,
113
+ mock_client_instance .request .assert_awaited_once_with (
114
+ method = "POST" ,
115
+ url = "http://test-gateway.com/models/default_model/tasks/process" ,
103
116
data = "some text" ,
104
117
json = None ,
105
118
files = None ,
@@ -124,7 +137,7 @@ async def test_submit_task_http_error(mock_httpx_async_client):
124
137
mock_response .raise_for_status .side_effect = httpx .HTTPStatusError (
125
138
"Bad Request" , request = httpx .Request ("POST" , "url" ), response = httpx .Response (400 )
126
139
)
127
- mock_client_instance .post .return_value = mock_response
140
+ mock_client_instance .request .return_value = mock_response
128
141
129
142
with pytest .raises (httpx .HTTPStatusError ):
130
143
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
@@ -137,7 +150,8 @@ async def test_submit_task_wait_for_completion_and_return_result(mock_httpx_asyn
137
150
_ , mock_client_instance = mock_httpx_async_client
138
151
mock_response = MagicMock ()
139
152
mock_response .json .return_value = {"uuid" : "task-123" , "status" : "pending" }
140
- mock_client_instance .post .return_value = mock_response
153
+ mock_response .raise_for_status .return_value = mock_response
154
+ mock_client_instance .request .return_value = mock_response
141
155
142
156
mock_wait_for_task = mocker .patch (
143
157
"client.cogstack_model_gateway_client.client.GatewayClient.wait_for_task" ,
@@ -220,14 +234,20 @@ async def test_get_task_success(mock_httpx_async_client):
220
234
mock_response = MagicMock ()
221
235
mock_response .raise_for_status .return_value = mock_response
222
236
mock_response .json .return_value = {"uuid" : "task-123" , "status" : "succeeded" }
223
- mock_client_instance .get .return_value = mock_response
237
+ mock_client_instance .request .return_value = mock_response
224
238
225
239
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
226
240
task_info = await client .get_task ("task-123" )
227
241
228
242
assert task_info == {"uuid" : "task-123" , "status" : "succeeded" }
229
- mock_client_instance .get .assert_awaited_once_with (
230
- "http://test-gateway.com/tasks/task-123" , params = {"detail" : True , "download" : False }
243
+ mock_client_instance .request .assert_awaited_once_with (
244
+ method = "GET" ,
245
+ url = "http://test-gateway.com/tasks/task-123" ,
246
+ params = {"detail" : True , "download" : False },
247
+ data = None ,
248
+ json = None ,
249
+ files = None ,
250
+ headers = None ,
231
251
)
232
252
mock_response .raise_for_status .assert_called_once ()
233
253
@@ -239,7 +259,7 @@ async def test_get_task_result_json(mock_httpx_async_client):
239
259
mock_response = MagicMock ()
240
260
mock_response .content = b'{"key": "value"}'
241
261
mock_response .raise_for_status .return_value = mock_response
242
- mock_client_instance .get .return_value = mock_response
262
+ mock_client_instance .request .return_value = mock_response
243
263
244
264
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
245
265
result = await client .get_task_result ("task-123" )
@@ -253,7 +273,7 @@ async def test_get_task_result_jsonl(mock_httpx_async_client):
253
273
mock_response = MagicMock ()
254
274
mock_response .content = b'{"item": 1}\n {"item": 2}\n '
255
275
mock_response .raise_for_status .return_value = mock_response
256
- mock_client_instance .get .return_value = mock_response
276
+ mock_client_instance .request .return_value = mock_response
257
277
258
278
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
259
279
result = await client .get_task_result ("task-123" )
@@ -267,7 +287,7 @@ async def test_get_task_result_text(mock_httpx_async_client):
267
287
mock_response = MagicMock ()
268
288
mock_response .content = b"plain text result"
269
289
mock_response .raise_for_status .return_value = mock_response
270
- mock_client_instance .get .return_value = mock_response
290
+ mock_client_instance .request .return_value = mock_response
271
291
272
292
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
273
293
result = await client .get_task_result ("task-123" )
@@ -281,7 +301,7 @@ async def test_get_task_result_binary(mock_httpx_async_client):
281
301
mock_response = MagicMock ()
282
302
mock_response .content = b"\x80 \x01 \x02 \x03 " # Example binary data
283
303
mock_response .raise_for_status .return_value = mock_response
284
- mock_client_instance .get .return_value = mock_response
304
+ mock_client_instance .request .return_value = mock_response
285
305
286
306
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
287
307
result = await client .get_task_result ("task-123" )
@@ -295,7 +315,7 @@ async def test_get_task_result_no_parse(mock_httpx_async_client):
295
315
mock_response = MagicMock ()
296
316
mock_response .content = b'{"key": "value"}'
297
317
mock_response .raise_for_status .return_value = mock_response
298
- mock_client_instance .get .return_value = mock_response
318
+ mock_client_instance .request .return_value = mock_response
299
319
300
320
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
301
321
result = await client .get_task_result ("task-123" , parse = False )
@@ -334,7 +354,7 @@ async def test_wait_for_task_timeout(mock_httpx_async_client, mocker):
334
354
335
355
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
336
356
client .timeout = 0.05
337
- client .polling_interval = 0.01
357
+ client .polling_interval = 0.5
338
358
339
359
with pytest .raises (
340
360
TimeoutError , match = "Timed out waiting for task 'task-polling' to complete"
@@ -394,13 +414,20 @@ async def test_get_models_success(mock_httpx_async_client):
394
414
_ , mock_client_instance = mock_httpx_async_client
395
415
mock_response = MagicMock ()
396
416
mock_response .json .return_value = ["model_a" , "model_b" ]
397
- mock_client_instance .get .return_value = mock_response
417
+ mock_response .raise_for_status .return_value = mock_response
418
+ mock_client_instance .request .return_value = mock_response
398
419
399
420
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
400
421
models = await client .get_models ()
401
422
assert models == ["model_a" , "model_b" ]
402
- mock_client_instance .get .assert_awaited_once_with (
403
- "http://test-gateway.com/models/" , params = {"verbose" : False }
423
+ mock_client_instance .request .assert_awaited_once_with (
424
+ method = "GET" ,
425
+ url = "http://test-gateway.com/models/" ,
426
+ params = {"verbose" : False },
427
+ data = None ,
428
+ json = None ,
429
+ files = None ,
430
+ headers = None ,
404
431
)
405
432
406
433
@@ -410,13 +437,20 @@ async def test_get_model_success(mock_httpx_async_client):
410
437
_ , mock_client_instance = mock_httpx_async_client
411
438
mock_response = MagicMock ()
412
439
mock_response .json .return_value = {"name" : "my_model" , "status" : "deployed" }
413
- mock_client_instance .get .return_value = mock_response
440
+ mock_response .raise_for_status .return_value = mock_response
441
+ mock_client_instance .request .return_value = mock_response
414
442
415
443
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
416
444
model_info = await client .get_model (model_name = "my_model" )
417
445
assert model_info == {"name" : "my_model" , "status" : "deployed" }
418
- mock_client_instance .get .assert_awaited_once_with (
419
- "http://test-gateway.com/models/my_model/info"
446
+ mock_client_instance .request .assert_awaited_once_with (
447
+ method = "GET" ,
448
+ url = "http://test-gateway.com/models/my_model/info" ,
449
+ params = None ,
450
+ data = None ,
451
+ json = None ,
452
+ files = None ,
453
+ headers = None ,
420
454
)
421
455
422
456
@@ -426,15 +460,22 @@ async def test_get_model_with_default_model(mock_httpx_async_client):
426
460
_ , mock_client_instance = mock_httpx_async_client
427
461
mock_response = MagicMock ()
428
462
mock_response .json .return_value = {"name" : "default_model" , "status" : "deployed" }
429
- mock_client_instance .get .return_value = mock_response
463
+ mock_response .raise_for_status .return_value = mock_response
464
+ mock_client_instance .request .return_value = mock_response
430
465
431
466
async with GatewayClient (
432
467
base_url = "http://test-gateway.com" , default_model = "default_model"
433
468
) as client :
434
469
model_info = await client .get_model ()
435
470
assert model_info == {"name" : "default_model" , "status" : "deployed" }
436
- mock_client_instance .get .assert_awaited_once_with (
437
- "http://test-gateway.com/models/default_model/info"
471
+ mock_client_instance .request .assert_awaited_once_with (
472
+ method = "GET" ,
473
+ url = "http://test-gateway.com/models/default_model/info" ,
474
+ params = None ,
475
+ data = None ,
476
+ json = None ,
477
+ files = None ,
478
+ headers = None ,
438
479
)
439
480
440
481
@@ -452,7 +493,8 @@ async def test_deploy_model_success(mock_httpx_async_client):
452
493
_ , mock_client_instance = mock_httpx_async_client
453
494
mock_response = MagicMock ()
454
495
mock_response .json .return_value = {"name" : "new_model" , "status" : "deploying" }
455
- mock_client_instance .post .return_value = mock_response
496
+ mock_response .raise_for_status .return_value = mock_response
497
+ mock_client_instance .request .return_value = mock_response
456
498
457
499
async with GatewayClient (base_url = "http://test-gateway.com" ) as client :
458
500
deploy_info = await client .deploy_model (
@@ -462,11 +504,16 @@ async def test_deploy_model_success(mock_httpx_async_client):
462
504
)
463
505
464
506
assert deploy_info == {"name" : "new_model" , "status" : "deploying" }
465
- mock_client_instance .post .assert_awaited_once_with (
466
- "http://test-gateway.com/models/new_model" ,
507
+ mock_client_instance .request .assert_awaited_once_with (
508
+ method = "POST" ,
509
+ url = "http://test-gateway.com/models/new_model" ,
467
510
json = {
468
511
"tracking_id" : None ,
469
512
"model_uri" : "mlflow-artifacts:/1/runidabcd1234/artifacts/new_model" ,
470
513
"ttl" : 3600 ,
471
514
},
515
+ params = None ,
516
+ data = None ,
517
+ files = None ,
518
+ headers = None ,
472
519
)
0 commit comments