Skip to content

Commit 681879a

Browse files
committed
align with sglang /generate output format
1 parent 1a169fa commit 681879a

File tree

2 files changed

+35
-27
lines changed

2 files changed

+35
-27
lines changed

lmdeploy/serve/openai/api_server.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
CompletionResponse, CompletionResponseChoice,
3939
CompletionResponseStreamChoice, CompletionStreamResponse, DeltaMessage,
4040
EmbeddingsRequest, EncodeRequest, EncodeResponse, ErrorResponse,
41-
GenerateReqInput, GenerateReqOutput, GenerateRequest, LogProbs, ModelCard,
42-
ModelList, ModelPermission, PoolingRequest, PoolingResponse, TopLogprob,
43-
UpdateParamsRequest, UsageInfo)
41+
GenerateReqInput, GenerateReqMetaOutput, GenerateReqOutput, GenerateRequest,
42+
LogProbs, ModelCard, ModelList, ModelPermission, PoolingRequest,
43+
PoolingResponse, TopLogprob, UpdateParamsRequest, UsageInfo)
4444
from lmdeploy.serve.openai.reasoning_parser.reasoning_parser import ReasoningParser, ReasoningParserManager
4545
from lmdeploy.serve.openai.tool_parser.tool_parser import ToolParser, ToolParserManager
4646
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
@@ -937,47 +937,51 @@ async def generate(request: GenerateReqInput, raw_request: Request = None):
937937
do_preprocess=False,
938938
)
939939

940-
def create_generate_response_json(text, gen_tokens, logprobs, finish_reason):
941-
response = GenerateReqOutput(
942-
text=text,
943-
gen_tokens=gen_tokens,
944-
logprobs=logprobs or None,
945-
finish_reason=finish_reason,
946-
)
940+
def create_finish_reason(finish_reason):
941+
# TODO: add detail info
942+
if not finish_reason:
943+
return None
944+
if finish_reason == 'length':
945+
return dict(type='length')
946+
if finish_reason == 'stop':
947+
return dict(type='stop')
948+
return dict(type='abort')
949+
950+
def create_generate_response_json(text, output_ids, logprobs, finish_reason):
951+
meta = GenerateReqMetaOutput(finish_reason=create_finish_reason(finish_reason),
952+
output_token_logprobs=logprobs or None)
953+
response = GenerateReqOutput(text=text, output_ids=output_ids, meta_info=meta)
947954
return response.model_dump_json()
948955

949956
async def generate_stream_generator():
950-
text = '' # full response
951957
async for res in result_generator:
952-
text += res.response or ''
953-
gen_tokens = res.token_ids
958+
text = res.response or ''
959+
output_ids = res.token_ids
954960
logprobs = []
955961
if res.logprobs:
956962
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
957-
logprobs.append((tok, tok_logprobs[tok]))
958-
response_json = create_generate_response_json(text, gen_tokens, logprobs, res.finish_reason)
963+
logprobs.append((tok_logprobs[tok], tok))
964+
response_json = create_generate_response_json(text, output_ids, logprobs, res.finish_reason)
959965
yield f'data: {response_json}\n\n'
960966
yield 'data: [DONE]\n\n'
961967

962968
if request.stream:
963969
return StreamingResponse(generate_stream_generator(), media_type='text/event-stream')
964970

965971
text = ''
966-
gen_tokens = []
972+
output_ids = []
967973
logprobs = []
968974
async for res in result_generator:
969975
text += res.response or ''
970-
gen_tokens.extend(res.token_ids or [])
976+
output_ids.extend(res.token_ids or [])
971977
if res.logprobs:
972978
for tok, tok_logprobs in zip(res.token_ids, res.logprobs):
973-
logprobs.append((tok, tok_logprobs[tok]))
979+
logprobs.append((tok_logprobs[tok], tok))
974980

975-
response = GenerateReqOutput(
976-
text=text,
977-
gen_tokens=gen_tokens,
978-
logprobs=logprobs or None,
979-
finish_reason=res.finish_reason,
980-
)
981+
response = GenerateReqOutput(text=text,
982+
output_ids=output_ids,
983+
meta_info=GenerateReqMetaOutput(finish_reason=create_finish_reason(res.finish_reason),
984+
output_token_logprobs=logprobs or None))
981985
return response
982986

983987

lmdeploy/serve/openai/protocol.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,9 +458,13 @@ class GenerateReqInput(BaseModel):
458458
include_stop_str_in_output: Optional[bool] = False
459459

460460

461+
class GenerateReqMetaOutput(BaseModel):
462+
finish_reason: Optional[Dict[str, Any]] = None
463+
output_token_logprobs: Optional[List[tuple[float, int]]] = None # (logprob, token_id)
464+
465+
461466
# /generate output
462467
class GenerateReqOutput(BaseModel):
463468
text: str
464-
gen_tokens: List[int]
465-
logprobs: Optional[List[tuple[int, float]]] = None # (token_id, logprob)
466-
finish_reason: Optional[Literal['stop', 'length', 'tool_calls', 'error']] = None
469+
output_ids: List[int]
470+
meta_info: GenerateReqMetaOutput

0 commit comments

Comments
 (0)