8
8
from types import MethodType
9
9
from typing import Optional , Tuple
10
10
11
- import transformers
12
11
from torch import nn
13
12
from transformers .models .codegen .modeling_codegen import (
14
13
CodeGenAttention ,
50
49
GraniteModel ,
51
50
GraniteRMSNorm ,
52
51
)
52
+ from transformers .models .granitemoe .modeling_granitemoe import (
53
+ GraniteMoeAttention ,
54
+ GraniteMoeForCausalLM ,
55
+ GraniteMoeModel ,
56
+ GraniteMoeMoE ,
57
+ GraniteMoeParallelExperts ,
58
+ GraniteMoeRMSNorm ,
59
+ GraniteMoeRotaryEmbedding ,
60
+ GraniteMoeTopKGating ,
61
+ )
53
62
from transformers .models .llama .modeling_llama import (
54
63
LlamaAttention ,
55
64
LlamaDecoderLayer ,
72
81
from transformers .models .llava .modeling_llava import (
73
82
LlavaForConditionalGeneration ,
74
83
)
84
+ from transformers .models .llava_next .modeling_llava_next import (
85
+ LlavaNextForConditionalGeneration ,
86
+ )
75
87
from transformers .models .mistral .modeling_mistral import (
76
88
MistralAttention ,
77
89
MistralDecoderLayer ,
133
145
134
146
from QEfficient .base .pytorch_transforms import ModuleMappingTransform , ModuleMethodMapperTransform
135
147
from QEfficient .customop import CustomRMSNormAIC , GemmaCustomRMSNormAIC
136
- from QEfficient .transformers .cache_utils import QEffDynamicCache
137
148
from QEfficient .transformers .models .codegen .modeling_codegen import (
138
149
QEffCodeGenAttention ,
139
150
QeffCodeGenBlock ,
181
192
QEffGraniteForCausalLM ,
182
193
QEffGraniteModel ,
183
194
)
195
+ from QEfficient .transformers .models .granitemoe .modeling_granitemoe import (
196
+ QEffGraniteMoeAttention ,
197
+ QEffGraniteMoeForCausalLM ,
198
+ QEffGraniteMoeModel ,
199
+ QEffGraniteMoeMoE ,
200
+ QEffGraniteMoeParallelExperts ,
201
+ QEffGraniteMoeRotaryEmbedding ,
202
+ QEffGraniteMoeTopKGating ,
203
+ )
184
204
from QEfficient .transformers .models .internvl .modeling_internvl import (
185
205
QEffInternVisionEmbeddings ,
186
206
QEffInternVLModel ,
205
225
from QEfficient .transformers .models .llava .modeling_llava import (
206
226
QEffLlavaForConditionalGeneration ,
207
227
)
228
+ from QEfficient .transformers .models .llava_next .modeling_llava_next import (
229
+ QEffLlavaNextForConditionalGeneration ,
230
+ )
208
231
from QEfficient .transformers .models .mistral .modeling_mistral import (
209
232
QEffMistralAttention ,
210
233
QEffMistralDecoderLayer ,
@@ -287,6 +310,7 @@ class CustomOpsTransform(ModuleMappingTransform):
287
310
Qwen2RMSNorm : CustomRMSNormAIC ,
288
311
MllamaTextRMSNorm : CustomRMSNormAIC ,
289
312
GraniteRMSNorm : CustomRMSNormAIC ,
313
+ GraniteMoeRMSNorm : CustomRMSNormAIC ,
290
314
}
291
315
292
316
@@ -329,6 +353,8 @@ class KVCacheTransform(ModuleMappingTransform):
329
353
Llama4TextExperts : QEffLlama4TextExperts ,
330
354
# Llava
331
355
LlavaForConditionalGeneration : QEffLlavaForConditionalGeneration ,
356
+ # Llava Next
357
+ LlavaNextForConditionalGeneration : QEffLlavaNextForConditionalGeneration ,
332
358
# Gemma
333
359
GemmaAttention : QEffGemmaAttention ,
334
360
GemmaDecoderLayer : QEffGemmaDecoderLayer ,
@@ -343,6 +369,14 @@ class KVCacheTransform(ModuleMappingTransform):
343
369
GraniteModel : QEffGraniteModel ,
344
370
GraniteForCausalLM : QEffGraniteForCausalLM ,
345
371
GraniteAttention : QEffGraniteAttention ,
372
+ # GraniteMoe
373
+ GraniteMoeModel : QEffGraniteMoeModel ,
374
+ GraniteMoeForCausalLM : QEffGraniteMoeForCausalLM ,
375
+ GraniteMoeAttention : QEffGraniteMoeAttention ,
376
+ GraniteMoeRotaryEmbedding : QEffGraniteMoeRotaryEmbedding ,
377
+ GraniteMoeParallelExperts : QEffGraniteMoeParallelExperts ,
378
+ GraniteMoeTopKGating : QEffGraniteMoeTopKGating ,
379
+ GraniteMoeMoE : QEffGraniteMoeMoE ,
346
380
# mllama
347
381
MllamaTextRMSNorm : CustomRMSNormAIC ,
348
382
MllamaTextSelfAttention : QEffMllamaTextSelfAttention ,
@@ -407,8 +441,6 @@ class KVCacheTransform(ModuleMappingTransform):
407
441
@classmethod
408
442
def apply (cls , model : nn .Module ) -> Tuple [nn .Module , bool ]:
409
443
model , transformed = super ().apply (model )
410
- # FIXME: see if we can merge into _module_mapping dict
411
- transformers .cache_utils .DynamicCache .update = QEffDynamicCache .update
412
444
return model , transformed
413
445
414
446
0 commit comments