|
| 1 | +# Copyright 2024-2025 ModelCloud.ai |
| 2 | +# Copyright 2024-2025 [email protected] |
| 3 | +# Contact: [email protected], x.com/qubitium |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | + |
| 17 | +from transformers import AutoModelForImageTextToText |
| 18 | +from .._const import EXPERT_INDEX_PLACEHOLDER |
| 19 | +from ..base import BaseGPTQModel |
| 20 | + |
| 21 | +class Llama4GPTQ(BaseGPTQModel): |
| 22 | + # some bug in the attention_mask of transformers.modeling_llama4, |
| 23 | + # so batch quantization for Llama4 is temporarily not supported. |
| 24 | + support_batch_quantize = False |
| 25 | + loader = AutoModelForImageTextToText |
| 26 | + |
| 27 | + base_modules = ["language_model.model.embed_tokens", "language_model.model.norm"] |
| 28 | + pre_lm_head_norm_module = "language_model.model.norm" |
| 29 | + |
| 30 | + layers_node = "language_model.model.layers" |
| 31 | + layer_type = "Llama4TextDecoderLayer" |
| 32 | + |
| 33 | + dynamic_expert_index = "num_local_experts" |
| 34 | + |
| 35 | + layer_modules = [ |
| 36 | + ["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj", "self_attn.o_proj"], |
| 37 | + |
| 38 | + [f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.gate_proj", f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.up_proj"], |
| 39 | + [f"feed_forward.experts.{EXPERT_INDEX_PLACEHOLDER}.down_proj"], |
| 40 | + |
| 41 | + ["feed_forward.shared_expert.gate_proj", "feed_forward.shared_expert.up_proj", "feed_forward.shared_expert.down_proj"], |
| 42 | + ] |
| 43 | + |
| 44 | + def before_model_load(self, load_quantized_model=False): |
| 45 | + if load_quantized_model: |
| 46 | + import torch |
| 47 | + import torch.nn as nn |
| 48 | + import transformers.models.llama4.modeling_llama4 as llama4_modeling |
| 49 | + from transformers.integrations.hub_kernels import use_kernel_forward_from_hub |
| 50 | + |
| 51 | + @use_kernel_forward_from_hub("Llama4TextMoe") |
| 52 | + class SequentialLlama4TextMoe(torch.nn.Module): |
| 53 | + def __init__(self, config): |
| 54 | + super().__init__() |
| 55 | + self.top_k = config.num_experts_per_tok |
| 56 | + self.hidden_dim = config.hidden_size |
| 57 | + print(config) |
| 58 | + self.num_experts = 16 |
| 59 | + self.experts = nn.ModuleList( |
| 60 | + [llama4_modeling.Llama4TextMLP(config) for _ in range(self.num_experts)] |
| 61 | + ) |
| 62 | + self.router = llama4_modeling.Llama4Router(config) |
| 63 | + self.shared_expert = llama4_modeling.Llama4TextMLP(config) |
| 64 | + |
| 65 | + def forward(self, hidden_states: torch.Tensor): |
| 66 | + hidden_states = hidden_states.reshape(-1, self.hidden_dim) |
| 67 | + router_logits = self.router(hidden_states) |
| 68 | + if isinstance(router_logits, tuple): |
| 69 | + router_scores, router_logits = router_logits |
| 70 | + router_scores = router_scores.t() |
| 71 | + else: |
| 72 | + # transformers < 4.54.0 only returns router_logits |
| 73 | + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) |
| 74 | + |
| 75 | + router_scores = ( |
| 76 | + torch.full_like(router_logits, float("-inf")) |
| 77 | + .scatter_(1, router_indices, router_top_value) |
| 78 | + .transpose(0, 1) |
| 79 | + ) |
| 80 | + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) |
| 81 | + |
| 82 | + out = self.shared_expert(hidden_states) |
| 83 | + for i in range(self.num_experts): |
| 84 | + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) |
| 85 | + |
| 86 | + return out, router_logits |
| 87 | + |
| 88 | + llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe |
| 89 | + |
| 90 | + |
| 91 | + def after_model_load(self, model, load_quantized_model=False): |
| 92 | + if load_quantized_model: |
| 93 | + return model |
| 94 | + |
| 95 | + import os |
| 96 | + import torch |
| 97 | + from concurrent.futures import ThreadPoolExecutor |
| 98 | + from functools import partial |
| 99 | + from transformers.modeling_utils import no_init_weights |
| 100 | + from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe |
| 101 | + |
| 102 | + # adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py |
| 103 | + class SequentialLlama4TextExperts(torch.nn.ModuleList): |
| 104 | + def __init__(self, config, original): |
| 105 | + self.num_experts = original.gate_up_proj.shape[0] |
| 106 | + with no_init_weights(): |
| 107 | + super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) |
| 108 | + intermediate_size = original.down_proj.shape[1] |
| 109 | + |
| 110 | + with torch.no_grad(): |
| 111 | + # Batch process all expert parameters to avoid loops |
| 112 | + gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) |
| 113 | + down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) |
| 114 | + |
| 115 | + # Batch split and transpose |
| 116 | + gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous() |
| 117 | + up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous() |
| 118 | + down_batch = down_batch.transpose(-2, -1).contiguous() |
| 119 | + |
| 120 | + # Batch assignment |
| 121 | + for i in range(self.num_experts): |
| 122 | + self[i].gate_proj.weight.data = gate_batch[i] |
| 123 | + self[i].up_proj.weight.data = up_batch[i] |
| 124 | + self[i].down_proj.weight.data = down_batch[i] |
| 125 | + |
| 126 | + class SequentialLlama4TextMoe(torch.nn.Module): |
| 127 | + def __init__(self, config, original): |
| 128 | + super().__init__() |
| 129 | + self.top_k = config.num_experts_per_tok |
| 130 | + self.hidden_dim = config.hidden_size |
| 131 | + self.num_experts = config.num_local_experts |
| 132 | + self.experts = SequentialLlama4TextExperts(config, original.experts) |
| 133 | + self.router = original.router |
| 134 | + self.shared_expert = original.shared_expert |
| 135 | + |
| 136 | + def forward(self, hidden_states: torch.Tensor): |
| 137 | + hidden_states = hidden_states.reshape(-1, self.hidden_dim) |
| 138 | + router_logits = self.router(hidden_states) |
| 139 | + if isinstance(router_logits, tuple): |
| 140 | + router_scores, router_logits = router_logits |
| 141 | + router_scores = router_scores.t() |
| 142 | + else: |
| 143 | + # transformers < 4.54.0 only returns router_logits |
| 144 | + router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) |
| 145 | + |
| 146 | + router_scores = ( |
| 147 | + torch.full_like(router_logits, float("-inf")) |
| 148 | + .scatter_(1, router_indices, router_top_value) |
| 149 | + .transpose(0, 1) |
| 150 | + ) |
| 151 | + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) |
| 152 | + |
| 153 | + out = self.shared_expert(hidden_states) |
| 154 | + for i in range(self.num_experts): |
| 155 | + out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) |
| 156 | + |
| 157 | + return out, router_logits |
| 158 | + |
| 159 | + model = model.to("cpu") |
| 160 | + def process_module(name, module, model, config): |
| 161 | + if isinstance(module, Llama4TextMoe): |
| 162 | + new_module = SequentialLlama4TextMoe(config=config, original=module) |
| 163 | + parent, child = name.rsplit(".", maxsplit=1) |
| 164 | + print("replace moe" + name + child) |
| 165 | + parent = model.get_submodule(parent) |
| 166 | + setattr(parent, child, new_module) |
| 167 | + print("cpu count", os.cpu_count()) |
| 168 | + with ThreadPoolExecutor(max_workers=8) as executor: |
| 169 | + process_fn = partial(process_module, model=model, config=model.config.get_text_config()) |
| 170 | + list(executor.map(lambda x: process_fn(x[0], x[1]), model.named_modules())) |
| 171 | + |
| 172 | + return model |
0 commit comments