Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions paddleformers/nn/pp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,12 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
_embed_cls = None
_rotary_emb_cls = None
_norm_cls = "rms_norm"
_mtp_layer_pipe_cls = None
_embedding_pipe_cls = None
_decoder_layer_pipe_cls = None
_criterion_pipe_cls = None
_lmhead_pipe_cls = None
_rms_norm_pipe_cls = None

def __init__(self, config: PretrainedConfig, **kwargs):
if getattr(config, "sliding_window", None) is not None and "sliding_attention" in getattr(
Expand All @@ -522,11 +528,18 @@ def __init__(self, config: PretrainedConfig, **kwargs):
# dynamic inherit DecoderLayer
if self._decoder_layer_cls is None:
raise ValueError("_decoder_layer_cls must be set before init.")

EmbeddingPipeCls = self._embedding_pipe_cls if self._embedding_pipe_cls is not None else Embedding

if self._decoder_layer_pipe_cls is None:
DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls)
else:
DecoderLayerPipe = self._decoder_layer_pipe_cls

LMHeadPipeCls = self._lmhead_pipe_cls if self._lmhead_pipe_cls is not None else LMHeadPipe
MTPLayerPipeCls = self._mtp_layer_pipe_cls if self._mtp_layer_pipe_cls is not None else None
RMSNormPipeCls = self._rms_norm_pipe_cls if self._rms_norm_pipe_cls is not None else RMSNormPipe

new_initializer_range = math.sqrt(0.3333 / config.hidden_size)
logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}")
config.initializer_range = new_initializer_range
Expand Down Expand Up @@ -572,7 +585,7 @@ def __init__(self, config: PretrainedConfig, **kwargs):
else:
self.add_sequential_layer(
LayerDesc(
EmbeddingPipe, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls
EmbeddingPipeCls, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls
),
"model",
)
Expand All @@ -586,6 +599,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
),
f"model.layers.{i}",
)
for i in range(config.num_nextn_predict_layers):
if MTPLayerPipeCls is not None:
self.add_sequential_layer(
LayerDesc(MTPLayerPipeCls, config=config, layer_idx=config.num_hidden_layers + i),
f"model.layers.{config.num_hidden_layers + i}",
)
for i in range(config.add_tail_layers):
self.add_sequential_layer(
LayerDesc(
Expand All @@ -595,22 +614,22 @@ def __init__(self, config: PretrainedConfig, **kwargs):
)

self.add_sequential_layer(
LayerDesc(RMSNormPipe if self._norm_cls == "rms_norm" else LayerNormPipe, config=config),
LayerDesc(RMSNormPipeCls if self._norm_cls == "rms_norm" else LayerNormPipe, config=config),
"model.norm",
)

if config.tie_word_embeddings:
self.add_sequential_layer(
SharedLayerDesc(
"model_shared_weight",
LMHeadPipe,
LMHeadPipeCls,
shared_weight_attr="embedding_weight",
config=config,
),
"lm_head",
)
else:
self.add_sequential_layer(LayerDesc(LMHeadPipe, config=config), "lm_head")
self.add_sequential_layer(LayerDesc(LMHeadPipeCls, config=config), "lm_head")
recompute_interval = 0

seg_method = config.pp_seg_method if hasattr(config, "pp_seg_method") else "layer:DecoderLayer|EmptyLayer"
Expand Down Expand Up @@ -643,10 +662,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
)

def get_loss_fn(self, config):
CriterionPipeCls = self._criterion_pipe_cls if self._criterion_pipe_cls is not None else CriterionLayerPipe

if config.get("dpo_config", None) is not None:
loss_fn = CriterionLayerPipe(config, use_infohub=True)
loss_fn = CriterionPipeCls(config, use_infohub=True)
else:
loss_fn = CriterionLayerPipe(config)
loss_fn = CriterionPipeCls(config)

return loss_fn

Expand Down
Loading
Loading