@@ -165,7 +165,7 @@ async def acreate(
165
165
async def acreate_many (
166
166
self ,
167
167
requests : list [CompletionRequest ],
168
- ) -> List [ChatCompletion ]:
168
+ ) -> List [ChatCompletion ] | List [ VendiCompletionResponse ] :
169
169
"""
170
170
Create multiple completions on different models with the same prompt and parameters
171
171
requests: A list of completionr requests
@@ -254,6 +254,36 @@ def run_batch_job(
254
254
time .sleep (poll_interval )
255
255
return job
256
256
257
+ async def arun_batch_job (
258
+ self ,
259
+ dataset_id : uuid .UUID ,
260
+ model_parameters : list [ModelParameters ],
261
+ wait_until_complete : bool = False ,
262
+ timeout : int = 3000 ,
263
+ poll_interval : int = 5 ,
264
+ ) -> BatchInference :
265
+ _res = await self .__aclient .post (
266
+ path = "/platform/v1/inference/batch/" ,
267
+ json = {
268
+ "dataset_id" : str (dataset_id ),
269
+ "model_parameters" : [{** i .model_dump (), ** i .model_extra } for i in model_parameters ]
270
+ }
271
+ )
272
+ job = BatchInference (** _res )
273
+
274
+ if wait_until_complete :
275
+ start_time = time .time ()
276
+ while True :
277
+ job = await self ._aget_batch_job (job .id )
278
+ if job .status in [BatchInferenceStatus .COMPLETED , BatchInferenceStatus .FAILED ]:
279
+ return job
280
+ if time .time () - start_time > timeout :
281
+ raise TimeoutError (
282
+ "The batch job did not complete within the specified timeout. "
283
+ "You can still check its status by using the batch_job_status method." )
284
+
285
+ await asyncio .sleep (poll_interval )
286
+
257
287
def __post_batch_job (self , dataset_id : uuid .UUID , model_parameters : list [ModelParameters ]) -> BatchInference :
258
288
res = self .__client .post (
259
289
uri = "/platform/v1/inference/batch/" ,
@@ -280,3 +310,27 @@ def _get_batch_job(self, batch_inference_id: uuid.UUID) -> BatchInference:
280
310
uri = f"/platform/v1/inference/batch/{ batch_inference_id } "
281
311
)
282
312
return BatchInference (** res )
313
+
314
+ async def _aget_batch_job (self , batch_inference_id : uuid .UUID ) -> BatchInference :
315
+ """
316
+ Get a batch inference object job by ID
317
+ """
318
+ res = await self .__aclient .get (
319
+ path = f"/platform/v1/inference/batch/{ batch_inference_id } "
320
+ )
321
+ return BatchInference (** res )
322
+
323
+ def list_batch_jobs (self ) -> List [BatchInference ]:
324
+ """
325
+ Get all batch inferences
326
+ """
327
+ res = self .__client .get (
328
+ uri = "/platform/v1/inference/batch/"
329
+ )
330
+ return [BatchInference (** i ) for i in res ]
331
+
332
+ def delete_batch_job (self , batch_id : uuid .UUID ):
333
+ """
334
+ Delete a batch inference job
335
+ """
336
+ return self .__client .delete (f"/platform/v1/inference/batch/{ batch_id } " )
0 commit comments