Skip to content

Commit 6632a09

Browse files
committed
client: Raise error when returning result of failed task
Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 499f4ab commit 6632a09

File tree

5 files changed

+114
-35
lines changed

5 files changed

+114
-35
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from cogstack_model_gateway_client.client import GatewayClient as GatewayClient
22
from cogstack_model_gateway_client.client import GatewayClientSync as GatewayClientSync
3+
from cogstack_model_gateway_client.exceptions import TaskFailedError as TaskFailedError

client/cogstack_model_gateway_client/client.py

Lines changed: 59 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import httpx
77

8-
from cogstack_model_gateway_client.exceptions import retry_if_network_error
8+
from cogstack_model_gateway_client.exceptions import TaskFailedError, retry_if_network_error
99

1010

1111
class GatewayClient:
@@ -103,7 +103,11 @@ async def submit_task(
103103
wait_for_completion: bool = False,
104104
return_result: bool = True,
105105
):
106-
"""Submit a task to the Gateway and return the task info."""
106+
"""Submit a task to the Gateway and return the task info.
107+
108+
Raises:
109+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
110+
"""
107111
model_name = model_name or self.default_model
108112
if not model_name:
109113
raise ValueError("Please provide a model name or set a default model for the client.")
@@ -118,7 +122,11 @@ async def submit_task(
118122
task_uuid = task_info["uuid"]
119123
task_info = await self.wait_for_task(task_uuid)
120124
if return_result:
121-
return await self.get_task_result(task_uuid)
125+
if task_info.get("status") == "succeeded":
126+
return await self.get_task_result(task_uuid)
127+
else:
128+
error_message = task_info.get("error_message", "Unknown error")
129+
raise TaskFailedError(task_uuid, error_message, task_info)
122130
return task_info
123131

124132
async def process(
@@ -128,7 +136,11 @@ async def process(
128136
wait_for_completion: bool = True,
129137
return_result: bool = True,
130138
):
131-
"""Generate annotations for the provided text."""
139+
"""Generate annotations for the provided text.
140+
141+
Raises:
142+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
143+
"""
132144
return await self.submit_task(
133145
model_name=model_name,
134146
task="process",
@@ -145,7 +157,11 @@ async def process_bulk(
145157
wait_for_completion: bool = True,
146158
return_result: bool = True,
147159
):
148-
"""Generate annotations for a list of texts."""
160+
"""Generate annotations for a list of texts.
161+
162+
Raises:
163+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
164+
"""
149165
return await self.submit_task(
150166
model_name=model_name,
151167
task="process_bulk",
@@ -166,7 +182,11 @@ async def redact(
166182
wait_for_completion: bool = True,
167183
return_result: bool = True,
168184
):
169-
"""Redact sensitive information from the provided text."""
185+
"""Redact sensitive information from the provided text.
186+
187+
Raises:
188+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
189+
"""
170190
params = {
171191
k: v
172192
for k, v in {
@@ -238,15 +258,20 @@ async def get_task_result(self, task_uuid: str, parse: bool = True):
238258
async def wait_for_task(
239259
self, task_uuid: str, detail: bool = True, raise_on_error: bool = False
240260
):
241-
"""Poll Gateway until the task reaches a final state."""
261+
"""Poll Gateway until the task reaches a final state.
262+
263+
Raises:
264+
TaskFailedError: If raise_on_error=True and the task fails.
265+
TimeoutError: If timeout is reached before task completion.
266+
"""
242267
start = asyncio.get_event_loop().time()
243268
while True:
244269
task = await self.get_task(task_uuid, detail=detail)
245270
status = task.get("status")
246271
if status in ("succeeded", "failed"):
247272
if status == "failed" and raise_on_error:
248273
error_message = task.get("error_message", "Unknown error")
249-
raise RuntimeError(f"Task '{task_uuid}' failed: {error_message}")
274+
raise TaskFailedError(task_uuid, error_message, task)
250275
return task
251276
if self.timeout is not None and asyncio.get_event_loop().time() - start > self.timeout:
252277
raise TimeoutError(f"Timed out waiting for task '{task_uuid}' to complete")
@@ -365,7 +390,11 @@ def submit_task(
365390
wait_for_completion: bool = False,
366391
return_result: bool = True,
367392
):
368-
"""Submit a task to the Gateway and return the task info."""
393+
"""Submit a task to the Gateway and return the task info.
394+
395+
Raises:
396+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
397+
"""
369398
return asyncio.run(
370399
self._client.submit_task(
371400
model_name=model_name,
@@ -387,7 +416,11 @@ def process(
387416
wait_for_completion: bool = True,
388417
return_result: bool = True,
389418
):
390-
"""Generate annotations for the provided text."""
419+
"""Generate annotations for the provided text.
420+
421+
Raises:
422+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
423+
"""
391424
return asyncio.run(
392425
self._client.process(
393426
text=text,
@@ -404,7 +437,11 @@ def process_bulk(
404437
wait_for_completion: bool = True,
405438
return_result: bool = True,
406439
):
407-
"""Generate annotations for a list of texts."""
440+
"""Generate annotations for a list of texts.
441+
442+
Raises:
443+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
444+
"""
408445
return asyncio.run(
409446
self._client.process_bulk(
410447
texts=texts,
@@ -425,7 +462,11 @@ def redact(
425462
wait_for_completion: bool = True,
426463
return_result: bool = True,
427464
):
428-
"""Redact sensitive information from the provided text."""
465+
"""Redact sensitive information from the provided text.
466+
467+
Raises:
468+
TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
469+
"""
429470
return asyncio.run(
430471
self._client.redact(
431472
text=text,
@@ -452,7 +493,12 @@ def get_task_result(self, task_uuid: str, parse: bool = True):
452493
return asyncio.run(self._client.get_task_result(task_uuid=task_uuid, parse=parse))
453494

454495
def wait_for_task(self, task_uuid: str, detail: bool = True, raise_on_error: bool = False):
455-
"""Poll Gateway until the task reaches a final state."""
496+
"""Poll Gateway until the task reaches a final state.
497+
498+
Raises:
499+
TaskFailedError: If raise_on_error=True and the task fails.
500+
TimeoutError: If timeout is reached before task completion.
501+
"""
456502
return asyncio.run(
457503
self._client.wait_for_task(
458504
task_uuid=task_uuid, detail=detail, raise_on_error=raise_on_error

client/cogstack_model_gateway_client/exceptions.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,22 @@
1212
log = logging.getLogger("cmg.client")
1313

1414

15+
class TaskFailedError(Exception):
16+
"""Raised when a Gateway task fails during execution.
17+
18+
Attributes:
19+
task_uuid: The UUID of the failed task
20+
error_message: The error message from the task
21+
task_info: The full task information dict (optional)
22+
"""
23+
24+
def __init__(self, task_uuid: str, error_message: str, task_info: dict = None):
25+
self.task_uuid = task_uuid
26+
self.error_message = error_message
27+
self.task_info = task_info
28+
super().__init__(f"Task '{task_uuid}' failed: {error_message}")
29+
30+
1531
def is_network_error(exception: Exception):
1632
"""Check if the exception is a network-related error."""
1733
return isinstance(

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@ line-length = 100
5959
target-version = "py312"
6060

6161
[tool.ruff.lint.isort]
62-
known-local-folder = ["common", "gateway", "scheduler", "client"]
62+
known-local-folder = ["common", "gateway", "scheduler", "cogstack_model_gateway_client"]
6363

6464
[tool.ruff.lint]
6565
select = ["E", "F", "I", "UP"]
6666

6767
[tool.pytest.ini_options]
6868
addopts = "-ra -s --disable-warnings --enable-cmg-logging"
69-
pythonpath = ["."]
69+
pythonpath = [".", "client"]
7070
testpaths = ["tests"]
7171

7272
[build-system]

tests/unit/client/test_client.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import httpx
99
import pytest
1010

11-
from client.cogstack_model_gateway_client.client import GatewayClient, GatewayClientSync
11+
from cogstack_model_gateway_client.client import GatewayClient, GatewayClientSync
12+
from cogstack_model_gateway_client.exceptions import TaskFailedError
1213

1314

1415
@pytest.fixture
@@ -159,11 +160,11 @@ async def test_submit_task_wait_for_completion_and_return_result(mock_httpx_asyn
159160
mock_client_instance.request.return_value = mock_response
160161

161162
mock_wait_for_task = mocker.patch(
162-
"client.cogstack_model_gateway_client.client.GatewayClient.wait_for_task",
163+
"cogstack_model_gateway_client.client.GatewayClient.wait_for_task",
163164
new=AsyncMock(return_value={"uuid": "task-123", "status": "succeeded"}),
164165
)
165166
mock_get_task_result = mocker.patch(
166-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task_result",
167+
"cogstack_model_gateway_client.client.GatewayClient.get_task_result",
167168
new=AsyncMock(return_value="processed text"),
168169
)
169170

@@ -180,6 +181,29 @@ async def test_submit_task_wait_for_completion_and_return_result(mock_httpx_asyn
180181
mock_wait_for_task.assert_awaited_once_with("task-123")
181182
mock_get_task_result.assert_awaited_once_with("task-123")
182183

184+
mock_wait_for_task = mocker.patch(
185+
"cogstack_model_gateway_client.client.GatewayClient.wait_for_task",
186+
new=AsyncMock(
187+
return_value={
188+
"uuid": "task-123",
189+
"status": "failed",
190+
"error_message": "Processing failed",
191+
}
192+
),
193+
)
194+
195+
async with GatewayClient(base_url="http://test-gateway.com") as client:
196+
with pytest.raises(TaskFailedError, match="Task 'task-123' failed: Processing failed"):
197+
await client.submit_task(
198+
model_name="my_model",
199+
task="process",
200+
data="text",
201+
wait_for_completion=True,
202+
return_result=True,
203+
)
204+
205+
mock_wait_for_task.assert_awaited_once_with("task-123")
206+
183207

184208
@pytest.mark.asyncio
185209
async def test_process_method(mocker):
@@ -335,9 +359,7 @@ async def test_wait_for_task_succeeded(mock_httpx_async_client, mocker):
335359
{"uuid": "task-polling", "status": "pending"},
336360
{"uuid": "task-polling", "status": "succeeded", "result": "done"},
337361
]
338-
mocker.patch(
339-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
340-
)
362+
mocker.patch("cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task)
341363
mocker.patch("asyncio.sleep", new=AsyncMock())
342364

343365
async with GatewayClient(base_url="http://test-gateway.com") as client:
@@ -352,9 +374,7 @@ async def test_wait_for_task_succeeded(mock_httpx_async_client, mocker):
352374
async def test_wait_for_task_timeout(mock_httpx_async_client, mocker):
353375
"""Test wait_for_task raises TimeoutError."""
354376
mock_get_task = AsyncMock(return_value={"uuid": "task-polling", "status": "pending"})
355-
mocker.patch(
356-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
357-
)
377+
mocker.patch("cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task)
358378
mocker.patch("asyncio.sleep", new=AsyncMock())
359379

360380
async with GatewayClient(base_url="http://test-gateway.com") as client:
@@ -383,9 +403,7 @@ async def mock_get_task_side_effect(*args, **kwargs):
383403
return {"uuid": "task-polling", "status": "succeeded"}
384404

385405
mock_get_task = AsyncMock(side_effect=mock_get_task_side_effect)
386-
mocker.patch(
387-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
388-
)
406+
mocker.patch("cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task)
389407
mocker.patch("asyncio.sleep", new=AsyncMock())
390408

391409
async with GatewayClient(base_url="http://test-gateway.com") as client:
@@ -400,19 +418,19 @@ async def mock_get_task_side_effect(*args, **kwargs):
400418

401419
@pytest.mark.asyncio
402420
async def test_wait_for_task_failed_raise_on_error(mock_httpx_async_client, mocker):
403-
"""Test wait_for_task raises RuntimeError on task failure with raise_on_error."""
421+
"""Test wait_for_task raises TaskFailedError on task failure with raise_on_error."""
404422
mock_get_task = AsyncMock()
405423
mock_get_task.side_effect = [
406424
{"uuid": "task-polling", "status": "pending"},
407425
{"uuid": "task-polling", "status": "failed", "error_message": "Something went wrong"},
408426
]
409-
mocker.patch(
410-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
411-
)
427+
mocker.patch("cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task)
412428
mocker.patch("asyncio.sleep", new=AsyncMock())
413429

414430
async with GatewayClient(base_url="http://test-gateway.com") as client:
415-
with pytest.raises(RuntimeError, match="Task 'task-polling' failed: Something went wrong"):
431+
with pytest.raises(
432+
TaskFailedError, match="Task 'task-polling' failed: Something went wrong"
433+
):
416434
await client.wait_for_task("task-polling", raise_on_error=True)
417435

418436
assert mock_get_task.await_count == 2
@@ -426,9 +444,7 @@ async def test_wait_for_task_failed_no_raise_on_error(mock_httpx_async_client, m
426444
{"uuid": "task-polling", "status": "pending"},
427445
{"uuid": "task-polling", "status": "failed", "error_message": "Something went wrong"},
428446
]
429-
mocker.patch(
430-
"client.cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task
431-
)
447+
mocker.patch("cogstack_model_gateway_client.client.GatewayClient.get_task", new=mock_get_task)
432448
mocker.patch("asyncio.sleep", new=AsyncMock())
433449

434450
async with GatewayClient(base_url="http://test-gateway.com") as client:

0 commit comments

Comments
 (0)