Skip to content

Commit ca0ca04

Browse files
committed
fix: Update client tests
Update client unit tests according to the latest changes in the client. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 1246ae7 commit ca0ca04

File tree

1 file changed

+78
-31
lines changed

1 file changed

+78
-31
lines changed

tests/unit/client/test_client.py

Lines changed: 78 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,24 @@ async def test_gateway_client_init():
2424
client = GatewayClient(
2525
base_url="http://localhost:8888/",
2626
default_model="test-model",
27-
polling_interval=0.01,
27+
polling_interval=0.5,
2828
timeout=0.1,
2929
)
3030
assert client.base_url == "http://localhost:8888"
3131
assert client.default_model == "test-model"
32-
assert client.polling_interval == 0.01
32+
assert client.polling_interval == 0.5
3333
assert client.timeout == 0.1
3434
assert client._client is None
3535

36+
client = GatewayClient(
37+
base_url="http://localhost:8888/",
38+
polling_interval=10,
39+
)
40+
assert client.polling_interval == 3.0 # Maximum 3.0 seconds
41+
42+
client.polling_interval = 0.05
43+
assert client.polling_interval == 0.5 # Minimum is 0.5 seconds
44+
3645

3746
@pytest.mark.asyncio
3847
async def test_gateway_client_aenter_aexit(mock_httpx_async_client):
@@ -65,16 +74,18 @@ async def test_submit_task_success(mock_httpx_async_client):
6574
_, mock_client_instance = mock_httpx_async_client
6675
mock_response = MagicMock()
6776
mock_response.json.return_value = {"uuid": "task-123", "status": "pending"}
68-
mock_client_instance.post.return_value = mock_response
77+
mock_response.raise_for_status.return_value = mock_response
78+
mock_client_instance.request.return_value = mock_response
6979

7080
async with GatewayClient(base_url="http://test-gateway.com") as client:
7181
task_info = await client.submit_task(
7282
model_name="my_model", task="process", data="some text"
7383
)
7484

