Skip to content

Commit f850f95

Browse files
committed
client: Fix GatewayClientSync loop handling issues
Revamp the design of GatewayClientSync to ensure it works in any environment. If an event loop is already running in the current thread, it should create a background thread with its own event loop to run the synchronous code. If no event loop is running, it should create and use its own in the current thread, taking care of cleaning up afterwards. Add docstrings and arguments to sync client methods to match their async counterparts and help with development. Signed-off-by: Phoevos Kalemkeris <[email protected]>
1 parent 35b4d1c commit f850f95

File tree

2 files changed

+438
-28
lines changed

2 files changed

+438
-28
lines changed

client/cogstack_model_gateway_client/client.py

Lines changed: 194 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import json
3+
import threading
4+
import time
35
from collections.abc import Iterable
46
from functools import wraps
57

@@ -128,7 +130,7 @@ async def process(
128130
wait_for_completion: bool = True,
129131
return_result: bool = True,
130132
):
131-
"""Generate annotations for the the provided text."""
133+
"""Generate annotations for the provided text."""
132134
return await self.submit_task(
133135
model_name=model_name,
134136
task="process",
@@ -290,9 +292,10 @@ async def deploy_model(
290292
class GatewayClientSync:
291293
def __init__(self, *args, **kwargs):
292294
self._client = GatewayClient(*args, **kwargs)
293-
self._loop = asyncio.new_event_loop()
294-
asyncio.set_event_loop(self._loop)
295-
self._loop.run_until_complete(self._client.__aenter__())
295+
self._loop = None
296+
self._own_loop = False
297+
self._thread = None
298+
self._setup_event_loop()
296299

297300
@property
298301
def base_url(self):
@@ -322,40 +325,204 @@ def timeout(self):
322325
def timeout(self, value: float):
323326
self._client.timeout = value
324327

328+
def _setup_event_loop(self):
329+
"""Set up event loop handling for sync operations."""
330+
try:
331+
existing_loop = asyncio.get_running_loop()
332+
self._loop = existing_loop
333+
self._own_loop = False
334+
self._setup_background_loop()
335+
except RuntimeError:
336+
self._loop = asyncio.new_event_loop()
337+
asyncio.set_event_loop(self._loop)
338+
self._own_loop = True
339+
self._loop.run_until_complete(self._client.__aenter__())
340+
341+
def _setup_background_loop(self):
342+
"""Set up a background thread with persistent event loop and client."""
343+
344+
def run_background_loop():
345+
self._background_loop = asyncio.new_event_loop()
346+
asyncio.set_event_loop(self._background_loop)
347+
348+
self._background_loop.run_until_complete(self._client.__aenter__())
349+
self._background_loop.run_forever()
350+
351+
self._thread = threading.Thread(target=run_background_loop, daemon=True)
352+
self._thread.start()
353+
354+
start_time, timeout = time.time(), 5.0
355+
while not hasattr(self, "_background_loop") and time.time() - start_time < timeout:
356+
time.sleep(0.01)
357+
358+
if not hasattr(self, "_background_loop"):
359+
raise RuntimeError("Failed to start background event loop")
360+
361+
def _run_async(self, coro):
362+
"""Run an async coroutine, handling different event loop scenarios."""
363+
if self._own_loop:
364+
return self._loop.run_until_complete(coro)
365+
else:
366+
future = asyncio.run_coroutine_threadsafe(coro, self._background_loop)
367+
return future.result(timeout=60)
368+
325369
def __del__(self):
326370
try:
327-
if hasattr(self, "_client") and self._client and self._client._client is not None:
328-
self._loop.run_until_complete(self._client.__aexit__(None, None, None))
329-
self._loop.close()
371+
if not hasattr(self, "_client") or not self._client:
372+
return
373+
374+
if self._own_loop and self._client._client is not None:
375+
try:
376+
self._loop.run_until_complete(self._client.__aexit__(None, None, None))
377+
self._loop.close()
378+
except Exception:
379+
pass
380+
elif hasattr(self, "_background_loop") and hasattr(self, "_thread"):
381+
if self._background_loop and not self._background_loop.is_closed():
382+
try:
383+
future = asyncio.run_coroutine_threadsafe(
384+
self._client.__aexit__(None, None, None), self._background_loop
385+
)
386+
future.result(timeout=1.0)
387+
except Exception:
388+
# If cleanup fails, we can't do much about it in __del__
389+
pass
390+
try:
391+
self._background_loop.call_soon_threadsafe(self._background_loop.stop)
392+
except Exception:
393+
pass
394+
if self._thread and self._thread.is_alive():
395+
self._thread.join(timeout=1.0)
330396
except Exception:
331397
pass
332398

333-
def submit_task(self, *args, **kwargs):
334-
return self._loop.run_until_complete(self._client.submit_task(*args, **kwargs))
399+
def submit_task(
400+
self,
401+
model_name: str = None,
402+
task: str = None,
403+
data=None,
404+
json=None,
405+
files=None,
406+
params=None,
407+
headers=None,
408+
wait_for_completion: bool = False,
409+
return_result: bool = True,
410+
):
411+
"""Submit a task to the Gateway and return the task info."""
412+
return self._run_async(
413+
self._client.submit_task(
414+
model_name=model_name,
415+
task=task,
416+
data=data,
417+
json=json,
418+
files=files,
419+
params=params,
420+
headers=headers,
421+
wait_for_completion=wait_for_completion,
422+
return_result=return_result,
423+
)
424+
)
425+
426+
def process(
427+
self,
428+
text: str,
429+
model_name: str = None,
430+
wait_for_completion: bool = True,
431+
return_result: bool = True,
432+
):
433+
"""Generate annotations for the provided text."""
434+
return self._run_async(
435+
self._client.process(
436+
text=text,
437+
model_name=model_name,
438+
wait_for_completion=wait_for_completion,
439+
return_result=return_result,
440+
)
441+
)
335442

336-
def process(self, *args, **kwargs):
337-
return self._loop.run_until_complete(self._client.process(*args, **kwargs))
443+
def process_bulk(
444+
self,
445+
texts: list[str],
446+
model_name: str = None,
447+
wait_for_completion: bool = True,
448+
return_result: bool = True,
449+
):
450+
"""Generate annotations for a list of texts."""
451+
return self._run_async(
452+
self._client.process_bulk(
453+
texts=texts,
454+
model_name=model_name,
455+
wait_for_completion=wait_for_completion,
456+
return_result=return_result,
457+
)
458+
)
338459

339-
def process_bulk(self, *args, **kwargs):
340-
return self._loop.run_until_complete(self._client.process_bulk(*args, **kwargs))
460+
def redact(
461+
self,
462+
text: str,
463+
concepts_to_keep: Iterable[str] = None,
464+
warn_on_no_redaction: bool = None,
465+
mask: str = None,
466+
hash: bool = None,
467+
model_name: str = None,
468+
wait_for_completion: bool = True,
469+
return_result: bool = True,
470+
):
471+
"""Redact sensitive information from the provided text."""
472+
return self._run_async(
473+
self._client.redact(
474+
text=text,
475+
concepts_to_keep=concepts_to_keep,
476+
warn_on_no_redaction=warn_on_no_redaction,
477+
mask=mask,
478+
hash=hash,
479+
model_name=model_name,
480+
wait_for_completion=wait_for_completion,
481+
return_result=return_result,
482+
)
483+
)
341484

342-
def redact(self, *args, **kwargs):
343-
return self._loop.run_until_complete(self._client.redact(*args, **kwargs))
485+
def get_task(self, task_uuid: str, detail: bool = True):
486+
"""Get a Gateway task details by its UUID."""
487+
return self._run_async(self._client.get_task(task_uuid=task_uuid, detail=detail))
344488

345-
def get_task(self, *args, **kwargs):
346-
return self._loop.run_until_complete(self._client.get_task(*args, **kwargs))
489+
def get_task_result(self, task_uuid: str, parse: bool = True):
490+
"""Get the result of a Gateway task by its UUID.
347491
348-
def get_task_result(self, *args, **kwargs):
349-
return self._loop.run_until_complete(self._client.get_task_result(*args, **kwargs))
492+
If parse is True, try to infer and parse the result as JSON, JSONL, or text.
493+
Otherwise, return raw bytes.
494+
"""
495+
return self._run_async(self._client.get_task_result(task_uuid=task_uuid, parse=parse))
350496

351-
def wait_for_task(self, *args, **kwargs):
352-
return self._loop.run_until_complete(self._client.wait_for_task(*args, **kwargs))
497+
def wait_for_task(self, task_uuid: str, detail: bool = True, raise_on_error: bool = False):
498+
"""Poll Gateway until the task reaches a final state."""
499+
return self._run_async(
500+
self._client.wait_for_task(
501+
task_uuid=task_uuid, detail=detail, raise_on_error=raise_on_error
502+
)
503+
)
353504

354-
def get_models(self, *args, **kwargs):
355-
return self._loop.run_until_complete(self._client.get_models(*args, **kwargs))
505+
def get_models(self, verbose: bool = False):
506+
"""Get the list of available models from the Gateway."""
507+
return self._run_async(self._client.get_models(verbose=verbose))
356508

357-
def get_model(self, *args, **kwargs):
358-
return self._loop.run_until_complete(self._client.get_model(*args, **kwargs))
509+
def get_model(self, model_name: str = None):
510+
"""Get details of a specific model."""
511+
return self._run_async(self._client.get_model(model_name=model_name))
359512

360-
def deploy_model(self, *args, **kwargs):
361-
return self._loop.run_until_complete(self._client.deploy_model(*args, **kwargs))
513+
def deploy_model(
514+
self,
515+
model_name: str = None,
516+
tracking_id: str = None,
517+
model_uri: str = None,
518+
ttl: int = None,
519+
):
520+
"""Deploy a CogStack Model Serve model through the Gateway."""
521+
return self._run_async(
522+
self._client.deploy_model(
523+
model_name=model_name,
524+
tracking_id=tracking_id,
525+
model_uri=model_uri,
526+
ttl=ttl,
527+
)
528+
)

0 commit comments

Comments
 (0)