5
5
6
6
import httpx
7
7
8
- from cogstack_model_gateway_client .exceptions import retry_if_network_error
8
+ from cogstack_model_gateway_client .exceptions import TaskFailedError , retry_if_network_error
9
9
10
10
11
11
class GatewayClient :
@@ -103,7 +103,11 @@ async def submit_task(
103
103
wait_for_completion : bool = False ,
104
104
return_result : bool = True ,
105
105
):
106
- """Submit a task to the Gateway and return the task info."""
106
+ """Submit a task to the Gateway and return the task info.
107
+
108
+ Raises:
109
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
110
+ """
107
111
model_name = model_name or self .default_model
108
112
if not model_name :
109
113
raise ValueError ("Please provide a model name or set a default model for the client." )
@@ -118,7 +122,11 @@ async def submit_task(
118
122
task_uuid = task_info ["uuid" ]
119
123
task_info = await self .wait_for_task (task_uuid )
120
124
if return_result :
121
- return await self .get_task_result (task_uuid )
125
+ if task_info .get ("status" ) == "succeeded" :
126
+ return await self .get_task_result (task_uuid )
127
+ else :
128
+ error_message = task_info .get ("error_message" , "Unknown error" )
129
+ raise TaskFailedError (task_uuid , error_message , task_info )
122
130
return task_info
123
131
124
132
async def process (
@@ -128,7 +136,11 @@ async def process(
128
136
wait_for_completion : bool = True ,
129
137
return_result : bool = True ,
130
138
):
131
- """Generate annotations for the provided text."""
139
+ """Generate annotations for the provided text.
140
+
141
+ Raises:
142
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
143
+ """
132
144
return await self .submit_task (
133
145
model_name = model_name ,
134
146
task = "process" ,
@@ -145,7 +157,11 @@ async def process_bulk(
145
157
wait_for_completion : bool = True ,
146
158
return_result : bool = True ,
147
159
):
148
- """Generate annotations for a list of texts."""
160
+ """Generate annotations for a list of texts.
161
+
162
+ Raises:
163
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
164
+ """
149
165
return await self .submit_task (
150
166
model_name = model_name ,
151
167
task = "process_bulk" ,
@@ -166,7 +182,11 @@ async def redact(
166
182
wait_for_completion : bool = True ,
167
183
return_result : bool = True ,
168
184
):
169
- """Redact sensitive information from the provided text."""
185
+ """Redact sensitive information from the provided text.
186
+
187
+ Raises:
188
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
189
+ """
170
190
params = {
171
191
k : v
172
192
for k , v in {
@@ -238,15 +258,20 @@ async def get_task_result(self, task_uuid: str, parse: bool = True):
238
258
async def wait_for_task (
239
259
self , task_uuid : str , detail : bool = True , raise_on_error : bool = False
240
260
):
241
- """Poll Gateway until the task reaches a final state."""
261
+ """Poll Gateway until the task reaches a final state.
262
+
263
+ Raises:
264
+ TaskFailedError: If raise_on_error=True and the task fails.
265
+ TimeoutError: If timeout is reached before task completion.
266
+ """
242
267
start = asyncio .get_event_loop ().time ()
243
268
while True :
244
269
task = await self .get_task (task_uuid , detail = detail )
245
270
status = task .get ("status" )
246
271
if status in ("succeeded" , "failed" ):
247
272
if status == "failed" and raise_on_error :
248
273
error_message = task .get ("error_message" , "Unknown error" )
249
- raise RuntimeError ( f"Task ' { task_uuid } ' failed: { error_message } " )
274
+ raise TaskFailedError ( task_uuid , error_message , task )
250
275
return task
251
276
if self .timeout is not None and asyncio .get_event_loop ().time () - start > self .timeout :
252
277
raise TimeoutError (f"Timed out waiting for task '{ task_uuid } ' to complete" )
@@ -365,7 +390,11 @@ def submit_task(
365
390
wait_for_completion : bool = False ,
366
391
return_result : bool = True ,
367
392
):
368
- """Submit a task to the Gateway and return the task info."""
393
+ """Submit a task to the Gateway and return the task info.
394
+
395
+ Raises:
396
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
397
+ """
369
398
return asyncio .run (
370
399
self ._client .submit_task (
371
400
model_name = model_name ,
@@ -387,7 +416,11 @@ def process(
387
416
wait_for_completion : bool = True ,
388
417
return_result : bool = True ,
389
418
):
390
- """Generate annotations for the provided text."""
419
+ """Generate annotations for the provided text.
420
+
421
+ Raises:
422
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
423
+ """
391
424
return asyncio .run (
392
425
self ._client .process (
393
426
text = text ,
@@ -404,7 +437,11 @@ def process_bulk(
404
437
wait_for_completion : bool = True ,
405
438
return_result : bool = True ,
406
439
):
407
- """Generate annotations for a list of texts."""
440
+ """Generate annotations for a list of texts.
441
+
442
+ Raises:
443
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
444
+ """
408
445
return asyncio .run (
409
446
self ._client .process_bulk (
410
447
texts = texts ,
@@ -425,7 +462,11 @@ def redact(
425
462
wait_for_completion : bool = True ,
426
463
return_result : bool = True ,
427
464
):
428
- """Redact sensitive information from the provided text."""
465
+ """Redact sensitive information from the provided text.
466
+
467
+ Raises:
468
+ TaskFailedError: If the task fails and wait_for_completion=True, return_result=True.
469
+ """
429
470
return asyncio .run (
430
471
self ._client .redact (
431
472
text = text ,
@@ -452,7 +493,12 @@ def get_task_result(self, task_uuid: str, parse: bool = True):
452
493
return asyncio .run (self ._client .get_task_result (task_uuid = task_uuid , parse = parse ))
453
494
454
495
def wait_for_task (self , task_uuid : str , detail : bool = True , raise_on_error : bool = False ):
455
- """Poll Gateway until the task reaches a final state."""
496
+ """Poll Gateway until the task reaches a final state.
497
+
498
+ Raises:
499
+ TaskFailedError: If raise_on_error=True and the task fails.
500
+ TimeoutError: If timeout is reached before task completion.
501
+ """
456
502
return asyncio .run (
457
503
self ._client .wait_for_task (
458
504
task_uuid = task_uuid , detail = detail , raise_on_error = raise_on_error
0 commit comments