From a7a46fa1ff5ee5efa07a735e57ede8ec9c2817cb Mon Sep 17 00:00:00 2001 From: zhanglize Date: Sun, 6 Jul 2025 20:59:10 +0800 Subject: [PATCH] add Qwen2 model support --- nanovllm/engine/model_runner.py | 46 +++---- nanovllm/models/qwen2.py | 204 ++++++++++++++++++++++++++++++++ 2 files changed, 229 insertions(+), 21 deletions(-) create mode 100644 nanovllm/models/qwen2.py diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index d48a0eb1..1ee5aeb0 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -7,6 +7,7 @@ from nanovllm.config import Config from nanovllm.engine.sequence import Sequence from nanovllm.models.qwen3 import Qwen3ForCausalLM +from nanovllm.models.qwen2 import Qwen2ForCausalLM from nanovllm.layers.sampler import Sampler from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model @@ -22,17 +23,17 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]): self.world_size = config.tensor_parallel_size self.rank = rank self.event = event - + dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) + # self.model = Qwen2ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = Sampler() - self.warmup_model() - self.allocate_kv_cache() + self.allocate_kv_cache(config.gpu_memory_utilization) if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") @@ -77,6 +78,7 @@ def write_shm(self, method_name, *args): assert self.world_size > 1 and not self.rank data = pickle.dumps([method_name, *args]) n = len(data) + assert n + 4 <= self.shm.size self.shm.buf[0:4] = n.to_bytes(4, "little") self.shm.buf[4:n+4] = data for event in self.event: @@ -86,28 +88,19 @@ def call(self, method_name, *args): if self.world_size > 1 and self.rank == 0: self.write_shm(method_name, *args) method = getattr(self, method_name, None) + assert callable(method) return method(*args) - def warmup_model(self): - torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() - max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len - num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) - seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] - self.run(seqs, True) - torch.cuda.empty_cache() - - def allocate_kv_cache(self): + def allocate_kv_cache(self, gpu_memory_utilization): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free - peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] - current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size + if not hasattr(hf_config, "head_dim"): + hf_config.head_dim = hf_config.hidden_size // hf_config.num_attention_heads block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes - assert config.num_kvcache_blocks > 0 + config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): @@ -118,7 +111,10 @@ def allocate_kv_cache(self): def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) - block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] + block_tables = [ + seq.block_table + [-1] * (max_len - len(seq.block_table)) + for seq in seqs + ] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables @@ -141,8 +137,6 @@ def prepare_prefill(self, seqs: list[Sequence]): cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) - if not seq.block_table: - continue for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1: @@ -150,6 +144,7 @@ def prepare_prefill(self, seqs: list[Sequence]): else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) + assert len(input_ids) == len(slot_mapping) if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) @@ -186,7 +181,7 @@ def prepare_sample(self, seqs: list[Sequence]): return temperatures @torch.inference_mode() - def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): if is_prefill or self.enforce_eager or input_ids.size(0) > 512: return self.model.compute_logits(self.model(input_ids, positions)) else: @@ -215,6 +210,12 @@ def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: @torch.inference_mode() def capture_cudagraph(self): + get_rng_state = torch.cuda.get_rng_state + set_rng_state = torch.cuda.set_rng_state + rng_state = torch.cuda.get_rng_state() + torch.cuda.get_rng_state = lambda: rng_state + torch.cuda.set_rng_state = lambda _: None + config = self.config hf_config = config.hf_config max_bs = min(self.config.max_num_seqs, 512) @@ -249,3 +250,6 @@ def capture_cudagraph(self): block_tables=block_tables, outputs=outputs, ) + + torch.cuda.get_rng_state = get_rng_state + torch.cuda.set_rng_state = set_rng_state diff --git a/nanovllm/models/qwen2.py b/nanovllm/models/qwen2.py new file mode 100644 index 00000000..d1bfd815 --- /dev/null +++ b/nanovllm/models/qwen2.py @@ -0,0 +1,204 @@ +import torch +from torch import nn +import torch.distributed as dist +from transformers import Qwen2Config + +from nanovllm.layers.activation import SiluAndMul +from nanovllm.layers.attention import Attention +from nanovllm.layers.layernorm import RMSNorm +from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear +from nanovllm.layers.rotary_embedding import get_rope +from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead + +class Qwen2MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + ) + + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + ) + assert hidden_act == "silu" + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x = self.down_proj(x) + return x + +class Qwen2Attention(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + rope_scaling: tuple | None = None, + ) -> None: + super().__init__() + tp_size = dist.get_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + assert self.total_num_kv_heads % tp_size == 0 + self.num_kv_heads = self.total_num_kv_heads // tp_size + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling + ) + + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor + ) -> torch.Tensor: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + output = self.o_proj(attn_output) + + return output + + +class Qwen2DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen2Config + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention( + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + max_position=config.max_position_embeddings, + rope_theta=getattr(config, "rope_theta", 1000000), + rope_scaling=getattr(config, "rope_scaling", None) + ) + + self.mlp = Qwen2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + hidden_states = self.self_attn(positions, hidden_states) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + +class Qwen2Model(nn.Module): + def __init__( + self, + config: Qwen2Config + ) -> None: + super().__init__() + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) + self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer(positions, hidden_states, residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + +class Qwen2ForCausalLM(nn.Module): + packed_modules_mapping = { + "q_proj": ("qkv_proj", "q"), + "k_proj": ("qkv_proj", "k"), + "v_proj": ("qkv_proj", "v"), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: Qwen2Config + ) -> None: + super().__init__() + self.model = Qwen2Model(config) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + if config.tie_word_embeddings: + self.lm_head.weight.data = self.model.embed_tokens.weight.data + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + logits = self.lm_head(hidden_states) + return logits \ No newline at end of file