Skip to content

Commit 77afbbc

Browse files
committed
Rebase and Minor Fixes
Signed-off-by: Mohit Soni <[email protected]>
1 parent e877f9f commit 77afbbc

File tree

3 files changed

+43
-7
lines changed

3 files changed

+43
-7
lines changed

QEfficient/base/pytorch_transforms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,10 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
162162
logger.info(f"[layer {layer_idx:02d}] loaded gate_proj & up_proj from fused tensor (shape {fused.shape})")
163163
transformed = True
164164

165-
model.language_model = model_tmp
165+
if hasattr(model, "language_model"):
166+
model.language_model = model_tmp
167+
else:
168+
model = model_tmp
166169
return model, transformed
167170

168171

QEfficient/transformers/models/llama4/modeling_llama4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from torch import nn
13-
from transformers.cache_utils import Cache, DynamicCache
13+
from transformers.cache_utils import Cache
1414
from transformers.modeling_outputs import (
1515
BaseModelOutput,
1616
BaseModelOutputWithPast,
@@ -32,6 +32,7 @@
3232
repeat_kv,
3333
)
3434

35+
from QEfficient.transformers.cache_utils import QEffDynamicCache
3536
from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask
3637
from QEfficient.utils import constants
3738
from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config
@@ -704,7 +705,7 @@ def forward(
704705
return_legacy_cache = False
705706
if use_cache and not isinstance(past_key_values, Cache):
706707
return_legacy_cache = True
707-
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
708+
past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values)
708709

709710
if cache_position is None:
710711
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from types import MethodType
99
from typing import Optional, Tuple
1010

11-
import transformers
1211
from torch import nn
1312
from transformers.models.codegen.modeling_codegen import (
1413
CodeGenAttention,
@@ -50,6 +49,16 @@
5049
GraniteModel,
5150
GraniteRMSNorm,
5251
)
52+
from transformers.models.granitemoe.modeling_granitemoe import (
53+
GraniteMoeAttention,
54+
GraniteMoeForCausalLM,
55+
GraniteMoeModel,
56+
GraniteMoeMoE,
57+
GraniteMoeParallelExperts,
58+
GraniteMoeRMSNorm,
59+
GraniteMoeRotaryEmbedding,
60+
GraniteMoeTopKGating,
61+
)
5362
from transformers.models.llama.modeling_llama import (
5463
LlamaAttention,
5564
LlamaDecoderLayer,
@@ -72,6 +81,9 @@
7281
from transformers.models.llava.modeling_llava import (
7382
LlavaForConditionalGeneration,
7483
)
84+
from transformers.models.llava_next.modeling_llava_next import (
85+
LlavaNextForConditionalGeneration,
86+
)
7587
from transformers.models.mistral.modeling_mistral import (
7688
MistralAttention,
7789
MistralDecoderLayer,
@@ -133,7 +145,6 @@
133145

134146
from QEfficient.base.pytorch_transforms import ModuleMappingTransform, ModuleMethodMapperTransform
135147
from QEfficient.customop import CustomRMSNormAIC, GemmaCustomRMSNormAIC
136-
from QEfficient.transformers.cache_utils import QEffDynamicCache
137148
from QEfficient.transformers.models.codegen.modeling_codegen import (
138149
QEffCodeGenAttention,
139150
QeffCodeGenBlock,
@@ -181,6 +192,15 @@
181192
QEffGraniteForCausalLM,
182193
QEffGraniteModel,
183194
)
195+
from QEfficient.transformers.models.granitemoe.modeling_granitemoe import (
196+
QEffGraniteMoeAttention,
197+
QEffGraniteMoeForCausalLM,
198+
QEffGraniteMoeModel,
199+
QEffGraniteMoeMoE,
200+
QEffGraniteMoeParallelExperts,
201+
QEffGraniteMoeRotaryEmbedding,
202+
QEffGraniteMoeTopKGating,
203+
)
184204
from QEfficient.transformers.models.internvl.modeling_internvl import (
185205
QEffInternVisionEmbeddings,
186206
QEffInternVLModel,
@@ -205,6 +225,9 @@
205225
from QEfficient.transformers.models.llava.modeling_llava import (
206226
QEffLlavaForConditionalGeneration,
207227
)
228+
from QEfficient.transformers.models.llava_next.modeling_llava_next import (
229+
QEffLlavaNextForConditionalGeneration,
230+
)
208231
from QEfficient.transformers.models.mistral.modeling_mistral import (
209232
QEffMistralAttention,
210233
QEffMistralDecoderLayer,
@@ -287,6 +310,7 @@ class CustomOpsTransform(ModuleMappingTransform):
287310
Qwen2RMSNorm: CustomRMSNormAIC,
288311
MllamaTextRMSNorm: CustomRMSNormAIC,
289312
GraniteRMSNorm: CustomRMSNormAIC,
313+
GraniteMoeRMSNorm: CustomRMSNormAIC,
290314
}
291315

292316

@@ -329,6 +353,8 @@ class KVCacheTransform(ModuleMappingTransform):
329353
Llama4TextExperts: QEffLlama4TextExperts,
330354
# Llava
331355
LlavaForConditionalGeneration: QEffLlavaForConditionalGeneration,
356+
# Llava Next
357+
LlavaNextForConditionalGeneration: QEffLlavaNextForConditionalGeneration,
332358
# Gemma
333359
GemmaAttention: QEffGemmaAttention,
334360
GemmaDecoderLayer: QEffGemmaDecoderLayer,
@@ -343,6 +369,14 @@ class KVCacheTransform(ModuleMappingTransform):
343369
GraniteModel: QEffGraniteModel,
344370
GraniteForCausalLM: QEffGraniteForCausalLM,
345371
GraniteAttention: QEffGraniteAttention,
372+
# GraniteMoe
373+
GraniteMoeModel: QEffGraniteMoeModel,
374+
GraniteMoeForCausalLM: QEffGraniteMoeForCausalLM,
375+
GraniteMoeAttention: QEffGraniteMoeAttention,
376+
GraniteMoeRotaryEmbedding: QEffGraniteMoeRotaryEmbedding,
377+
GraniteMoeParallelExperts: QEffGraniteMoeParallelExperts,
378+
GraniteMoeTopKGating: QEffGraniteMoeTopKGating,
379+
GraniteMoeMoE: QEffGraniteMoeMoE,
346380
# mllama
347381
MllamaTextRMSNorm: CustomRMSNormAIC,
348382
MllamaTextSelfAttention: QEffMllamaTextSelfAttention,
@@ -407,8 +441,6 @@ class KVCacheTransform(ModuleMappingTransform):
407441
@classmethod
408442
def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]:
409443
model, transformed = super().apply(model)
410-
# FIXME: see if we can merge into _module_mapping dict
411-
transformers.cache_utils.DynamicCache.update = QEffDynamicCache.update
412444
return model, transformed
413445

414446

0 commit comments

Comments
 (0)