Skip to content

Commit 499f4ab

Browse files
committed
client: Set default timeout to None (no timeout)
Set the default timeout for the GatewayClient to None, indicating no timeout while waiting for task completion. A float value can still be explicitly set to enforce a timeout. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 9dcd51b commit 499f4ab

File tree

2 files changed

+36
-6
lines changed

2 files changed

+36
-6
lines changed

client/cogstack_model_gateway_client/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def __init__(
1414
base_url: str,
1515
default_model: str = None,
1616
polling_interval: float = 2.0,
17-
timeout: float = 300.0,
17+
timeout: float | None = None,
1818
):
1919
"""Initialize the GatewayClient with the base Gateway URL and optional parameters.
2020
@@ -24,9 +24,9 @@ def __init__(
2424
polling_interval (float, optional): The interval in seconds to poll for task completion.
2525
Defaults to 2.0 seconds, with a minimum of 0.5 and maximum of 3.0 seconds.
2626
timeout (float, optional): The client polling timeout while waiting for task completion.
27-
Defaults to 300.0 seconds. A TimeoutError in this case should rarely indicate that
28-
something went wrong, but rather that the task is taking longer than expected (e.g.
29-
common with long running tasks like training).
27+
Defaults to None (no timeout). When set to a float value, a TimeoutError will be
28+
raised if the task takes longer than the specified number of seconds. When None,
29+
the client will wait indefinitely for task completion.
3030
"""
3131
self.base_url = base_url.rstrip("/")
3232
self.default_model = default_model
@@ -248,7 +248,7 @@ async def wait_for_task(
248248
error_message = task.get("error_message", "Unknown error")
249249
raise RuntimeError(f"Task '{task_uuid}' failed: {error_message}")
250250
return task
251-
if asyncio.get_event_loop().time() - start > self.timeout:
251+
if self.timeout is not None and asyncio.get_event_loop().time() - start > self.timeout:
252252
raise TimeoutError(f"Timed out waiting for task '{task_uuid}' to complete")
253253
await asyncio.sleep(self.polling_interval)
254254

@@ -350,7 +350,7 @@ def timeout(self):
350350
return self._client.timeout
351351

352352
@timeout.setter
353-
def timeout(self, value: float):
353+
def timeout(self, value: float | None):
354354
self._client.timeout = value
355355

356356
def submit_task(

tests/unit/client/test_client.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ async def test_gateway_client_init():
4242
polling_interval=10,
4343
)
4444
assert client.polling_interval == 3.0 # Maximum 3.0 seconds
45+
assert client.timeout is None # Default timeout should be None
4546

4647
client.polling_interval = 0.05
4748
assert client.polling_interval == 0.5 # Minimum is 0.5 seconds
@@ -368,6 +369,35 @@ async def test_wait_for_task_timeout(mock_httpx_async_client, mocker):
368369
assert mock_get_task.await_count >= (client.timeout / client.polling_interval)
369370

370371

372+
@pytest.mark.asyncio
373+
async def test_wait_for_task_no_timeout(mock_httpx_async_client, mocker):
374+
"""Test wait_for_task doesn't timeout when timeout is None."""
375+
call_count = 0
376+
377+
async def mock_get_task_side_effect(*args, **kwargs):
378+
nonlocal call_count
379+
call_count += 1
380+
if call_count < 5: # Return pending for first 4 calls
381+
return {"uuid": "task-polling", "status": "pending"}
382+
else: # Return succeeded on 5th call
383+
return {"uuid": "task-polling", "status": "succeeded"}
384+
385+
mock_get_task = AsyncMock(side_effect=mock_get_task_side_effect)
386+
mocker.patch(
387+
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
388+
)
389+
mocker.patch("asyncio.sleep", new=AsyncMock())
390+
391+
async with GatewayClient(base_url="http://test-gateway.com") as client:
392+
assert client.timeout is None
393+
client.polling_interval = 0.01
394+
395+
result = await client.wait_for_task("task-polling")
396+
397+
assert result["status"] == "succeeded"
398+
assert mock_get_task.await_count == 5
399+
400+
371401
@pytest.mark.asyncio
372402
async def test_wait_for_task_failed_raise_on_error(mock_httpx_async_client, mocker):
373403
"""Test wait_for_task raises RuntimeError on task failure with raise_on_error."""

0 commit comments

Comments
 (0)