Skip to content

Commit 86fb6f0

Browse files
QubitiumLRL-ModelCloudZX-ModelCloudLRL2-ModelCloud
authored
Llama 4 Support (#1508)
* update transformers for llama 4 Signed-off-by: Qubitium <[email protected]> * add Llama4GPTQ * use loader AutoModelForImageTextToText * cleanup * fix qkvo forward when every 4 layer Signed-off-by: ZX-ModelCloud <[email protected]> * Update llama4.py * add support_batch_quantize * add warning * Update README.md * fix data * fix input_ids * update llama4 modules * cleanup * Revert "update llama4 modules" This reverts commit 250060b. * llama4 add after_model_load * speed up * Update llama4.py * fix config * add before_model_load * update readme * cleanup --------- Signed-off-by: Qubitium <[email protected]> Signed-off-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL-ModelCloud <[email protected]> Co-authored-by: ZX-ModelCloud <[email protected]> Co-authored-by: LRL2-ModelCloud <[email protected]>
1 parent 69c3db1 commit 86fb6f0

File tree

8 files changed

+230
-14
lines changed

8 files changed

+230
-14
lines changed

README.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,17 @@ Native support support some of the most popular multi-modal models:
174174
## Model Support
175175
| Model | | | | | | | | | |
176176
|-------------------|---|-------------------|---|----------------|---|----------------|---|------------|---|
177-
| Baichuan || EXAONE 3.0 || InternLM 1/2.5 || OPT || StableLM ||
178-
| Bloom || Falcon (H1) || Llama 1-3.3 || OLMo2 || StarCoder2 ||
179-
| ChatGLM || Gemma 1/2/3 || Llama 3.2 VL || Ovis 1.6/2 || TeleChat2 ||
180-
| CodeGen || GPTBigCod || LongLLaMA || Phi 1-4 || Yi ||
181-
| Cohere 1-2 || GPTQ-Neo/GPT-NeoX || Instella || Nemotron Ultra || Seed-OSS ||
182-
| DBRX Converted || GPT-2 || MiniCPM3 || PanGu-α || XVERSE ||
183-
| Deci || GPT-J || Mistral || Qwen 1/2/3 || | |
184-
| DeepSeek-V2/V3/R1 || GPT-OSS || Mixtral || Qwen 2/3 MoE || | |
185-
| DeepSeek-V2-Lite || Granite || MobileLLM || Qwen 2/2.5 VL || | |
186-
| Dream || GRIN-MoE || MOSS || Qwen 2.5 Omni || | |
187-
| ERNIE 4.5 || Hymba || MPT || RefinedWeb || | |
188-
177+
| Baichuan || EXAONE 3.0 || InternLM 1/2.5 || MPT || RefinedWeb ||
178+
| Bloom || Falcon (H1) || Llama 1-3.3 || OPT || StableLM ||
179+
| ChatGLM || Gemma 1/2/3 || Llama 3.2 VL || OLMo2 || StarCoder2 ||
180+
| CodeGen || GPTBigCod || Llama 4 || Ovis 1.6/2 || TeleChat2 ||
181+
| Cohere 1-2 || GPTQ-Neo/GPT-NeoX || LongLLaMA || Phi 1-4 || Yi ||
182+
| DBRX Converted || GPT-2 || Instella || Nemotron Ultra || Seed-OSS ||
183+
| Deci || GPT-J || MiniCPM3 || PanGu-α || XVERSE ||
184+
| DeepSeek-V2/V3/R1 || GPT-OSS || Mistral || Qwen 1/2/3 || | |
185+
| DeepSeek-V2-Lite || Granite || Mixtral || Qwen 2/3 MoE || | |
186+
| Dream || GRIN-MoE || MobileLLM || Qwen 2/2.5 VL || | |
187+
| ERNIE 4.5 || Hymba || MOSS || Qwen 2.5 Omni || | |
189188

190189
## Platform and HW Support
191190

gptqmodel/looper/module_looper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,10 @@ def loop(self, auto_gc=True, calibration_enable_gpu_cache=True, buffered_fwd=Fal
196196

197197
# dynamic expert layer index for model defs
198198
if self.gptq_model.dynamic_expert_index is not None:
199-
num_experts = getattr(self.gptq_model.model.config, self.gptq_model.dynamic_expert_index)
199+
if hasattr(self.gptq_model.model.config, "text_config"):
200+
num_experts = getattr(self.gptq_model.model.config.text_config, self.gptq_model.dynamic_expert_index)
201+
else:
202+
num_experts = getattr(self.gptq_model.model.config, self.gptq_model.dynamic_expert_index)
200203
layer_modules = get_moe_layer_modules(layer_modules=self.gptq_model.layer_modules,
201204
num_experts=num_experts)
202205

gptqmodel/models/auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
from .definitions.internlm import InternLMGPTQ # noqa: E402
100100
from .definitions.internlm2 import InternLM2GPTQ # noqa: E402
101101
from .definitions.llama import LlamaGPTQ # noqa: E402
102+
from .definitions.llama4 import Llama4GPTQ # noqa: E402
102103
from .definitions.longllama import LongLlamaGPTQ # noqa: E402
103104
from .definitions.mimo import MimoGPTQ # noqa: E402
104105
from .definitions.minicpm import MiniCPMGPTQ # noqa: E402
@@ -145,6 +146,7 @@
145146
"gptj": GPTJGPTQ,
146147
"gpt2": GPT2GPTQ,
147148
"llama": LlamaGPTQ,
149+
"llama4": Llama4GPTQ,
148150
"opt": OPTGPTQ,
149151
"moss": MOSSGPTQ,
150152
"chatglm": ChatGLM,

gptqmodel/models/base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch._dynamo
2727
import torch.nn as nn
2828
from tokenicer import Tokenicer
29+
from torch import LongTensor
2930
from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel,
3031
PreTrainedTokenizerBase, ProcessorMixin, modeling_utils)
3132

@@ -123,6 +124,8 @@ class BaseGPTQModel(nn.Module):
123124

124125
server = None
125126

127+
support_batch_quantize = True
128+
126129
def __init__(
127130
self,
128131
model: PreTrainedModel,
@@ -370,6 +373,10 @@ def quantize(
370373
"FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ."
371374
)
372375

376+
if self.support_batch_quantize is False:
377+
batch_size = 1
378+
log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.")
379+
373380
# Validate quant linear before quantization starts
374381
_ = select_quant_linear(
375382
bits=self.quantize_config.bits,

gptqmodel/models/definitions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from .internlm import InternLMGPTQ
5151
from .internlm2 import InternLM2GPTQ
5252
from .llama import LlamaGPTQ
53+
from .llama4 import Llama4GPTQ
5354
from .longllama import LongLlamaGPTQ
5455
from .mimo import MimoGPTQ
5556
from .minicpm3 import MiniCPM3GPTQ
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2024-2025 ModelCloud.ai
2+
# Copyright 2024-2025 [email protected]
3+
# Contact: [email protected], x.com/qubitium
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from transformers import AutoModelForImageTextToText
18+
from .._const import EXPERT_INDEX_PLACEHOLDER
19+
from ..base import BaseGPTQModel
20+
21+
class Llama4GPTQ(BaseGPTQModel):
22+
# some bug in the attention_mask of transformers.modeling_llama4,
23+
# so batch quantization for Llama4 is temporarily not supported.
24+
support_batch_quantize = False
25+
loader = AutoModelForImageTextToText
26+
27+
base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"]
28+
pre_lm_head_norm_module = "language_model.model.norm"
29+
30+
layers_node = "language_model.model.layers"
31+
layer_type = "Llama4TextDecoderLayer"
32+
33+
dynamic_expert_index = "num_local_experts"
34+
35+
layer_modules = [
36+
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"],
37+
38+
[f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.gate_proj", f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.up_proj"],
39+
[f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.down_proj"],
40+
41+
["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"],
42+
]
43+
44+
def before_model_load(self, load_quantized_model=False):
45+
if load_quantized_model:
46+
import torch
47+
import torch.nn as nn
48+
import transformers.models.llama4.modeling_llama4 as llama4_modeling
49+
from transformers.integrations.hub_kernels import use_kernel_forward_from_hub
50+
51+
@use_kernel_forward_from_hub("Llama4TextMoe")
52+
class SequentialLlama4TextMoe(torch.nn.Module):
53+
def __init__(self, config):
54+
super().__init__()
55+
self.top_k = config.num_experts_per_tok
56+
self.hidden_dim = config.hidden_size
57+
print(config)
58+
self.num_experts = 16
59+
self.experts = nn.ModuleList(
60+
[llama4_modeling.Llama4TextMLP(config) for _ in range(self.num_experts)]
61+
)
62+
self.router = llama4_modeling.Llama4Router(config)
63+
self.shared_expert = llama4_modeling.Llama4TextMLP(config)
64+
65+
def forward(self, hidden_states: torch.Tensor):
66+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
67+
router_logits = self.router(hidden_states)
68+
if isinstance(router_logits, tuple):
69+
router_scores, router_logits = router_logits
70+
router_scores = router_scores.t()
71+
else:
72+
# transformers < 4.54.0 only returns router_logits
73+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
74+
75+
router_scores = (
76+
torch.full_like(router_logits, float("-inf"))
77+
.scatter_(1, router_indices, router_top_value)
78+
.transpose(0, 1)
79+
)
80+
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
81+
82+
out = self.shared_expert(hidden_states)
83+
for i in range(self.num_experts):
84+
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
85+
86+
return out, router_logits
87+
88+
llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe
89+
90+
91+
def after_model_load(self, model, load_quantized_model=False):
92+
if load_quantized_model:
93+
return model
94+
95+
import os
96+
import torch
97+
from concurrent.futures import ThreadPoolExecutor
98+
from functools import partial
99+
from transformers.modeling_utils import no_init_weights
100+
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe
101+
102+
# adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
103+
class SequentialLlama4TextExperts(torch.nn.ModuleList):
104+
def __init__(self, config, original):
105+
self.num_experts = original.gate_up_proj.shape[0]
106+
with no_init_weights():
107+
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
108+
intermediate_size = original.down_proj.shape[1]
109+
110+
with torch.no_grad():
111+
# Batch process all expert parameters to avoid loops
112+
gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)])
113+
down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)])
114+
115+
# Batch split and transpose
116+
gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous()
117+
up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous()
118+
down_batch = down_batch.transpose(-2, -1).contiguous()
119+
120+
# Batch assignment
121+
for i in range(self.num_experts):
122+
self[i].gate_proj.weight.data = gate_batch[i]
123+
self[i].up_proj.weight.data = up_batch[i]
124+
self[i].down_proj.weight.data = down_batch[i]
125+
126+
class SequentialLlama4TextMoe(torch.nn.Module):
127+
def __init__(self, config, original):
128+
super().__init__()
129+
self.top_k = config.num_experts_per_tok
130+
self.hidden_dim = config.hidden_size
131+
self.num_experts = config.num_local_experts
132+
self.experts = SequentialLlama4TextExperts(config, original.experts)
133+
self.router = original.router
134+
self.shared_expert = original.shared_expert
135+
136+
def forward(self, hidden_states: torch.Tensor):
137+
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
138+
router_logits = self.router(hidden_states)
139+
if isinstance(router_logits, tuple):
140+
router_scores, router_logits = router_logits
141+
router_scores = router_scores.t()
142+
else:
143+
# transformers < 4.54.0 only returns router_logits
144+
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)
145+
146+
router_scores = (
147+
torch.full_like(router_logits, float("-inf"))
148+
.scatter_(1, router_indices, router_top_value)
149+
.transpose(0, 1)
150+
)
151+
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)
152+
153+
out = self.shared_expert(hidden_states)
154+
for i in range(self.num_experts):
155+
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)
156+
157+
return out, router_logits
158+
159+
model = model.to("cpu")
160+
def process_module(name, module, model, config):
161+
if isinstance(module, Llama4TextMoe):
162+
new_module = SequentialLlama4TextMoe(config=config, original=module)
163+
parent, child = name.rsplit(".", maxsplit=1)
164+
print("replace moe" + name + child)
165+
parent = model.get_submodule(parent)
166+
setattr(parent, child, new_module)
167+
print("cpu count", os.cpu_count())
168+
with ThreadPoolExecutor(max_workers=8) as executor:
169+
process_fn = partial(process_module, model=model, config=model.config.get_text_config())
170+
list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules()))
171+
172+
return model

gptqmodel/models/loader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,10 @@ def skip(*args, **kwargs):
478478
model.checkpoint_file_name = model_save_name
479479

480480
if cls.dynamic_expert_index is not None:
481-
num_experts = getattr(config, cls.dynamic_expert_index)
481+
if hasattr(config, "text_config"):
482+
num_experts = getattr(config.text_config, cls.dynamic_expert_index)
483+
else:
484+
num_experts = getattr(config, cls.dynamic_expert_index)
482485
cls.layer_modules = get_moe_layer_modules(layer_modules=cls.layer_modules,
483486
num_experts=num_experts)
484487

tests/models/test_llama4.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2024-2025 ModelCloud.ai
2+
# Copyright 2024-2025 [email protected]
3+
# Contact: [email protected], x.com/qubitium
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from model_test import ModelTest
18+
19+
20+
class TestLlama4(ModelTest):
21+
NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct"
22+
NATIVE_ARC_CHALLENGE_ACC = 0.3567
23+
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805
24+
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36
25+
APPLY_CHAT_TEMPLATE = True
26+
TRUST_REMOTE_CODE = False
27+
28+
def test_llama4(self):
29+
self.quant_lm_eval()

0 commit comments

Comments
 (0)