Skip to content
This repository was archived by the owner on Aug 1, 2024. It is now read-only.

Commit 4b9f210

Browse files
committed
add async operations for completions
1 parent 6bbbf5c commit 4b9f210

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

poetry.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ tenacity = "^8.2.3"
1414
requests = "^2.31.0"
1515
pydantic = "^2.5.2"
1616
pydantic-settings = "^2.1.0"
17-
pandas = "^2.1.3"
17+
pandas = "*"
1818
aiohttp = "^3.8.1"
1919

2020

src/vendi/completions/completions.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ async def acreate(
165165
async def acreate_many(
166166
self,
167167
requests: list[CompletionRequest],
168-
) -> List[ChatCompletion]:
168+
) -> List[ChatCompletion] | List[VendiCompletionResponse]:
169169
"""
170170
Create multiple completions on different models with the same prompt and parameters
171171
requests: A list of completionr requests
@@ -254,6 +254,36 @@ def run_batch_job(
254254
time.sleep(poll_interval)
255255
return job
256256

257+
async def arun_batch_job(
258+
self,
259+
dataset_id: uuid.UUID,
260+
model_parameters: list[ModelParameters],
261+
wait_until_complete: bool = False,
262+
timeout: int = 3000,
263+
poll_interval: int = 5,
264+
) -> BatchInference:
265+
_res = await self.__aclient.post(
266+
path="/platform/v1/inference/batch/",
267+
json={
268+
"dataset_id": str(dataset_id),
269+
"model_parameters": [{**i.model_dump(), **i.model_extra} for i in model_parameters]
270+
}
271+
)
272+
job = BatchInference(**_res)
273+
274+
if wait_until_complete:
275+
start_time = time.time()
276+
while True:
277+
job = await self._aget_batch_job(job.id)
278+
if job.status in [BatchInferenceStatus.COMPLETED, BatchInferenceStatus.FAILED]:
279+
return job
280+
if time.time() - start_time > timeout:
281+
raise TimeoutError(
282+
"The batch job did not complete within the specified timeout. "
283+
"You can still check its status by using the batch_job_status method.")
284+
285+
await asyncio.sleep(poll_interval)
286+
257287
def __post_batch_job(self, dataset_id: uuid.UUID, model_parameters: list[ModelParameters]) -> BatchInference:
258288
res = self.__client.post(
259289
uri="/platform/v1/inference/batch/",
@@ -280,3 +310,27 @@ def _get_batch_job(self, batch_inference_id: uuid.UUID) -> BatchInference:
280310
uri=f"/platform/v1/inference/batch/{batch_inference_id}"
281311
)
282312
return BatchInference(**res)
313+
314+
async def _aget_batch_job(self, batch_inference_id: uuid.UUID) -> BatchInference:
315+
"""
316+
Get a batch inference object job by ID
317+
"""
318+
res = await self.__aclient.get(
319+
path=f"/platform/v1/inference/batch/{batch_inference_id}"
320+
)
321+
return BatchInference(**res)
322+
323+
def list_batch_jobs(self) -> List[BatchInference]:
324+
"""
325+
Get all batch inferences
326+
"""
327+
res = self.__client.get(
328+
uri="/platform/v1/inference/batch/"
329+
)
330+
return [BatchInference(**i) for i in res]
331+
332+
def delete_batch_job(self, batch_id: uuid.UUID):
333+
"""
334+
Delete a batch inference job
335+
"""
336+
return self.__client.delete(f"/platform/v1/inference/batch/{batch_id}")

0 commit comments

Comments
 (0)