diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index bf854b3a35..bf91c7d939 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -323,6 +323,23 @@ async def generate(request: GenerateRequest): {"completion_ids": [[101, 102, 103], [201, 202, 203]]} ``` """ + all_outputs = await generate_raw(request) + completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] + return {"completion_ids": completion_ids} + + @app.post("/generate_raw/") + async def generate_raw(request: GenerateRequest): + """ + Generates completions for the provided prompts and returns the raw list of RequestOutput objects. + + Args: + request (`GenerateRequest`): + - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + + Returns: + `list[RequestOutput]`: + - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. + """ # Guided decoding, if enabled if request.guided_decoding_regex is not None: @@ -342,8 +359,7 @@ async def generate(request: GenerateRequest): guided_decoding=guided_decoding, ) all_outputs = llm.generate(request.prompts, sampling_params=sampling_params) - completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] - return {"completion_ids": completion_ids} + return all_outputs class InitCommunicatorRequest(BaseModel): host: str