Skip to content

Commit efccbd3

Browse files
committed
remove gradio + add sglang backend code
1 parent 937f88c commit efccbd3

File tree

4 files changed

+128
-152
lines changed

4 files changed

+128
-152
lines changed
+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import multiprocessing
2+
import os
3+
from typing import Any, Dict, AsyncGenerator
4+
from fastchat.utils import is_partial_stop
5+
from gpt_server.model_backend.base import ModelBackend
6+
from loguru import logger
7+
8+
import sglang as sgl
9+
10+
11+
@sgl.function
12+
def pipeline(s, prompt, max_tokens):
13+
for p in prompt:
14+
if isinstance(p, str):
15+
s += p
16+
else:
17+
s += sgl.image(p)
18+
s += sgl.gen("response", max_tokens=max_tokens)
19+
20+
21+
class SGLangBackend(ModelBackend):
22+
def __init__(self, model_path) -> None:
23+
lora = os.getenv("lora", None)
24+
enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False))
25+
max_model_len = os.getenv("max_model_len", None)
26+
tensor_parallel_size = int(os.getenv("num_gpus", "1"))
27+
gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.8))
28+
dtype = os.getenv("dtype", "auto")
29+
max_loras = 1
30+
enable_lora = False
31+
self.lora_requests = []
32+
# ---
33+
multiprocessing.set_start_method("spawn", force=True)
34+
runtime = sgl.Runtime(
35+
model_path=model_path,
36+
trust_remote_code=True,
37+
mem_fraction_static=gpu_memory_utilization,
38+
tp_size=tensor_parallel_size,
39+
dtype=dtype,
40+
context_length=int(max_model_len) if max_model_len else None,
41+
grammar_backend="xgrammar",
42+
)
43+
44+
sgl.set_default_backend(runtime)
45+
46+
async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator:
47+
prompt = params.get("prompt", "")
48+
messages = params["messages"]
49+
logger.info(prompt)
50+
request_id = params.get("request_id", "0")
51+
temperature = float(params.get("temperature", 0.8))
52+
top_p = float(params.get("top_p", 0.8))
53+
top_k = params.get("top_k", -1.0)
54+
max_new_tokens = int(params.get("max_new_tokens", 1024 * 8))
55+
stop_str = params.get("stop", None)
56+
stop_token_ids = params.get("stop_words_ids", None) or []
57+
presence_penalty = float(params.get("presence_penalty", 0.0))
58+
frequency_penalty = float(params.get("frequency_penalty", 0.0))
59+
request = params.get("request", None)
60+
# Handle stop_str
61+
stop = set()
62+
if isinstance(stop_str, str) and stop_str != "":
63+
stop.add(stop_str)
64+
elif isinstance(stop_str, list) and stop_str != []:
65+
stop.update(stop_str)
66+
67+
input_ids = params.get("input_ids", None)
68+
text_outputs = ""
69+
state = pipeline.run(
70+
prompt,
71+
max_new_tokens=max_new_tokens,
72+
stop_token_ids=stop_token_ids,
73+
stop=stop,
74+
temperature=temperature,
75+
presence_penalty=presence_penalty,
76+
frequency_penalty=frequency_penalty,
77+
top_k=top_k,
78+
top_p=top_p,
79+
stream=True,
80+
)
81+
async for out, meta_info in state.text_async_iter(
82+
var_name="response", return_meta_data=True
83+
):
84+
85+
partial_stop = any(is_partial_stop(out, i) for i in stop)
86+
# prevent yielding partial stop sequence
87+
if partial_stop:
88+
continue
89+
text_outputs += out
90+
aborted = False
91+
prompt_tokens = meta_info["prompt_tokens"]
92+
completion_tokens = meta_info["completion_tokens"]
93+
usage = {
94+
"prompt_tokens": prompt_tokens,
95+
"completion_tokens": completion_tokens,
96+
"total_tokens": prompt_tokens + completion_tokens,
97+
}
98+
ret = {
99+
"text": text_outputs,
100+
"error_code": 0,
101+
"usage": usage,
102+
"finish_reason": meta_info["finish_reason"]["type"],
103+
}
104+
yield ret
105+
106+
if aborted:
107+
break
108+
logger.info(text_outputs)
109+
logger.info(usage)

gpt_server/model_worker/base/model_worker_base.py

+6
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,11 @@ def load_model_tokenizer(self, model_path):
122122

123123
logger.info(f"{self.worker_name} 使用 vllm 后端")
124124
self.backend = VllmBackend(model_path=self.model_path)
125+
elif "sglang" in os.getenv("backend"):
126+
from gpt_server.model_backend.sglang_backend import SGLangBackend
125127

128+
logger.info(f"{self.worker_name} 使用 SGLang 后端")
129+
self.backend = SGLangBackend(model_path=self.model_path)
126130
elif "lmdeploy" in os.getenv("backend"):
127131
from gpt_server.model_backend.lmdeploy_backend import LMDeployBackend
128132

@@ -209,6 +213,8 @@ def run(cls):
209213
os.environ["backend"] = "lmdeploy-pytorch"
210214
elif args.backend == "lmdeploy-turbomind":
211215
os.environ["backend"] = "lmdeploy-turbomind"
216+
elif args.backend == "sglang":
217+
os.environ["backend"] = "sglang"
212218

213219
if args.lora:
214220
os.environ["lora"] = args.lora

pyproject.toml

-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ dependencies = [
1111
"fastapi==0.114.1",
1212
"ffmpy",
1313
"fschat==0.2.36",
14-
"gradio==4.26.0",
1514
"infinity-emb[all]==0.0.73",
1615
"lmdeploy==0.7.0.post3",
1716
"loguru>=0.7.2",

0 commit comments

Comments
 (0)