|
| 1 | +import multiprocessing |
| 2 | +import os |
| 3 | +from typing import Any, Dict, AsyncGenerator |
| 4 | +from fastchat.utils import is_partial_stop |
| 5 | +from gpt_server.model_backend.base import ModelBackend |
| 6 | +from loguru import logger |
| 7 | + |
| 8 | +import sglang as sgl |
| 9 | + |
| 10 | + |
| 11 | +@sgl.function |
| 12 | +def pipeline(s, prompt, max_tokens): |
| 13 | + for p in prompt: |
| 14 | + if isinstance(p, str): |
| 15 | + s += p |
| 16 | + else: |
| 17 | + s += sgl.image(p) |
| 18 | + s += sgl.gen("response", max_tokens=max_tokens) |
| 19 | + |
| 20 | + |
| 21 | +class SGLangBackend(ModelBackend): |
| 22 | + def __init__(self, model_path) -> None: |
| 23 | + lora = os.getenv("lora", None) |
| 24 | + enable_prefix_caching = bool(os.getenv("enable_prefix_caching", False)) |
| 25 | + max_model_len = os.getenv("max_model_len", None) |
| 26 | + tensor_parallel_size = int(os.getenv("num_gpus", "1")) |
| 27 | + gpu_memory_utilization = float(os.getenv("gpu_memory_utilization", 0.8)) |
| 28 | + dtype = os.getenv("dtype", "auto") |
| 29 | + max_loras = 1 |
| 30 | + enable_lora = False |
| 31 | + self.lora_requests = [] |
| 32 | + # --- |
| 33 | + multiprocessing.set_start_method("spawn", force=True) |
| 34 | + runtime = sgl.Runtime( |
| 35 | + model_path=model_path, |
| 36 | + trust_remote_code=True, |
| 37 | + mem_fraction_static=gpu_memory_utilization, |
| 38 | + tp_size=tensor_parallel_size, |
| 39 | + dtype=dtype, |
| 40 | + context_length=int(max_model_len) if max_model_len else None, |
| 41 | + grammar_backend="xgrammar", |
| 42 | + ) |
| 43 | + |
| 44 | + sgl.set_default_backend(runtime) |
| 45 | + |
| 46 | + async def stream_chat(self, params: Dict[str, Any]) -> AsyncGenerator: |
| 47 | + prompt = params.get("prompt", "") |
| 48 | + messages = params["messages"] |
| 49 | + logger.info(prompt) |
| 50 | + request_id = params.get("request_id", "0") |
| 51 | + temperature = float(params.get("temperature", 0.8)) |
| 52 | + top_p = float(params.get("top_p", 0.8)) |
| 53 | + top_k = params.get("top_k", -1.0) |
| 54 | + max_new_tokens = int(params.get("max_new_tokens", 1024 * 8)) |
| 55 | + stop_str = params.get("stop", None) |
| 56 | + stop_token_ids = params.get("stop_words_ids", None) or [] |
| 57 | + presence_penalty = float(params.get("presence_penalty", 0.0)) |
| 58 | + frequency_penalty = float(params.get("frequency_penalty", 0.0)) |
| 59 | + request = params.get("request", None) |
| 60 | + # Handle stop_str |
| 61 | + stop = set() |
| 62 | + if isinstance(stop_str, str) and stop_str != "": |
| 63 | + stop.add(stop_str) |
| 64 | + elif isinstance(stop_str, list) and stop_str != []: |
| 65 | + stop.update(stop_str) |
| 66 | + |
| 67 | + input_ids = params.get("input_ids", None) |
| 68 | + text_outputs = "" |
| 69 | + state = pipeline.run( |
| 70 | + prompt, |
| 71 | + max_new_tokens=max_new_tokens, |
| 72 | + stop_token_ids=stop_token_ids, |
| 73 | + stop=stop, |
| 74 | + temperature=temperature, |
| 75 | + presence_penalty=presence_penalty, |
| 76 | + frequency_penalty=frequency_penalty, |
| 77 | + top_k=top_k, |
| 78 | + top_p=top_p, |
| 79 | + stream=True, |
| 80 | + ) |
| 81 | + async for out, meta_info in state.text_async_iter( |
| 82 | + var_name="response", return_meta_data=True |
| 83 | + ): |
| 84 | + |
| 85 | + partial_stop = any(is_partial_stop(out, i) for i in stop) |
| 86 | + # prevent yielding partial stop sequence |
| 87 | + if partial_stop: |
| 88 | + continue |
| 89 | + text_outputs += out |
| 90 | + aborted = False |
| 91 | + prompt_tokens = meta_info["prompt_tokens"] |
| 92 | + completion_tokens = meta_info["completion_tokens"] |
| 93 | + usage = { |
| 94 | + "prompt_tokens": prompt_tokens, |
| 95 | + "completion_tokens": completion_tokens, |
| 96 | + "total_tokens": prompt_tokens + completion_tokens, |
| 97 | + } |
| 98 | + ret = { |
| 99 | + "text": text_outputs, |
| 100 | + "error_code": 0, |
| 101 | + "usage": usage, |
| 102 | + "finish_reason": meta_info["finish_reason"]["type"], |
| 103 | + } |
| 104 | + yield ret |
| 105 | + |
| 106 | + if aborted: |
| 107 | + break |
| 108 | + logger.info(text_outputs) |
| 109 | + logger.info(usage) |
0 commit comments