Skip to content

Commit 03b3d61

Browse files
authored
fix mtp (#4105)
1 parent 17a2717 commit 03b3d61

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

fastdeploy/model_executor/models/ernie4_5_mtp.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import re
1920
from functools import partial
2021
from typing import Dict, Union
2122

@@ -250,7 +251,7 @@ def __init__(
250251
self.embed_tokens = fd_config.speculative_config.sharing_model.ernie.embed_tokens
251252
self.norm = fd_config.speculative_config.sharing_model.ernie.norm
252253

253-
self.layers = nn.LayerList(
254+
self.mtp_block = nn.LayerList(
254255
[
255256
Ernie4_5_DecoderLayer(
256257
fd_config=fd_config,
@@ -296,7 +297,7 @@ def load_state_dict(self, state_dict):
296297
self.eh_proj.load_state_dict(state_dict)
297298
for i in range(self.num_layers):
298299
logger.info(f"Start load layer {i}")
299-
self.layers[i].load_state_dict(state_dict)
300+
self.mtp_block[i].load_state_dict(state_dict)
300301

301302
def forward(
302303
self,
@@ -315,7 +316,7 @@ def forward(
315316
hidden_states = self.eh_proj(inputs_embedding)
316317
residual = None
317318
for i in range(self.num_layers):
318-
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
319+
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
319320

320321
hidden_states = hidden_states + residual
321322

@@ -374,17 +375,23 @@ def load_weights(self, weights_iterator) -> None:
374375
weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
375376
"""
376377

377-
from fastdeploy.model_executor.utils import default_weight_loader
378+
from fastdeploy.model_executor.utils import (
379+
default_weight_loader,
380+
process_weights_after_loading,
381+
)
378382

379383
all_param_mapping = [
380384
# (param_name, weight_name, expert_id, shard_id)
381385
("embed_tokens.embeddings", "embed_tokens", None, None),
382386
("lm_head.linear", "lm_head", None, None),
387+
("enorm", "mtp_emb_norm.0", None, None),
388+
("hnorm", "mtp_hidden_norm.0", None, None),
389+
("eh_proj.linear", "mtp_linear_proj.0", None, None),
383390
]
384391

385392
params_dict = dict(self.named_parameters())
386393
shard_id = None
387-
394+
process_weights_after_loading_fn = process_weights_after_loading(dict(self.named_sublayers()))
388395
for loaded_weight_name, loaded_weight in weights_iterator:
389396
for param_name, weight_name, exp_id, shard_id in all_param_mapping:
390397
if weight_name not in loaded_weight_name:
@@ -396,11 +403,16 @@ def load_weights(self, weights_iterator) -> None:
396403
else:
397404
if loaded_weight_name not in params_dict.keys():
398405
continue
406+
model_param_name = loaded_weight_name
399407
param = params_dict[loaded_weight_name]
400408

401409
# Get weight loader from parameter and set weight
402410
weight_loader = getattr(param, "weight_loader", default_weight_loader(self.fd_config))
403411
weight_loader(param, loaded_weight)
412+
model_sublayer_name = re.sub(
413+
r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$", "", model_param_name
414+
)
415+
process_weights_after_loading_fn(model_sublayer_name, param)
404416

405417
def compute_logits(self, hidden_states: paddle.Tensor):
406418
"""

0 commit comments

Comments
 (0)