|
1 | 1 | import asyncio
|
2 | 2 | import json
|
| 3 | +import threading |
| 4 | +import time |
3 | 5 | from collections.abc import Iterable
|
4 | 6 | from functools import wraps
|
5 | 7 |
|
@@ -128,7 +130,7 @@ async def process(
|
128 | 130 | wait_for_completion: bool = True,
|
129 | 131 | return_result: bool = True,
|
130 | 132 | ):
|
131 |
| - """Generate annotations for the the provided text.""" |
| 133 | + """Generate annotations for the provided text.""" |
132 | 134 | return await self.submit_task(
|
133 | 135 | model_name=model_name,
|
134 | 136 | task="process",
|
@@ -290,9 +292,10 @@ async def deploy_model(
|
290 | 292 | class GatewayClientSync:
|
291 | 293 | def __init__(self, *args, **kwargs):
|
292 | 294 | self._client = GatewayClient(*args, **kwargs)
|
293 |
| - self._loop = asyncio.new_event_loop() |
294 |
| - asyncio.set_event_loop(self._loop) |
295 |
| - self._loop.run_until_complete(self._client.__aenter__()) |
| 295 | + self._loop = None |
| 296 | + self._own_loop = False |
| 297 | + self._thread = None |
| 298 | + self._setup_event_loop() |
296 | 299 |
|
297 | 300 | @property
|
298 | 301 | def base_url(self):
|
@@ -322,40 +325,204 @@ def timeout(self):
|
322 | 325 | def timeout(self, value: float):
|
323 | 326 | self._client.timeout = value
|
324 | 327 |
|
| 328 | + def _setup_event_loop(self): |
| 329 | + """Set up event loop handling for sync operations.""" |
| 330 | + try: |
| 331 | + existing_loop = asyncio.get_running_loop() |
| 332 | + self._loop = existing_loop |
| 333 | + self._own_loop = False |
| 334 | + self._setup_background_loop() |
| 335 | + except RuntimeError: |
| 336 | + self._loop = asyncio.new_event_loop() |
| 337 | + asyncio.set_event_loop(self._loop) |
| 338 | + self._own_loop = True |
| 339 | + self._loop.run_until_complete(self._client.__aenter__()) |
| 340 | + |
| 341 | + def _setup_background_loop(self): |
| 342 | + """Set up a background thread with persistent event loop and client.""" |
| 343 | + |
| 344 | + def run_background_loop(): |
| 345 | + self._background_loop = asyncio.new_event_loop() |
| 346 | + asyncio.set_event_loop(self._background_loop) |
| 347 | + |
| 348 | + self._background_loop.run_until_complete(self._client.__aenter__()) |
| 349 | + self._background_loop.run_forever() |
| 350 | + |
| 351 | + self._thread = threading.Thread(target=run_background_loop, daemon=True) |
| 352 | + self._thread.start() |
| 353 | + |
| 354 | + start_time, timeout = time.time(), 5.0 |
| 355 | + while not hasattr(self, "_background_loop") and time.time() - start_time < timeout: |
| 356 | + time.sleep(0.01) |
| 357 | + |
| 358 | + if not hasattr(self, "_background_loop"): |
| 359 | + raise RuntimeError("Failed to start background event loop") |
| 360 | + |
| 361 | + def _run_async(self, coro): |
| 362 | + """Run an async coroutine, handling different event loop scenarios.""" |
| 363 | + if self._own_loop: |
| 364 | + return self._loop.run_until_complete(coro) |
| 365 | + else: |
| 366 | + future = asyncio.run_coroutine_threadsafe(coro, self._background_loop) |
| 367 | + return future.result(timeout=60) |
| 368 | + |
325 | 369 | def __del__(self):
|
326 | 370 | try:
|
327 |
| - if hasattr(self, "_client") and self._client and self._client._client is not None: |
328 |
| - self._loop.run_until_complete(self._client.__aexit__(None, None, None)) |
329 |
| - self._loop.close() |
| 371 | + if not hasattr(self, "_client") or not self._client: |
| 372 | + return |
| 373 | + |
| 374 | + if self._own_loop and self._client._client is not None: |
| 375 | + try: |
| 376 | + self._loop.run_until_complete(self._client.__aexit__(None, None, None)) |
| 377 | + self._loop.close() |
| 378 | + except Exception: |
| 379 | + pass |
| 380 | + elif hasattr(self, "_background_loop") and hasattr(self, "_thread"): |
| 381 | + if self._background_loop and not self._background_loop.is_closed(): |
| 382 | + try: |
| 383 | + future = asyncio.run_coroutine_threadsafe( |
| 384 | + self._client.__aexit__(None, None, None), self._background_loop |
| 385 | + ) |
| 386 | + future.result(timeout=1.0) |
| 387 | + except Exception: |
| 388 | + # If cleanup fails, we can't do much about it in __del__ |
| 389 | + pass |
| 390 | + try: |
| 391 | + self._background_loop.call_soon_threadsafe(self._background_loop.stop) |
| 392 | + except Exception: |
| 393 | + pass |
| 394 | + if self._thread and self._thread.is_alive(): |
| 395 | + self._thread.join(timeout=1.0) |
330 | 396 | except Exception:
|
331 | 397 | pass
|
332 | 398 |
|
333 |
| - def submit_task(self, *args, **kwargs): |
334 |
| - return self._loop.run_until_complete(self._client.submit_task(*args, **kwargs)) |
| 399 | + def submit_task( |
| 400 | + self, |
| 401 | + model_name: str = None, |
| 402 | + task: str = None, |
| 403 | + data=None, |
| 404 | + json=None, |
| 405 | + files=None, |
| 406 | + params=None, |
| 407 | + headers=None, |
| 408 | + wait_for_completion: bool = False, |
| 409 | + return_result: bool = True, |
| 410 | + ): |
| 411 | + """Submit a task to the Gateway and return the task info.""" |
| 412 | + return self._run_async( |
| 413 | + self._client.submit_task( |
| 414 | + model_name=model_name, |
| 415 | + task=task, |
| 416 | + data=data, |
| 417 | + json=json, |
| 418 | + files=files, |
| 419 | + params=params, |
| 420 | + headers=headers, |
| 421 | + wait_for_completion=wait_for_completion, |
| 422 | + return_result=return_result, |
| 423 | + ) |
| 424 | + ) |
| 425 | + |
| 426 | + def process( |
| 427 | + self, |
| 428 | + text: str, |
| 429 | + model_name: str = None, |
| 430 | + wait_for_completion: bool = True, |
| 431 | + return_result: bool = True, |
| 432 | + ): |
| 433 | + """Generate annotations for the provided text.""" |
| 434 | + return self._run_async( |
| 435 | + self._client.process( |
| 436 | + text=text, |
| 437 | + model_name=model_name, |
| 438 | + wait_for_completion=wait_for_completion, |
| 439 | + return_result=return_result, |
| 440 | + ) |
| 441 | + ) |
335 | 442 |
|
336 |
| - def process(self, *args, **kwargs): |
337 |
| - return self._loop.run_until_complete(self._client.process(*args, **kwargs)) |
| 443 | + def process_bulk( |
| 444 | + self, |
| 445 | + texts: list[str], |
| 446 | + model_name: str = None, |
| 447 | + wait_for_completion: bool = True, |
| 448 | + return_result: bool = True, |
| 449 | + ): |
| 450 | + """Generate annotations for a list of texts.""" |
| 451 | + return self._run_async( |
| 452 | + self._client.process_bulk( |
| 453 | + texts=texts, |
| 454 | + model_name=model_name, |
| 455 | + wait_for_completion=wait_for_completion, |
| 456 | + return_result=return_result, |
| 457 | + ) |
| 458 | + ) |
338 | 459 |
|
339 |
| - def process_bulk(self, *args, **kwargs): |
340 |
| - return self._loop.run_until_complete(self._client.process_bulk(*args, **kwargs)) |
| 460 | + def redact( |
| 461 | + self, |
| 462 | + text: str, |
| 463 | + concepts_to_keep: Iterable[str] = None, |
| 464 | + warn_on_no_redaction: bool = None, |
| 465 | + mask: str = None, |
| 466 | + hash: bool = None, |
| 467 | + model_name: str = None, |
| 468 | + wait_for_completion: bool = True, |
| 469 | + return_result: bool = True, |
| 470 | + ): |
| 471 | + """Redact sensitive information from the provided text.""" |
| 472 | + return self._run_async( |
| 473 | + self._client.redact( |
| 474 | + text=text, |
| 475 | + concepts_to_keep=concepts_to_keep, |
| 476 | + warn_on_no_redaction=warn_on_no_redaction, |
| 477 | + mask=mask, |
| 478 | + hash=hash, |
| 479 | + model_name=model_name, |
| 480 | + wait_for_completion=wait_for_completion, |
| 481 | + return_result=return_result, |
| 482 | + ) |
| 483 | + ) |
341 | 484 |
|
342 |
| - def redact(self, *args, **kwargs): |
343 |
| - return self._loop.run_until_complete(self._client.redact(*args, **kwargs)) |
| 485 | + def get_task(self, task_uuid: str, detail: bool = True): |
| 486 | + """Get a Gateway task details by its UUID.""" |
| 487 | + return self._run_async(self._client.get_task(task_uuid=task_uuid, detail=detail)) |
344 | 488 |
|
345 |
| - def get_task(self, *args, **kwargs): |
346 |
| - return self._loop.run_until_complete(self._client.get_task(*args, **kwargs)) |
| 489 | + def get_task_result(self, task_uuid: str, parse: bool = True): |
| 490 | + """Get the result of a Gateway task by its UUID. |
347 | 491 |
|
348 |
| - def get_task_result(self, *args, **kwargs): |
349 |
| - return self._loop.run_until_complete(self._client.get_task_result(*args, **kwargs)) |
| 492 | + If parse is True, try to infer and parse the result as JSON, JSONL, or text. |
| 493 | + Otherwise, return raw bytes. |
| 494 | + """ |
| 495 | + return self._run_async(self._client.get_task_result(task_uuid=task_uuid, parse=parse)) |
350 | 496 |
|
351 |
| - def wait_for_task(self, *args, **kwargs): |
352 |
| - return self._loop.run_until_complete(self._client.wait_for_task(*args, **kwargs)) |
| 497 | + def wait_for_task(self, task_uuid: str, detail: bool = True, raise_on_error: bool = False): |
| 498 | + """Poll Gateway until the task reaches a final state.""" |
| 499 | + return self._run_async( |
| 500 | + self._client.wait_for_task( |
| 501 | + task_uuid=task_uuid, detail=detail, raise_on_error=raise_on_error |
| 502 | + ) |
| 503 | + ) |
353 | 504 |
|
354 |
| - def get_models(self, *args, **kwargs): |
355 |
| - return self._loop.run_until_complete(self._client.get_models(*args, **kwargs)) |
| 505 | + def get_models(self, verbose: bool = False): |
| 506 | + """Get the list of available models from the Gateway.""" |
| 507 | + return self._run_async(self._client.get_models(verbose=verbose)) |
356 | 508 |
|
357 |
| - def get_model(self, *args, **kwargs): |
358 |
| - return self._loop.run_until_complete(self._client.get_model(*args, **kwargs)) |
| 509 | + def get_model(self, model_name: str = None): |
| 510 | + """Get details of a specific model.""" |
| 511 | + return self._run_async(self._client.get_model(model_name=model_name)) |
359 | 512 |
|
360 |
| - def deploy_model(self, *args, **kwargs): |
361 |
| - return self._loop.run_until_complete(self._client.deploy_model(*args, **kwargs)) |
| 513 | + def deploy_model( |
| 514 | + self, |
| 515 | + model_name: str = None, |
| 516 | + tracking_id: str = None, |
| 517 | + model_uri: str = None, |
| 518 | + ttl: int = None, |
| 519 | + ): |
| 520 | + """Deploy a CogStack Model Serve model through the Gateway.""" |
| 521 | + return self._run_async( |
| 522 | + self._client.deploy_model( |
| 523 | + model_name=model_name, |
| 524 | + tracking_id=tracking_id, |
| 525 | + model_uri=model_uri, |
| 526 | + ttl=ttl, |
| 527 | + ) |
| 528 | + ) |
0 commit comments