16
16
17
17
from __future__ import annotations
18
18
19
+ import re
19
20
from functools import partial
20
21
from typing import Dict , Union
21
22
@@ -250,7 +251,7 @@ def __init__(
250
251
self .embed_tokens = fd_config .speculative_config .sharing_model .ernie .embed_tokens
251
252
self .norm = fd_config .speculative_config .sharing_model .ernie .norm
252
253
253
- self .layers = nn .LayerList (
254
+ self .mtp_block = nn .LayerList (
254
255
[
255
256
Ernie4_5_DecoderLayer (
256
257
fd_config = fd_config ,
@@ -296,7 +297,7 @@ def load_state_dict(self, state_dict):
296
297
self .eh_proj .load_state_dict (state_dict )
297
298
for i in range (self .num_layers ):
298
299
logger .info (f"Start load layer { i } " )
299
- self .layers [i ].load_state_dict (state_dict )
300
+ self .mtp_block [i ].load_state_dict (state_dict )
300
301
301
302
def forward (
302
303
self ,
@@ -315,7 +316,7 @@ def forward(
315
316
hidden_states = self .eh_proj (inputs_embedding )
316
317
residual = None
317
318
for i in range (self .num_layers ):
318
- hidden_states , residual = self .layers [i ](forward_meta , hidden_states , residual )
319
+ hidden_states , residual = self .mtp_block [i ](forward_meta , hidden_states , residual )
319
320
320
321
hidden_states = hidden_states + residual
321
322
@@ -374,17 +375,23 @@ def load_weights(self, weights_iterator) -> None:
374
375
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
375
376
"""
376
377
377
- from fastdeploy .model_executor .utils import default_weight_loader
378
+ from fastdeploy .model_executor .utils import (
379
+ default_weight_loader ,
380
+ process_weights_after_loading ,
381
+ )
378
382
379
383
all_param_mapping = [
380
384
# (param_name, weight_name, expert_id, shard_id)
381
385
("embed_tokens.embeddings" , "embed_tokens" , None , None ),
382
386
("lm_head.linear" , "lm_head" , None , None ),
387
+ ("enorm" , "mtp_emb_norm.0" , None , None ),
388
+ ("hnorm" , "mtp_hidden_norm.0" , None , None ),
389
+ ("eh_proj.linear" , "mtp_linear_proj.0" , None , None ),
383
390
]
384
391
385
392
params_dict = dict (self .named_parameters ())
386
393
shard_id = None
387
-
394
+ process_weights_after_loading_fn = process_weights_after_loading ( dict ( self . named_sublayers ()))
388
395
for loaded_weight_name , loaded_weight in weights_iterator :
389
396
for param_name , weight_name , exp_id , shard_id in all_param_mapping :
390
397
if weight_name not in loaded_weight_name :
@@ -396,11 +403,16 @@ def load_weights(self, weights_iterator) -> None:
396
403
else :
397
404
if loaded_weight_name not in params_dict .keys ():
398
405
continue
406
+ model_param_name = loaded_weight_name
399
407
param = params_dict [loaded_weight_name ]
400
408
401
409
# Get weight loader from parameter and set weight
402
410
weight_loader = getattr (param , "weight_loader" , default_weight_loader (self .fd_config ))
403
411
weight_loader (param , loaded_weight )
412
+ model_sublayer_name = re .sub (
413
+ r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$" , "" , model_param_name
414
+ )
415
+ process_weights_after_loading_fn (model_sublayer_name , param )
404
416
405
417
def compute_logits (self , hidden_states : paddle .Tensor ):
406
418
"""
0 commit comments