Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions nanovllm/engine/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,24 @@
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.models.qwen2 import Qwen2ForCausalLM
from nanovllm.models.llama import LlamaForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model


class ModelRunner:
model_dict = {
"llama": LlamaForCausalLM,
"qwen2": Qwen2ForCausalLM,
"qwen3": Qwen3ForCausalLM,
}

def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
self.model_type = hf_config.model_type
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
Expand All @@ -28,7 +36,7 @@ def __init__(self, config: Config, rank: int, event: Event | list[Event]):
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
self.model = ModelRunner.model_dict[self.model_type](hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
self.warmup_model()
Expand Down Expand Up @@ -105,10 +113,11 @@ def allocate_kv_cache(self):
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
head_dim = hf_config.head_dim if hasattr(hf_config, "head_dim") else hf_config.hidden_size // hf_config.num_attention_heads
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
Expand Down
4 changes: 2 additions & 2 deletions nanovllm/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def forward(
return query, key


@lru_cache(1)
# @lru_cache(1)
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
rope_scaling: dict | None = None,
):
assert rope_scaling is None
# assert rope_scaling is None
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
return rotary_emb
242 changes: 242 additions & 0 deletions nanovllm/models/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import torch
from torch import nn
import torch.distributed as dist
from transformers import LlamaConfig

from nanovllm.layers.activation import SiluAndMul
from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead


class LlamaAttention(nn.Module):

def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
max_position_embeddings: int = 8192,
bias: bool = False,
bias_o_proj: bool = False,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = self.hidden_size // self.total_num_heads
self.head_dim = head_dim
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
)

self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias_o_proj,
)

self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
)

self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
self.num_kv_heads,
)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output = self.o_proj(attn_output)
return output


class LlamaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
bias: bool = False,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
)
assert hidden_act == "silu"
self.act_fn = SiluAndMul()

def forward(self, x):
x = self.gate_up_proj(x)
x = self.act_fn(x)
x = self.down_proj(x)
return x


class LlamaDecoderLayer(nn.Module):

def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias

self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=getattr(config, "num_key_value_heads",
config.num_attention_heads),
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
bias=attention_bias,
bias_o_proj=bias_o_proj,
)
self.mlp = LlamaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
bias=getattr(config, "mlp_bias", False),
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(positions=positions,
hidden_states=hidden_states)

# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual


class LlamaModel(nn.Module):

def __init__(
self,
config: LlamaConfig,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for layer in self.layers:
hidden_states, residual = layer(positions, hidden_states, residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


class LlamaForCausalLM(nn.Module):
packed_modules_mapping = {
"q_proj": ("qkv_proj", "q"),
"k_proj": ("qkv_proj", "k"),
"v_proj": ("qkv_proj", "v"),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}

def __init__(
self,
config: LlamaConfig
) -> None:
super().__init__()
self.model = LlamaModel(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if config.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
) -> torch.Tensor:
model_output = self.model(input_ids, positions)
return model_output

def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
logits = self.lm_head(hidden_states)
return logits
Loading