7585
assert task_info == {"uuid": "task-123", "status": "pending"}
76-
mock_client_instance.post.assert_awaited_once_with(
77-
"http://test-gateway.com/models/my_model/tasks/process",
86+
mock_client_instance.request.assert_awaited_once_with(
87+
method="POST",
88+
url="http://test-gateway.com/models/my_model/tasks/process",
7889
data="some text",
7990
json=None,
8091
files=None,
@@ -90,16 +101,18 @@ async def test_submit_task_with_default_model(mock_httpx_async_client):
90101
_, mock_client_instance = mock_httpx_async_client
91102
mock_response = MagicMock()
92103
mock_response.json.return_value = {"uuid": "task-456", "status": "pending"}
93-
mock_client_instance.post.return_value = mock_response
104+
mock_response.raise_for_status.return_value = mock_response
105+
mock_client_instance.request.return_value = mock_response
94106

95107
async with GatewayClient(
96108
base_url="http://test-gateway.com", default_model="default_model"
97109
) as client:
98110
task_info = await client.submit_task(task="process", data="some text")
99111

100112
assert task_info == {"uuid": "task-456", "status": "pending"}
101-
mock_client_instance.post.assert_awaited_once_with(
102-
"http://test-gateway.com/models/default_model/tasks/process",
113+
mock_client_instance.request.assert_awaited_once_with(
114+
method="POST",
115+
url="http://test-gateway.com/models/default_model/tasks/process",
103116
data="some text",
104117
json=None,
105118
files=None,
@@ -124,7 +137,7 @@ async def test_submit_task_http_error(mock_httpx_async_client):
124137
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
125138
"Bad Request", request=httpx.Request("POST", "url"), response=httpx.Response(400)
126139
)
127-
mock_client_instance.post.return_value = mock_response
140+
mock_client_instance.request.return_value = mock_response
128141

129142
with pytest.raises(httpx.HTTPStatusError):
130143
async with GatewayClient(base_url="http://test-gateway.com") as client:
@@ -137,7 +150,8 @@ async def test_submit_task_wait_for_completion_and_return_result(mock_httpx_asyn
137150
_, mock_client_instance = mock_httpx_async_client
138151
mock_response = MagicMock()
139152
mock_response.json.return_value = {"uuid": "task-123", "status": "pending"}
140-
mock_client_instance.post.return_value = mock_response
153+
mock_response.raise_for_status.return_value = mock_response
154+
mock_client_instance.request.return_value = mock_response
141155

142156
mock_wait_for_task = mocker.patch(
143157
"client.cogstack_model_gateway_client.client.GatewayClient.wait_for_task",
@@ -220,14 +234,20 @@ async def test_get_task_success(mock_httpx_async_client):
220234
mock_response = MagicMock()
221235
mock_response.raise_for_status.return_value = mock_response
222236
mock_response.json.return_value = {"uuid": "task-123", "status": "succeeded"}
223-
mock_client_instance.get.return_value = mock_response
237+
mock_client_instance.request.return_value = mock_response
224238

225239
async with GatewayClient(base_url="http://test-gateway.com") as client:
226240
task_info = await client.get_task("task-123")
227241

228242
assert task_info == {"uuid": "task-123", "status": "succeeded"}
229-
mock_client_instance.get.assert_awaited_once_with(
230-
"http://test-gateway.com/tasks/task-123", params={"detail": True, "download": False}
243+
mock_client_instance.request.assert_awaited_once_with(
244+
method="GET",
245+
url="http://test-gateway.com/tasks/task-123",
246+
params={"detail": True, "download": False},
247+
data=None,
248+
json=None,
249+
files=None,
250+
headers=None,
231251
)
232252
mock_response.raise_for_status.assert_called_once()
233253

@@ -239,7 +259,7 @@ async def test_get_task_result_json(mock_httpx_async_client):
239259
mock_response = MagicMock()
240260
mock_response.content = b'{"key": "value"}'
241261
mock_response.raise_for_status.return_value = mock_response
242-
mock_client_instance.get.return_value = mock_response
262+
mock_client_instance.request.return_value = mock_response
243263

244264
async with GatewayClient(base_url="http://test-gateway.com") as client:
245265
result = await client.get_task_result("task-123")
@@ -253,7 +273,7 @@ async def test_get_task_result_jsonl(mock_httpx_async_client):
253273
mock_response = MagicMock()
254274
mock_response.content = b'{"item": 1}\n{"item": 2}\n'
255275
mock_response.raise_for_status.return_value = mock_response
256-
mock_client_instance.get.return_value = mock_response
276+
mock_client_instance.request.return_value = mock_response
257277

258278
async with GatewayClient(base_url="http://test-gateway.com") as client:
259279
result = await client.get_task_result("task-123")
@@ -267,7 +287,7 @@ async def test_get_task_result_text(mock_httpx_async_client):
267287
mock_response = MagicMock()
268288
mock_response.content = b"plain text result"
269289
mock_response.raise_for_status.return_value = mock_response
270-
mock_client_instance.get.return_value = mock_response
290+
mock_client_instance.request.return_value = mock_response
271291

272292
async with GatewayClient(base_url="http://test-gateway.com") as client:
273293
result = await client.get_task_result("task-123")
@@ -281,7 +301,7 @@ async def test_get_task_result_binary(mock_httpx_async_client):
281301
mock_response = MagicMock()
282302
mock_response.content = b"\x80\x01\x02\x03" # Example binary data
283303
mock_response.raise_for_status.return_value = mock_response
284-
mock_client_instance.get.return_value = mock_response
304+
mock_client_instance.request.return_value = mock_response
285305

286306
async with GatewayClient(base_url="http://test-gateway.com") as client:
287307
result = await client.get_task_result("task-123")
@@ -295,7 +315,7 @@ async def test_get_task_result_no_parse(mock_httpx_async_client):
295315
mock_response = MagicMock()
296316
mock_response.content = b'{"key": "value"}'
297317
mock_response.raise_for_status.return_value = mock_response
298-
mock_client_instance.get.return_value = mock_response
318+
mock_client_instance.request.return_value = mock_response
299319

300320
async with GatewayClient(base_url="http://test-gateway.com") as client:
301321
result = await client.get_task_result("task-123", parse=False)
@@ -334,7 +354,7 @@ async def test_wait_for_task_timeout(mock_httpx_async_client, mocker):
334354

335355
async with GatewayClient(base_url="http://test-gateway.com") as client:
336356
client.timeout = 0.05
337-
client.polling_interval = 0.01
357+
client.polling_interval = 0.5
338358

339359
with pytest.raises(
340360
TimeoutError, match="Timed out waiting for task 'task-polling' to complete"
@@ -394,13 +414,20 @@ async def test_get_models_success(mock_httpx_async_client):
394414
_, mock_client_instance = mock_httpx_async_client
395415
mock_response = MagicMock()
396416
mock_response.json.return_value = ["model_a", "model_b"]
397-
mock_client_instance.get.return_value = mock_response
417+
mock_response.raise_for_status.return_value = mock_response
418+
mock_client_instance.request.return_value = mock_response
398419

399420
async with GatewayClient(base_url="http://test-gateway.com") as client:
400421
models = await client.get_models()
401422
assert models == ["model_a", "model_b"]
402-
mock_client_instance.get.assert_awaited_once_with(
403-
"http://test-gateway.com/models/", params={"verbose": False}
423+
mock_client_instance.request.assert_awaited_once_with(
424+
method="GET",
425+
url="http://test-gateway.com/models/",
426+
params={"verbose": False},
427+
data=None,
428+
json=None,
429+
files=None,
430+
headers=None,
404431
)
405432

406433

@@ -410,13 +437,20 @@ async def test_get_model_success(mock_httpx_async_client):
410437
_, mock_client_instance = mock_httpx_async_client
411438
mock_response = MagicMock()
412439
mock_response.json.return_value = {"name": "my_model", "status": "deployed"}
413-
mock_client_instance.get.return_value = mock_response
440+
mock_response.raise_for_status.return_value = mock_response
441+
mock_client_instance.request.return_value = mock_response
414442

415443
async with GatewayClient(base_url="http://test-gateway.com") as client:
416444
model_info = await client.get_model(model_name="my_model")
417445
assert model_info == {"name": "my_model", "status": "deployed"}
418-
mock_client_instance.get.assert_awaited_once_with(
419-
"http://test-gateway.com/models/my_model/info"
446+
mock_client_instance.request.assert_awaited_once_with(
447+
method="GET",
448+
url="http://test-gateway.com/models/my_model/info",
449+
params=None,
450+
data=None,
451+
json=None,
452+
files=None,
453+
headers=None,
420454
)
421455

422456

@@ -426,15 +460,22 @@ async def test_get_model_with_default_model(mock_httpx_async_client):
426460
_, mock_client_instance = mock_httpx_async_client
427461
mock_response = MagicMock()
428462
mock_response.json.return_value = {"name": "default_model", "status": "deployed"}
429-
mock_client_instance.get.return_value = mock_response
463+
mock_response.raise_for_status.return_value = mock_response
464+
mock_client_instance.request.return_value = mock_response
430465

431466
async with GatewayClient(
432467
base_url="http://test-gateway.com", default_model="default_model"
433468
) as client:
434469
model_info = await client.get_model()
435470
assert model_info == {"name": "default_model", "status": "deployed"}
436-
mock_client_instance.get.assert_awaited_once_with(
437-
"http://test-gateway.com/models/default_model/info"
471+
mock_client_instance.request.assert_awaited_once_with(
472+
method="GET",
473+
url="http://test-gateway.com/models/default_model/info",
474+
params=None,
475+
data=None,
476+
json=None,
477+
files=None,
478+
headers=None,
438479
)
439480

440481

@@ -452,7 +493,8 @@ async def test_deploy_model_success(mock_httpx_async_client):
452493
_, mock_client_instance = mock_httpx_async_client
453494
mock_response = MagicMock()
454495
mock_response.json.return_value = {"name": "new_model", "status": "deploying"}
455-
mock_client_instance.post.return_value = mock_response
496+
mock_response.raise_for_status.return_value = mock_response
497+
mock_client_instance.request.return_value = mock_response
456498

457499
async with GatewayClient(base_url="http://test-gateway.com") as client:
458500
deploy_info = await client.deploy_model(
@@ -462,11 +504,16 @@ async def test_deploy_model_success(mock_httpx_async_client):
462504
)
463505

464506
assert deploy_info == {"name": "new_model", "status": "deploying"}
465-
mock_client_instance.post.assert_awaited_once_with(
466-
"http://test-gateway.com/models/new_model",
507+
mock_client_instance.request.assert_awaited_once_with(
508+
method="POST",
509+
url="http://test-gateway.com/models/new_model",
467510
json={
468511
"tracking_id": None,
469512
"model_uri": "mlflow-artifacts:/1/runidabcd1234/artifacts/new_model",
470513
"ttl": 3600,
471514
},
515+
params=None,
516+
data=None,
517+
files=None,
518+
headers=None,
472519
)

0 commit comments

Comments
 (0)