diff --git a/llm/auto_parallel/gpt-3/run_pretrain_auto.py b/llm/auto_parallel/gpt-3/run_pretrain_auto.py index 16fb05dae15b..8d01e6bdef4e 100644 --- a/llm/auto_parallel/gpt-3/run_pretrain_auto.py +++ b/llm/auto_parallel/gpt-3/run_pretrain_auto.py @@ -38,13 +38,16 @@ CosineAnnealingWithWarmupDecay, GPTConfig, GPTForCausalLMAuto, + GPTForCausalLMNet, GPTPretrainingCriterionAuto, + GPTPretrainingCriterionNet, LinearAnnealingWithWarmupDecay, ) from paddlenlp.utils.log import logger MODEL_CLASSES = { "gpt": (GPTConfig, GPTForCausalLMAuto, GPTPretrainingCriterionAuto), + "gpt_network": (GPTConfig, GPTForCausalLMNet, GPTPretrainingCriterionNet), } from paddlenlp.data.causal_dataset import ( @@ -104,6 +107,10 @@ class PreTrainingArguments(AutoTrainingArguments): default=False, metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, ) + use_intermediate_api: bool = field( + default=False, + metadata={"help": "Weather to use auto_parallel intermediate api"}, + ) def __post_init__(self): super().__post_init__() diff --git a/llm/auto_parallel/llama/llama_with_api.sh b/llm/auto_parallel/llama/llama_with_api.sh new file mode 100644 index 000000000000..f6826509544e --- /dev/null +++ b/llm/auto_parallel/llama/llama_with_api.sh @@ -0,0 +1,101 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +export NNODES=1 +export PADDLE_TRAINERS_NUM=1 + +export GLOG_v=0 + +export FLAGS_cudnn_deterministic=1 +export FLAGS_embedding_deterministic=1 +export FLAGS_max_inplace_grad_add=65536 +export FLAGS_enable_auto_parallel_align_mode=1 +export FLAGS_enable_pir_api=1 + +task_name="llama_auto" +rm -rf output +rm -rf log + +export SOT_LOG_LEVEL=4 +export PYTHONPATH=../../../:$PYTHONPATH + +#ulimit -c unlimited + +python -u -m paddle.distributed.launch \ + --gpus "0,1,2,3,4,5,6,7" \ + --log_dir "log" \ + ./run_pretrain_auto.py \ + --model_name_or_path "facebook/llama-7b" \ + --tokenizer_name_or_path "facebook/llama-7b" \ + --input_dir "./data" \ + --output_dir "./output" \ + --split 949,50,1 \ + --to_static false \ + --pipeline_parallel_degree 2 \ + --sharding_parallel_degree 2 \ + --tensor_parallel_degree 2 \ + --virtual_pp_degree 1 \ + --pipeline_schedule_mode "VPP" \ + --weight_decay 0.01 \ + --warmup_ratio 0.01 \ + --max_grad_norm 1.0 \ + --learning_rate 3e-05 \ + --min_learning_rate 3e-06 \ + --max_steps 10 \ + --logging_steps 1 \ + --eval_steps 10000 \ + --save_steps 1000 \ + --continue_training 0 \ + --do_train true \ + --do_eval false \ + --do_predict false \ + --disable_tqdm true \ + --save_total_limit 2 \ + --device gpu \ + --model_type "llama_network" \ + --use_intermediate_api true \ + --dataloader_num_workers 4 \ + --distributed_dataloader 0 \ + --enable_auto_parallel 1 \ + --per_device_train_batch_size 1 \ + --gradient_accumulation_steps 32 \ + --per_device_eval_batch_size 1 \ + --recompute false \ + --recompute_use_reentrant true \ + --skip_profile_timer true \ + --recompute_granularity full \ + --pp_recompute_interval 0 \ + --bf16 true \ + --fp16_opt_level "O2" \ + --amp_master_grad true \ + --fuse_attention_ffn false \ + --fuse_attention_qkv true \ + --use_flash_attention true \ + --use_fused_rope true \ + --use_fused_rms_norm false \ + --max_seq_length 4096 \ + --sequence_parallel true \ + --sharding "stage1" \ + --sharding_parallel_config "enable_stage1_tensor_fusion enable_stage1_overlap" \ + --tensor_parallel_config "enable_mp_async_allreduce" \ + --num_hidden_layers 4 \ + --auto_parallel_resume_form_hybrid_parallel true \ diff --git a/llm/auto_parallel/llama/run_pretrain_auto.py b/llm/auto_parallel/llama/run_pretrain_auto.py index 8d6ff5462807..d205e7afbe33 100644 --- a/llm/auto_parallel/llama/run_pretrain_auto.py +++ b/llm/auto_parallel/llama/run_pretrain_auto.py @@ -41,12 +41,15 @@ LinearAnnealingWithWarmupDecay, LlamaConfig, LlamaForCausalLM3DAuto, + LlamaForCausalLM3DNet, LlamaPretrainingCriterion3DAuto, + LlamaPretrainingCriterion3DNet, ) from paddlenlp.utils.log import logger MODEL_CLASSES = { "llama": (LlamaConfig, LlamaForCausalLM3DAuto, LlamaPretrainingCriterion3DAuto), + "llama_network": (LlamaConfig, LlamaForCausalLM3DNet, LlamaPretrainingCriterion3DNet), } @@ -107,6 +110,10 @@ class PreTrainingArguments(AutoTrainingArguments): default=False, metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, ) + use_intermediate_api: bool = field( + default=False, + metadata={"help": "Weather to use auto_parallel intermediate api"}, + ) def __post_init__(self): super().__post_init__() @@ -551,6 +558,7 @@ def main(): config.use_recompute = training_args.recompute config.tensor_parallel_degree = training_args.tensor_parallel_degree config.tensor_parallel_rank = training_args.tensor_parallel_rank + config.sharding_parallel_degree = training_args.sharding_parallel_degree if training_args.strategy.pipeline.enable and config.virtual_pp_degree > 1: pipeline = training_args.strategy.pipeline @@ -571,10 +579,6 @@ def main(): model = model_class.from_config(config, dtype="float32") criterion = criterion_class(config) - for param in model.parameters(): - assert not param._is_initialized() - param.initialize() - if training_args.recompute: def fn(layer): @@ -628,6 +632,7 @@ def fn(layer): eval_dataset=eval_dataset if training_args.do_eval else None, optimizers=(None, lr_scheduler), tokenizer=tokenizer, + model_args=model_args, ) checkpoint = None diff --git a/llm/auto_parallel/qwen/run_intermediate_api.sh b/llm/auto_parallel/qwen/run_intermediate_api.sh new file mode 100644 index 000000000000..2ab4ab9f0358 --- /dev/null +++ b/llm/auto_parallel/qwen/run_intermediate_api.sh @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# just for debug + +set -x +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +export NNODES=1 +export PADDLE_TRAINERS_NUM=1 +export FLAGS_call_stack_level=3 +export FLAGS_use_cuda_managed_memory=true + +task_name="llama_auto" +rm -rf output/$task_name/ +rm -rf "output/$task_name""_log" + +export SOT_LOG_LEVEL=4 +export PYTHONPATH=../../../:$PYTHONPATH + + +rm -rf ./log/auto_3d_auto + +export FLAGS_embedding_deterministic=1 +export FLAGS_cudnn_deterministic=1 +export FLAGS_max_inplace_grad_add=65536 +export NVIDIA_TF32_OVERRIDE=0 +export FLAGS_enable_pir_in_executor=1 +export FLAGS_enable_pir_api=1 + + +python -u -m paddle.distributed.launch \ + --gpus "4,5" \ + --log_dir "log/auto_3d_auto" \ + run_pretrain_3D_auto.py \ + --model_name_or_path "qwen/qwen-14b" \ + --tokenizer_name_or_path "qwen/qwen-14b" \ + --model_type "qwen_network" \ + --use_intermediate_api true \ + --input_dir "../data" \ + --output_dir "./checkpoints/qwen_pretrain_ckpts" \ + --per_device_train_batch_size 1\ + --gradient_accumulation_steps 32\ + --per_device_eval_batch_size 16\ + --sharding "stage1" \ + --sharding_parallel_degree 1\ + --tensor_parallel_degree 2\ + --pipeline_parallel_degree 1\ + --virtual_pp_degree 1\ + --use_flash_attention false\ + --use_fused_rms_norm false\ + --use_fused_rope false\ + --max_seq_length 4096\ + --learning_rate 3e-05\ + --min_learning_rate 3e-06\ + --scale_loss 1024\ + --warmup_steps 30\ + --logging_steps 1\ + --max_steps 10000\ + --save_steps 1000\ + --eval_steps 10000\ + --weight_decay 0.01\ + --bf16 true\ + --fp16_opt_level "O2"\ + --amp_master_grad true \ + --warmup_ratio 0.01\ + --max_grad_norm 0.0\ + --dataloader_num_workers 4\ + --continue_training 0\ + --do_train true\ + --do_eval false\ + --do_predict false \ + --disable_tqdm true\ + --recompute false\ + --recompute_granularity "core_attn"\ + --recompute_use_reentrant true\ + --distributed_dataloader 0\ + --save_total_limit 2\ + --enable_auto_parallel 1\ + --to_static 1 \ + --num_hidden_layers 1 \ + --attention_probs_dropout_prob 0 \ + --hidden_dropout_prob 0 \ + --auto_parallel_resume_form_hybrid_parallel true \ diff --git a/llm/auto_parallel/qwen/run_pretrain_3D_auto.py b/llm/auto_parallel/qwen/run_pretrain_3D_auto.py index e1e0eae25491..98a13d9994b7 100644 --- a/llm/auto_parallel/qwen/run_pretrain_3D_auto.py +++ b/llm/auto_parallel/qwen/run_pretrain_3D_auto.py @@ -40,12 +40,15 @@ LinearAnnealingWithWarmupDecay, QWenConfig, QWenForCausalLM3DAuto, + QWenForCausalLM3DNet, QWenPretrainingCriterionAuto, + QWenPretrainingCriterionNet, ) from paddlenlp.utils.log import logger MODEL_CLASSES = { "qwen": (QWenConfig, QWenForCausalLM3DAuto, QWenPretrainingCriterionAuto), + "qwen_network": (QWenConfig, QWenForCausalLM3DNet, QWenPretrainingCriterionNet), } from paddlenlp.data.causal_dataset import ( @@ -113,6 +116,10 @@ class PreTrainingArguments(AutoTrainingArguments): default=False, metadata={"help": "whether use lazy init for model parameters"}, ) + use_intermediate_api: bool = field( + default=False, + metadata={"help": "Weather to use auto_parallel intermediate api"}, + ) def __post_init__(self): super().__post_init__() @@ -258,6 +265,8 @@ class ModelArguments: default=False, metadata={"help": "recompute_use_reentrant"}, ) + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) def create_pretrained_dataset( @@ -559,7 +568,11 @@ def main(): # Create the learning_rate sheduler and optimizer if training_args.decay_steps is None: training_args.decay_steps = training_args.max_steps - warmup_steps = training_args.warmup_ratio * training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps lr_scheduler = None if training_args.lr_scheduler_type.value == "cosine": diff --git a/paddlenlp/trainer/auto_trainer.py b/paddlenlp/trainer/auto_trainer.py index 2ebaea1d5699..35bdbbfa7b85 100644 --- a/paddlenlp/trainer/auto_trainer.py +++ b/paddlenlp/trainer/auto_trainer.py @@ -22,9 +22,14 @@ import paddle.distributed as dist import paddle.nn as nn from paddle.distributed import fleet +from paddle.distributed.auto_parallel.intermediate.parallelize import ( + parallelize_model, + parallelize_optimizer, +) from tqdm.auto import tqdm from paddlenlp.trainer import Trainer +from paddlenlp.transformers.model_utils import PretrainedModel from ..utils.batch_sampler import DistributedBatchSampler as NlpDistributedBatchSampler from ..utils.log import logger @@ -66,6 +71,45 @@ def loss_func(loss, outputs): kwargs.update({"criterion": loss_func}) + sequence_parallel = False + if kwargs.get("model_args", None) is not None: + model_args = kwargs.pop("model_args") + if hasattr(model_args, "sequence_parallel"): + sequence_parallel = model_args.sequence_parallel + + if kwargs.get("args", None) is not None and kwargs["args"].use_intermediate_api: + model = kwargs.get("model", None) + assert model is not None + assert isinstance(model, PretrainedModel) + for param in model.parameters(): + assert not param._is_initialized(), "intermediate_api needs lazy init" + + auto_dist_degree = { + "tensor_parallel": kwargs["args"].tensor_parallel_degree > 1, + "sequence_parallel": sequence_parallel, + "pipeline_parallel": kwargs["args"].pipeline_parallel_degree > 1, + "data_sharding_parallel": kwargs["args"].dataset_world_size > 1, + "sharding": kwargs["args"].sharding, + "sharding_mesh_dim": kwargs["args"].sharding_parallel_mesh_dimension, + } + auto_dist_config = model._generate_auto_dist_config(auto_dist_degree) + self.auto_dist_config = auto_dist_config + + model = parallelize_model( + model, + dp_config=auto_dist_config["dp_config"], + mp_config=auto_dist_config["mp_config"], + pp_config=auto_dist_config["pp_config"], + ) + + kwargs["model"] = model + + model = kwargs["model"] + for param in model.parameters(): + if not param._is_initialized(): + param.initialize() + kwargs["model"] = model + super().__init__(*args, **kwargs) assert self.args.enable_auto_parallel @@ -115,30 +159,39 @@ def _wrap_for_dist_loader(self, train_dataloader): return dist_loader def _wrap_for_auto(self, model, train_dataloader): - logger.info("Wrapping model for auto paralle") + logger.info(f"Wrapping model for auto parallel using intermediate api {self.args.use_intermediate_api} ") dist_loader = self._wrap_for_dist_loader(train_dataloader) - sharding_parallel_mesh_dimension = self.args.sharding_parallel_mesh_dimension - if ShardingOption.SHARD_OP in self.args.sharding: - self.optimizer = dist.shard_optimizer( - self.optimizer, - dist.ShardingStage1(sharding_mesh_dim=sharding_parallel_mesh_dimension), - self.args.gradient_accumulation_steps, - ) - elif ShardingOption.SHARD_GRAD_OP in self.args.sharding: - self.optimizer = dist.shard_optimizer( - self.optimizer, - dist.ShardingStage2(sharding_mesh_dim=sharding_parallel_mesh_dimension), - self.args.gradient_accumulation_steps, - ) - elif ShardingOption.FULL_SHARD in self.args.sharding: - self.optimizer = dist.shard_optimizer( + if self.args.use_intermediate_api: + assert self.auto_dist_config is not None + self.optimizer = parallelize_optimizer( self.optimizer, - dist.ShardingStage3(sharding_mesh_dim=sharding_parallel_mesh_dimension), - self.args.gradient_accumulation_steps, + dp_config=self.auto_dist_config["dp_config"], + mp_config=self.auto_dist_config["mp_config"], + pp_config=self.auto_dist_config["pp_config"], ) else: - self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) + sharding_parallel_mesh_dimension = self.args.sharding_parallel_mesh_dimension + if ShardingOption.SHARD_OP in self.args.sharding: + self.optimizer = dist.shard_optimizer( + self.optimizer, + dist.ShardingStage1(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, + ) + elif ShardingOption.SHARD_GRAD_OP in self.args.sharding: + self.optimizer = dist.shard_optimizer( + self.optimizer, + dist.ShardingStage2(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, + ) + elif ShardingOption.FULL_SHARD in self.args.sharding: + self.optimizer = dist.shard_optimizer( + self.optimizer, + dist.ShardingStage3(sharding_mesh_dim=sharding_parallel_mesh_dimension), + self.args.gradient_accumulation_steps, + ) + else: + self.optimizer = dist.shard_optimizer(self.optimizer, None, self.args.gradient_accumulation_steps) if self.args.to_static: unified_strategy = dist.Strategy() diff --git a/paddlenlp/transformers/gpt/__init__.py b/paddlenlp/transformers/gpt/__init__.py index 564ae17b1c15..06fb9ebfc7a1 100644 --- a/paddlenlp/transformers/gpt/__init__.py +++ b/paddlenlp/transformers/gpt/__init__.py @@ -15,5 +15,6 @@ from .configuration import * from .modeling import * from .modeling_auto import * +from .modeling_network import * from .modeling_pp import * from .tokenizer import * diff --git a/paddlenlp/transformers/gpt/modeling_network.py b/paddlenlp/transformers/gpt/modeling_network.py new file mode 100644 index 000000000000..90d7da773594 --- /dev/null +++ b/paddlenlp/transformers/gpt/modeling_network.py @@ -0,0 +1,1333 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import collections +import contextlib +import math +from functools import partial + +import numpy as np +import paddle +import paddle.incubate as incubate +import paddle.nn as nn +import paddle.nn.functional as F +import paddle.tensor as tensor +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.utils import recompute +from paddle.utils import try_import + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from ...utils.converter import StateDictNameMapping +from .. import PretrainedModel, register_base_model +from ..model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from .configuration import GPT_PRETRAINED_INIT_CONFIGURATION, GPTConfig + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +try: + from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd +except: + FusedDropoutAdd = None + +__all__ = [ + "GPTModelNet", + "GPTPretrainedModelNet", + "GPTPretrainingCriterionNet", + "GPTLMHeadModelNet", + "GPTForCausalLMNet", + "GPTEmbeddingsNet", + "GPTDecoderLayerNet", + "GPTLayerNorm", +] + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + if paddle.is_compiled_with_xpu(): + # xpu does not support set constant to -np.inf + mask = paddle.full_like(x, -1e4) + else: + mask = paddle.full_like(x, -np.inf) + mask.stop_gradient = True + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +def seed_guard_context(name=None): + if name in get_rng_state_tracker().states_: + return get_rng_state_tracker().rng_state(name) + else: + return contextlib.nullcontext() + + +def fast_layer_norm(input, weight, bias, eps): + fast_ln_lib = try_import("fast_ln") + return fast_ln_lib.fast_ln(input, weight, bias, eps)[0] + + +class GPTLayerNorm(nn.LayerNorm): + def __init__(self, config, normalized_shape, epsilon=1e-05, weight_attr=None, bias_attr=None, name=None): + super().__init__( + normalized_shape=normalized_shape, epsilon=epsilon, weight_attr=weight_attr, bias_attr=bias_attr + ) + self.config = config + self._check_normalized_shape(self._normalized_shape) + + def _check_normalized_shape(self, normalized_shape): + if isinstance(normalized_shape, (list, tuple)): + assert len(normalized_shape) == 1 + + def forward(self, input): + if self.config.use_fast_layer_norm: + return fast_layer_norm(input, self.weight, self.bias, self._epsilon) + return super().forward(input) + + +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make causal mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) + + +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask + + +class MultiHeadAttentionNet(nn.Layer): + """ + Attention mapps queries and a set of key-value pairs to outputs, and + Multi-Head Attention performs multiple parallel attention to jointly attending + to information from different representation subspaces. + + """ + + Cache = collections.namedtuple("Cache", ["k", "v"]) + + def __init__(self, config, ipp=None): + super(MultiHeadAttentionNet, self).__init__() + + self.config = config + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + + self.use_flash_attention = config.use_flash_attention if flash_attention else False + + self.head_dim = config.hidden_size // config.num_attention_heads + assert ( + self.head_dim * config.num_attention_heads == config.hidden_size + ), "hidden_size must be divisible by num_attention_heads" + + self.num_attention_heads = config.num_attention_heads # default, without tensor parallel + + if self.config.fuse_attention_qkv: + self.qkv_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias_attr=True) + else: + self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) + self.k_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) + self.v_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) + + self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias_attr=True) + + def _fuse_prepare_qkv(self, query, use_cache=False, past_key_value=None): + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs / n, seq_len, num_head, head_dim] (n is model parallelism) + target_shape = [-1, self.config.seq_length, self.num_attention_heads, 3 * self.head_dim] + else: + target_shape = [0, 0, self.num_attention_heads, 3 * self.head_dim] + + # bs, seq_len, num_head * 3*head_dim + mix_layer = self.qkv_proj(query) + # bs, seq_len, num_head, 3*head_dim + mix_layer = paddle.reshape_(mix_layer, target_shape) + # query_states, key_states, value_states => bs, seq_len, num_head, head_dim + query_states, key_states, value_states = paddle.split(mix_layer, num_or_sections=3, axis=-1) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + # concat along seqlen dimension + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + return query_states, key_states, value_states, past_key_value + + def _prepare_qkv(self, query, key, value, use_cache=False, past_key_value=None): + r""" + Prapares linear projected queries, keys and values for usage of subsequnt + multiple parallel attention. If `cache` is not None, using cached results + to reduce redundant calculations. + + """ + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs/n, seq_len, num_head * head_dim] (n is model parallelism) + target_shape = [-1, self.config.seq_length, self.num_attention_heads, self.head_dim] + else: + target_shape = [0, 0, self.num_attention_heads, self.head_dim] + + query_states = self.q_proj(query) + # [bs, seq_len, num_head, head_dim] + query_states = tensor.reshape(x=query_states, shape=target_shape) + + key_states = self.k_proj(key) + # [bs, seq_len, num_head, head_dim] + key_states = tensor.reshape(x=key_states, shape=target_shape) + + value_states = self.v_proj(value) + # [bs, seq_len, num_head, head_dim] + value_states = tensor.reshape(x=value_states, shape=target_shape) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + # concat along seqlen dimension + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + return query_states, key_states, value_states, past_key_value + + def _flash_attention(self, q, k, v, attention_mask=None, output_attentions=False): + with seed_guard_context("local_seed"): + out, weights = flash_attention( + query=q, + key=k, + value=v, + dropout=self.config.attention_probs_dropout_prob, + causal=q.shape[1] != 1, + return_softmax=output_attentions, + training=self.training, + ) + # [bs, seq_len, num_head, head_dim] -> [bs, seq_len, num_head * head_dim] + out = tensor.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + return (out, weights) if output_attentions else out + + def _core_attention(self, q, k, v, attention_mask=None, output_attentions=False): + # [bs, seq_len, num_head, head_dim] -> [bs, num_head, seq_len, head_dim] + perm = [0, 2, 1, 3] + q = tensor.transpose(x=q, perm=perm) + k = tensor.transpose(x=k, perm=perm) + v = tensor.transpose(x=v, perm=perm) + # scale dot product attention + product = paddle.matmul(x=q * ((self.config.scale_qk_coeff * self.head_dim) ** -0.5), y=k, transpose_y=True) + if self.config.scale_qk_coeff != 1.0: + product = product.scale(self.config.scale_qk_coeff) + + # softmax_mask_fuse_upper_triangle is not supported sif paddle is not compiled with cuda/rocm + if not paddle.is_compiled_with_cuda(): + attention_mask = get_triangle_upper_mask(product, attention_mask) + if attention_mask is not None: + product = product + attention_mask.astype(product.dtype) + weights = F.softmax(product) + else: + weights = incubate.softmax_mask_fuse_upper_triangle(product) + + if self.config.attention_probs_dropout_prob: + with seed_guard_context("local_seed"): + weights = F.dropout( + weights, self.config.attention_probs_dropout_prob, training=self.training, mode="upscale_in_train" + ) + + out = paddle.matmul(weights, v) + + # combine heads + out = tensor.transpose(out, perm=[0, 2, 1, 3]) # bs, seq_len, num_head, head_dim + out = tensor.reshape(x=out, shape=[0, 0, -1]) # bs, seq_len, dim + + return (out, weights) if output_attentions else out + + def forward( + self, query, key, value, attention_mask=None, use_cache=False, past_key_value=None, output_attentions=False + ): + r""" + Applies multi-head attention to map queries and a set of key-value pairs + to outputs. + """ + key = query if key is None else key + value = query if value is None else value + if self.config.fuse_attention_qkv: + # [bs, seq_len, num_head, head_dim] + q, k, v, past_key_value = self._fuse_prepare_qkv(query, use_cache, past_key_value) + else: + # [bs, seq_len, num_head, head_dim] + q, k, v, past_key_value = self._prepare_qkv(query, key, value, use_cache, past_key_value) + + if self.config.use_flash_attention: + # Flash Attention now ignore attention mask + # Current Flash Attention doesn't support attn maskt + # Paddle Flash Attention input [batch_size, seq_len, num_heads, head_dim] + # Torch Flash Attention input (batch_size, seqlen, nheads, headdim) + # bsz, q_len, num_heads, head_dim = q.shape + # TODO: Support attention mask for flash attention + attention_func = self._flash_attention + else: + # scale dot product attention + # [bs, seq_len, num_head,] + attention_func = self._core_attention + + has_gradient = (not q.stop_gradient) or (not k.stop_gradient) or (not v.stop_gradient) + if self.enable_recompute and self.config.recompute_granularity == "core_attn" and has_gradient: + outputs = recompute(attention_func, q, k, v, attention_mask, output_attentions, use_reentrant=False) + else: + outputs = attention_func(q, k, v, attention_mask=attention_mask, output_attentions=output_attentions) + + if output_attentions: + out, weights = outputs + else: + out = outputs + + # if sequence_parallel is true, out shape are [bs, seq_len, num_head * head_dim / n] + # else their shape are [bs, q_len, num_head * head_dim / n], n is mp parallelism. + + if self.config.sequence_parallel: + bs, seq_len, dim = out.shape + out = out.reshape([bs * seq_len, dim]) # [bs, seq_len, dim / n] => [bs * seq_len, dim / n] + + # project to output + out = self.out_proj(out) + # if sequence_parallel is true, out shape are [bs * seq_len / n, dim] + # else their shape are [bs, seq_len, dim], n is mp parallelism. + + outs = [out] + if output_attentions: + outs.append(weights) + if use_cache: + outs.append(past_key_value) + return out if len(outs) == 1 else tuple(outs) + + +class TransformerDecoder(nn.Layer): + """ + TransformerDecoder is a stack of N decoder layers. + """ + + def __init__(self, config, decoder_layers, norm=None, hidden_size=None): + super(TransformerDecoder, self).__init__() + + self.config = config + self.layers = decoder_layers + + self.norm = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5) + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.norm.weight) + mark_as_sequence_parallel_parameter(self.norm.bias) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + + @paddle.jit.not_to_static + def recompute_training( + self, + layer_module: nn.Layer, + hidden_states: paddle.Tensor, + past_key_value: paddle.Tensor, + attention_mask: paddle.Tensor, + use_cache: bool, + output_attentions: paddle.Tensor, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + # GPTDecoderLayer + # def forward( + # self, hidden_states, attention_mask=None, use_cache=False, past_key_value=None, output_attentions=False + # ): + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + use_cache, + past_key_value, + use_reentrant=self.config.recompute_use_reentrant, + ) + return hidden_states + + def forward( + self, + hidden_states, + attention_mask=None, + use_cache=False, + past_key_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r""" + Applies a stack of N Transformer decoder layers on inputs. If `norm` is + provided, also applies layer normalization on the output of last decoder + layer. + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + + output = hidden_states + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + next_decoder_cache = () if use_cache else None + + for i, decoder_layer in enumerate(self.layers): + has_gradient = not output.stop_gradient + if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full": + outputs = self.recompute_training( + layer_module=decoder_layer, + hidden_states=output, + attention_mask=attention_mask, + use_cache=use_cache, + past_key_value=None, + output_attentions=output_attentions, + ) + else: + outputs = decoder_layer( + output, + attention_mask=attention_mask, + use_cache=use_cache, + past_key_value=past_key_values[i] if past_key_values is not None else None, + output_attentions=output_attentions, + ) + + # outputs = hidden_states if both use_cache and output_attentions are False + # Otherwise, outputs = (hidden_states, attention if output_attentions, cache if use_cache) + output = outputs[0] if (use_cache or output_attentions) else outputs + all_self_attentions = all_self_attentions + (outputs[1],) if output_attentions else None + all_hidden_states = all_hidden_states + (output,) if output_hidden_states else None + next_decoder_cache = next_decoder_cache + (outputs[-1],) if use_cache else None + + if self.norm is not None: + output = self.norm(output) + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + temp_list = [output, next_cache, all_hidden_states, all_self_attentions] + + if not (use_cache or output_attentions or output_hidden_states): + return output + + return tuple(v for v in temp_list if v is not None) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=output, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=None, + ) + + +class GPTDecoderLayerNet(nn.Layer): + """ + The transformer decoder layer. + + It contains multiheadattention and some linear layers. + """ + + def __init__(self, config: GPTConfig, ipp=None): + super(GPTDecoderLayerNet, self).__init__() + self.config = config + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + + if not FusedDropoutAdd: + config.use_fused_dropout_add = False + + self.self_attn = MultiHeadAttentionNet(config, ipp) + + self.linear1 = nn.Linear(config.hidden_size, config.intermediate_size, bias_attr=True) + self.linear2 = nn.Linear(config.intermediate_size, config.hidden_size, bias_attr=True) + # fix : change nn.LayerNorm(config.hidden_size, epsilon=1e-5, bias_attr=True) to GPTLayerNorm() + self.norm1 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5, bias_attr=True) + self.norm2 = GPTLayerNorm(config, config.hidden_size, epsilon=1e-5, bias_attr=True) + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.norm1.weight) + mark_as_sequence_parallel_parameter(self.norm1.bias) + mark_as_sequence_parallel_parameter(self.norm2.weight) + mark_as_sequence_parallel_parameter(self.norm2.bias) + if config.use_fused_dropout_add: + self.fused_dropout_add1 = FusedDropoutAdd(config.attention_probs_dropout_prob, mode="upscale_in_train") + self.fused_dropout_add2 = FusedDropoutAdd(config.hidden_dropout_prob, mode="upscale_in_train") + else: + self.dropout1 = nn.Dropout(config.attention_probs_dropout_prob, mode="upscale_in_train") + self.dropout2 = nn.Dropout(config.hidden_dropout_prob, mode="upscale_in_train") + + if config.hidden_activation == "gelu": + self.activation = F.gelu + else: + self.activation = getattr(F, config.hidden_activation) + + def forward( + self, hidden_states, attention_mask=None, use_cache=False, past_key_value=None, output_attentions=False + ): + # when sequence_parallel=True: + # hidden_states => [bs * seq_len / n, embed_dim] + residual = hidden_states + if self.config.normalize_before: + hidden_states = self.norm1(hidden_states) + + # self.self_attn: + # def forward( + # self, query, key, value, attention_mask=None, use_cache=False, past_key_value=None, output_attentions=False + # ): + # self.self_attn(...) --> hidden_states, weights, (past_key_value) + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and has_gradient and self.config.recompute_granularity == "full_attn": + hidden_states = recompute( + self.self_attn, + hidden_states, + None, + None, + attention_mask, + use_cache, + past_key_value, + output_attentions, + use_reentrant=False, + ) + else: + hidden_states = self.self_attn( + hidden_states, None, None, attention_mask, use_cache, past_key_value, output_attentions + ) + + # when sequence_parallel=True: + # hidden_states => [bs * seq_len / n, embed_dim] + incremental_cache = hidden_states[-1] if use_cache else None + attention_weights = hidden_states[1] if output_attentions else None + hidden_states = hidden_states[0] if (use_cache or output_attentions) else hidden_states + + # Use a ternary operator for a more concise assignment of current_seed + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + + # The 'with' block ensures the correct seed context is used + with seed_guard_context(current_seed): + if self.config.use_fused_dropout_add: + hidden_states = self.fused_dropout_add1(hidden_states, residual) + else: + hidden_states = residual + self.dropout1(hidden_states) + + if not self.config.normalize_before: + hidden_states = self.norm1(hidden_states) + + residual = hidden_states + if self.config.normalize_before: + hidden_states = self.norm2(hidden_states) + + # when sequence_parallel=True: + # hidden_states => [bs * seq_len / n, embed_dim] + with seed_guard_context(current_seed): + if not self.config.use_fused_dropout_add: + l_1 = self.linear1(hidden_states) + act = self.activation(l_1, approximate=True) + # NOTE(align_mode) + l_2 = self.linear2(act) + hidden_states = residual + self.dropout2(l_2) + else: + hidden_states = self.fused_dropout_add2( + self.linear2(self.activation(self.linear1(hidden_states), approximate=True)), residual + ) + if not self.config.normalize_before: + hidden_states = self.norm2(hidden_states) + + if not (output_attentions or use_cache): + return hidden_states + + temp_list = [ + hidden_states, + attention_weights, + incremental_cache, + ] + + return tuple(v for v in temp_list if v is not None) + + +class GPTEmbeddingsNet(nn.Layer): + """ + Include embeddings from word and position embeddings. + """ + + def __init__( + self, + config, + ): + super(GPTEmbeddingsNet, self).__init__() + + self.config = config + + self.word_embeddings = nn.Embedding( + config.vocab_size, + config.hidden_size, + ) + + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + ) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, position_ids=None, inputs_embeddings=None): + if position_ids is None and inputs_embeddings is None: + raise ValueError("You have to specify either `inputs_embeddings` or `position_ids`)") + if position_ids is not None and inputs_embeddings is not None: + raise ValueError("You cannot specify both `inputs_embeddings` and `position_ids`)") + + with paddle.amp.auto_cast(False): + if input_ids is not None: + input_shape = input_ids.shape + inputs_embeddings = self.word_embeddings(input_ids) + else: + input_shape = inputs_embeddings.shape[:-1] + + if position_ids is None: + ones = paddle.ones(input_shape, dtype="int64") + seq_length = paddle.cumsum(ones, axis=-1) + position_ids = seq_length - ones + position_embeddings = self.position_embeddings(position_ids) + embeddings = inputs_embeddings + position_embeddings + + # exit() + if self.config.sequence_parallel: + # embeddings = dist.shard_tensor(embeddings,get_mesh(),[dist.Replicate(),dist.Replicate()]) + bs, seq_len, hidden_size = embeddings.shape + # [bs, seq_len, dim] -> [bs * seq_len, dim] + embeddings = paddle.reshape_(embeddings, [bs * seq_len, hidden_size]) + # [bs * seq_len / n, dim] (n is mp parallelism) + # embeddings = ScatterOp.apply(embeddings) + # Use a ternary operator for a more concise assignment of current_seed + current_seed = "local_seed" if self.config.sequence_parallel else "global_seed" + # The 'with' block ensures the correct seed context is used + with seed_guard_context(current_seed): + embeddings = self.dropout(embeddings) + # NOTE(align_mode) + return embeddings + + +class GPTPretrainedModelNet(PretrainedModel): + """ + An abstract class for pretrained GPT models. It provides GPT related + `model_config_file`, `resource_files_names`, `pretrained_resource_files_map`, + `pretrained_init_configuration`, `base_model_prefix` for downloading and + loading pretrained models. + See :class:`~paddlenlp.transformers.model_utils.PretrainedModel` for more details. + """ + + model_config_file = "model_config.json" + resource_files_names = {"model_state": "model_state.pdparams"} + base_model_prefix = "gpt" + config_class = GPTConfig + pretrained_init_configuration = GPT_PRETRAINED_INIT_CONFIGURATION + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + base_actions = { + # Column Linear + "layers.0.linear1.weight": partial(fn, is_column=True), + "layers.0.linear1.bias": partial(fn, is_column=True), + # Row Linear + "word_embeddings.weight": partial(fn, is_column=False), + "layers.0.self_attn.out_proj.weight": partial(fn, is_column=False), + "layers.0.linear2.weight": partial(fn, is_column=False), + } + + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + @classmethod + def _get_name_mappings(cls, config: GPTConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["wte.weight", "embeddings.word_embeddings.weight"], + ["wpe.weight", "embeddings.position_embeddings.weight"], + ["ln_f.weight", "decoder.norm.weight"], + ["ln_f.bias", "decoder.norm.bias"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"h.{layer_index}.ln_1.weight", f"decoder.layers.{layer_index}.norm1.weight"], + [f"h.{layer_index}.ln_1.bias", f"decoder.layers.{layer_index}.norm1.bias"], + [f"h.{layer_index}.ln_2.weight", f"decoder.layers.{layer_index}.norm2.weight"], + [f"h.{layer_index}.ln_2.bias", f"decoder.layers.{layer_index}.norm2.bias"], + [f"h.{layer_index}.mlp.c_fc.weight", f"decoder.layers.{layer_index}.linear1.weight"], + [f"h.{layer_index}.mlp.c_fc.bias", f"decoder.layers.{layer_index}.linear1.bias"], + [f"h.{layer_index}.mlp.c_proj.weight", f"decoder.layers.{layer_index}.linear2.weight"], + [f"h.{layer_index}.mlp.c_proj.bias", f"decoder.layers.{layer_index}.linear2.bias"], + [f"h.{layer_index}.attn.c_proj.weight", f"decoder.layers.{layer_index}.self_attn.out_proj.weight"], + [f"h.{layer_index}.attn.c_proj.bias", f"decoder.layers.{layer_index}.self_attn.out_proj.bias"], + # attention + [ + f"h.{layer_index}.attn.c_attn.weight", + f"decoder.layers.{layer_index}.self_attn.q_proj.weight", + "split", + 0, + ], + [ + f"h.{layer_index}.attn.c_attn.bias", + f"decoder.layers.{layer_index}.self_attn.q_proj.bias", + "split", + 0, + ], + [ + f"h.{layer_index}.attn.c_attn.weight", + f"decoder.layers.{layer_index}.self_attn.k_proj.weight", + "split", + 1, + ], + [ + f"h.{layer_index}.attn.c_attn.bias", + f"decoder.layers.{layer_index}.self_attn.k_proj.bias", + "split", + 1, + ], + [ + f"h.{layer_index}.attn.c_attn.weight", + f"decoder.layers.{layer_index}.self_attn.v_proj.weight", + "split", + 2, + ], + [ + f"h.{layer_index}.attn.c_attn.bias", + f"decoder.layers.{layer_index}.self_attn.v_proj.bias", + "split", + 2, + ], + ] + + model_mappings.extend(layer_mappings) + # downstream mappings + if "GPT2Model" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "transformer." + mapping[0] + mapping[1] = "gpt." + mapping[1] + if "GPT2ForTokenClassification" in config.architectures: + model_mappings.extend([["classifier.weight", "classifier.weight", "transpose"]]) + if "GPT2ForSequenceClassification" in config.architectures: + model_mappings.extend([["score.weight", "score.weight", "transpose"]]) + if "GPT2LMHeadModel" in config.architectures: + model_mappings.append(["lm_head.weight", "lm_head.decoder.weight"]) + + mappings = [StateDictNameMapping(*mapping) for mapping in model_mappings] + return mappings + + +@register_base_model +class GPTModelNet(GPTPretrainedModelNet): + r""" + The bare GPT Model transformer outputting raw hidden-states. + + This model inherits from :class:`~paddlenlp.transformers.model_utils.PretrainedModel`. + Refer to the superclass documentation for the generic methods. + + This model is also a Paddle `paddle.nn.Layer `__ subclass. Use it as a regular Paddle Layer + and refer to the Paddle documentation for all matter related to general usage and behavior. + + Args: + vocab_size (int): + Vocabulary size of `inputs_ids` in `GPTModel`. Also is the vocab size of token embedding matrix. + Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `GPTModel`. + hidden_size (int, optional): + Dimensionality of the embedding layer and decoder layer. Defaults to `768`. + num_hidden_layers (int, optional): + Number of hidden layers in the Transformer decoder. Defaults to `12`. + num_attention_heads (int, optional): + Number of attention heads for each attention layer in the Transformer decoder. + Defaults to `12`. + intermediate_size (int, optional): + Dimensionality of the feed-forward (ff) layer in the decoder. Input tensors + to ff layers are firstly projected from `hidden_size` to `intermediate_size`, + and then projected back to `hidden_size`. Typically `intermediate_size` is larger than `hidden_size`. + Defaults to `3072`. + hidden_act (str, optional): + The non-linear activation function in the feed-forward layer. + ``"gelu"``, ``"relu"`` and any other paddle supported activation functions + are supported. Defaults to `"gelu"`. + hidden_dropout_prob (float, optional): + The dropout probability for all fully connected layers in the embeddings and decoder. + Defaults to `0.1`. + attention_probs_dropout_prob (float, optional): + The dropout probability used in MultiHeadAttention in all decoder layers to drop some attention target. + Defaults to `0.1`. + max_position_embeddings (int, optional): + The maximum value of the dimensionality of position encoding, which dictates the maximum supported length of an input + sequence. Defaults to `512`. + type_vocab_size (int, optional): + The vocabulary size of the `token_type_ids`. Defaults to `16`. + + .. note:: + Please NOT using `type_vocab_size`, for it will be obsolete in the future.. + + initializer_range (float, optional): + The standard deviation of the normal initializer. Default to `0.02`. + + .. note:: + A normal_initializer initializes weight matrices as normal distributions. + See :meth:`GPTPretrainedModelNet._init_weights()` for how weights are initialized in `GPTModelNet`. + + pad_token_id(int, optional): + The index of padding token in the token vocabulary. + Defaults to `0`. + + """ + + def __init__(self, config: GPTConfig): + super(GPTModelNet, self).__init__(config) + + self.config = config + + self.pad_token_id = config.pad_token_id + self.eos_token_id = config.eos_token_id + self.bos_token_id = config.bos_token_id + self.eol_token_id = config.eol_token_id + self.vocab_size = config.vocab_size + + self.bias = paddle.tril( + paddle.ones([1, 1, config.max_position_embeddings, config.max_position_embeddings], dtype="int64") + ) + self.embeddings = GPTEmbeddingsNet(config) + + decoder_layers = nn.LayerList() + for i in range(config.num_hidden_layers): + decoder_layers.append(GPTDecoderLayerNet(config)) + + self.decoder = TransformerDecoder( + config, + decoder_layers, + ) + + def get_layer_ipp(self, layer_index): + mesh = fleet.auto.get_mesh() + if "pp" not in mesh.dim_names: + return None + else: + pp_degree = mesh.get_dim_size("pp") + layer_per_stage = math.ceil(self.config.num_hidden_layers / pp_degree) + return layer_index // layer_per_stage + + def get_last_layer_ipp(self): + return self.get_layer_ipp(self.config.num_hidden_layers - 1) + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + # NOTE(zhaoyingli): infer spmd does not support [seq_len, seq_len] --> [batch, 1, seq_len, seq_len] in data_parallel + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r""" + The GPTModelNet forward method, overrides the `__call__()` special method. + + Args: + input_ids (Tensor, optional): + Indices of input sequence tokens in the vocabulary. They are + numerical representations of tokens that build the input sequence. + Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. + Defaults to None. + position_ids(Tensor, optional): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, + max_position_embeddings - 1]``. + Shape as `(batch_size, num_tokens)` and dtype as int64. Defaults to `None`. + attention_mask (Tensor, optional): + Mask used in self attention to avoid performing attention to some unwanted positions, + usually the subsequent positions. + It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + It is a tensor with shape bro adcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`. + For example, its shape can be [batch_size, sequence_length], [batch_size, sequence_length, sequence_length], + [batch_size, num_attention_heads, sequence_length, sequence_length]. + Its data type should be int64. + The `masked` tokens have `0` values, and the `unmasked` tokens have `1` values. + Defaults to `None`, which means nothing needed to be prevented attention to. + inputs_embeds (Tensor, optional): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation + of shape `(batch_size, sequence_length, hidden_size)`. This is useful if you want more control over + how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. + Default to None. + use_cache (bool, optional): + Whether or not to use cache. Defaults to `False`. If set to `True`, key value states will be returned and + can be used to speed up decoding. + past_key_values (list, optional): + It is only used for inference and should be None for training. + Default to `None`. + output_attentions (bool, optional): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. Defaults to `False`. + output_hidden_states (bool, optional): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. Defaults to `False`. + return_dict (bool, optional): + Whether to return a :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` object. If `False`, the output + will be a tuple of tensors. Defaults to `False`. + + Returns: + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`. + + Especially, When `return_dict=output_hidden_states=output_attentions=False`, + returns tensor `outputs` which is the output at the last layer of the model. + Its data type should be float32 and has a shape of [batch_size, sequence_length, hidden_size]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import GPTModelNet, GPTTokenizer + + tokenizer = GPTTokenizer.from_pretrained('gpt2-medium-en') + model = GPTModelNet.from_pretrained('gpt2-medium-en') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!", return_token_type_ids=False) + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + output = model(**inputs) + """ + + if self.config.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.reshape((-1, input_shape[-1])) + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + # input_shape => bs, seq_len + if past_key_values is None: + past_key_values = tuple([None] * len(self.decoder.layers)) + + if position_ids is None: + past_length = 0 + if past_key_values[0] is not None: + # bs, seq_len, num_head, head_dim + past_length = past_key_values[0][0].shape[1] + position_ids = paddle.arange(past_length, input_shape[-1] + past_length, dtype="int64") + position_ids = position_ids.unsqueeze(0) + position_ids = paddle.expand(position_ids, input_shape) + embedding_output = self.embeddings( + input_ids=input_ids, position_ids=position_ids, inputs_embeddings=inputs_embeds + ) + # TODO, use registered buffer + length = input_shape[-1] + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + length = length + cache_length + else: + cache_length = 0 + + causal_mask = self.bias[:, :, cache_length:length, :length] + if attention_mask is not None: + if attention_mask.dtype != paddle.int64: + attention_mask = paddle.cast(attention_mask, dtype=paddle.int64) + if len(attention_mask.shape) == 2: + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - (attention_mask & causal_mask)) * -1e4 + else: + attention_mask = (1.0 - causal_mask) * -1e4 + + # The tensor returned by triu not in static graph. + attention_mask.stop_gradient = True + + outputs = self.decoder( + embedding_output, + attention_mask=attention_mask, + use_cache=use_cache, + past_key_values=past_key_values, + output_hidden_states=output_hidden_states, + output_attentions=output_attentions, + return_dict=return_dict, + ) + + if output_hidden_states: + if return_dict: + outputs.hidden_states = (embedding_output,) + outputs.hidden_states + else: # outputs is a tuple + idx = 2 if use_cache else 1 + all_hidden_states = (embedding_output,) + outputs[idx] + outputs[idx] = all_hidden_states + + return outputs + + +class GPTPretrainingCriterionNet(paddle.nn.Layer): + """ + Criterion for GPT. It calculates the final loss. + """ + + def __init__(self, config): + super(GPTPretrainingCriterionNet, self).__init__() + self.config = config + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=config.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels, loss_mask=None): + """ + Args: + prediction_scores(Tensor): + The logits of masked token prediction. Its data type should be float32 and + its shape is [batch_size, sequence_length, vocab_size]. + masked_lm_labels(Tensor): + The labels of the masked language modeling, the dimensionality of `masked_lm_labels` + is equal to `prediction_scores`. Its data type should be int64 and + its shape is [batch_size, sequence_length, 1]. + loss_mask(Tensor): + Mask used for calculating the loss of the masked language modeling to avoid + calculating some unwanted tokens. + Its data type should be float32 and its shape is [batch_size, sequence_length, 1]. + + Returns: + Tensor: The pretraining loss. Its data type should be float32 and its shape is [1]. + + """ + with paddle.amp.auto_cast(False): + if len(prediction_scores.shape) < len(masked_lm_labels.unsqueeze(2).shape): + prediction_scores = paddle.unsqueeze_(prediction_scores, 0) + + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) + # if loss_mask is None: + # loss_mask = (masked_lm_loss > 0).astype("float32") + # loss_mask = loss_mask.reshape([-1]) + # print(" loss_mask ",loss_mask.shape,masked_lm_loss.reshape([-1]).shape) + # scale_loss = masked_lm_loss.reshape([-1]) * loss_mask + # print(" scale_loss ",scale_loss.shape) + # masked_lm_loss = paddle.sum(scale_loss) + # print(" masked_lm_loss ",masked_lm_loss.shape,loss_mask.shape) + # loss = masked_lm_loss / loss_mask.sum() + return loss + + +class GPTLMHeadNet(nn.Layer): + def __init__(self, config: GPTConfig, embedding_weights=None, ipp=None): + super(GPTLMHeadNet, self).__init__() + self.config = config + self.transpose_y = True + + if embedding_weights is not None: + self.transpose_y = True + self.weight = embedding_weights + else: + if config.tensor_parallel_degree > 1: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + self.weight.split_axis = 0 + + def forward(self, hidden_states, tensor_parallel_output=None): + + if self.config.sequence_parallel: + hidden_states = paddle.reshape(hidden_states, [-1, self.config.seq_length, self.config.hidden_size]) + + logits = paddle.matmul(hidden_states, self.weight, transpose_y=self.transpose_y) + return logits + + +class GPTForCausalLMNet(GPTPretrainedModelNet): + """ + The GPT Model with a `language modeling` head on top. + + Args: + gpt (:class:`GPTModelNet`): + An instance of :class:`GPTModelNet`. + + """ + + _tied_weights_keys = ["lm_head.weight", "lm_head.decoder.weight"] + _keys_to_ignore_on_save = [r"lm_head.weight", r"lm_head.decoder.weight"] + _keys_to_ignore_on_load_missing = [r"lm_head.weight", r"lm_head.decoder.weight"] + + def __init__(self, config: GPTConfig): + super(GPTForCausalLMNet, self).__init__(config) + self.gpt = GPTModelNet(config) + self.lm_head = GPTLMHeadNet(config, embedding_weights=self.gpt.embeddings.word_embeddings.weight) + + self.tie_weights() + self.criterion = GPTPretrainingCriterionNet(config) + + def get_output_embeddings(self): + return self.lm_head + + def get_input_embeddings(self): + return self.gpt.embeddings.word_embeddings + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + labels=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + ): + r""" + + Args: + input_ids (Tensor, optional): + See :class:`GPTModelNet`. + position_ids (Tensor, optional): + See :class:`GPTModelNet`. + attention_mask (Tensor, optional): + See :class:`GPTModelNet`. + inputs_embeds (Tensor, optional): + See :class:`GPTModelNet`. + use_cache (bool, optional): + See :class:`GPTModelNet`. + past_key_values (Tensor, optional): + See :class:`GPTModelNet`. + labels (paddle.Tensor, optional): + A Tensor of shape `(batch_size, sequence_length)`. + Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., vocab_size]` + Defaults to None. + output_attentions (bool, optional): + See :class:`GPTModelNet`. + output_hidden_states (bool, optional): + See :class:`GPTModelNet`. + return_dict (bool, optional): + See :class:`GPTModelNet`. + + Returns: + An instance of :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions` if + `return_dict=True`. Otherwise it returns a tuple of tensors corresponding + to ordered and not None (depending on the input arguments) fields of + :class:`~paddlenlp.transformers.model_outputs.BaseModelOutputWithPastAndCrossAttentions`. + + Especialy, when `return_dict=use_cache=output_attentions=output_hidden_states=False`, + returns a tensor `logits` which is the output of the gpt model. + """ + input_type = type(input_ids) if input_ids is not None else type(inputs_embeds) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + outputs = self.gpt( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if isinstance(outputs, input_type): + hidden_states = outputs + else: + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + if isinstance(outputs, input_type): + return (loss, logits) if loss is not None else logits + outputs = (logits,) + outputs[1:] + return ((loss,) + outputs) if loss is not None else outputs + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_fast_entry(self, kwargs): + from paddlenlp.ops import FasterGPT + + use_fp16_decoding = kwargs.get("use_fp16_decoding", False) + decode_strategy = kwargs.get("decode_strategy") + if decode_strategy == "beam_search": + raise AttributeError("'beam_search' is not supported yet in the fast version of GPT") + # Currently, FasterTransformer only support restricted size_per_head. + size_per_head = self.gpt.config["hidden_size"] // self.gpt.config["num_attention_heads"] + if size_per_head not in [32, 64, 80, 96, 128]: + raise AttributeError( + "'size_per_head = %d' is not supported yet in the fast version of GPT" % size_per_head + ) + if kwargs["forced_bos_token_id"] is not None: + # not support for min_length yet in the fast version + raise AttributeError("'forced_bos_token_id != None' is not supported yet in the fast version") + if kwargs["min_length"] != 0: + # not support for min_length yet in the fast version + raise AttributeError("'min_length != 0' is not supported yet in the fast version") + self._fast_entry = FasterGPT(self, use_fp16_decoding=use_fp16_decoding).forward + return self._fast_entry + + def prepare_inputs_for_generation(self, input_ids, use_cache=False, past_key_values=None, **kwargs): + # only last token for inputs_ids if cache is defined in kwargs + position_ids = kwargs.get("position_ids", None) + # attention_mask = kwargs.get("attention_mask", None) + if past_key_values is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + if position_ids is not None: + position_ids = position_ids[:, -1].unsqueeze(-1) + return { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": None, + "use_cache": use_cache, + "past_key_values": past_key_values, + } + + @staticmethod + def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id): + is_pad_token_in_inputs_ids = (pad_token_id is not None) and float(paddle.any(input_ids == pad_token_id)) + is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( + (eos_token_id is not None) and (pad_token_id != eos_token_id) + ) + if is_pad_token_in_inputs_ids and is_pad_token_not_equal_to_eos_token_id: + attention_mask = (input_ids != pad_token_id).astype("int64") + else: + attention_mask = paddle.ones_like(input_ids, dtype="int64") + return paddle.unsqueeze(attention_mask, axis=[1, 2]) + + +GPTLMHeadModelNet = GPTForCausalLMNet diff --git a/paddlenlp/transformers/llama/__init__.py b/paddlenlp/transformers/llama/__init__.py index 865637ecf4d5..a85b249f356d 100644 --- a/paddlenlp/transformers/llama/__init__.py +++ b/paddlenlp/transformers/llama/__init__.py @@ -15,6 +15,7 @@ from .configuration import * from .modeling import * from .modeling_auto import * +from .modeling_network import * from .modeling_pp import * from .tokenizer import * from .tokenizer_fast import * diff --git a/paddlenlp/transformers/llama/modeling_network.py b/paddlenlp/transformers/llama/modeling_network.py new file mode 100644 index 000000000000..f37bb9e18cfa --- /dev/null +++ b/paddlenlp/transformers/llama/modeling_network.py @@ -0,0 +1,1257 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Llama model""" +from __future__ import annotations + +import math +import os +import warnings +from functools import partial +from typing import Optional, Tuple + +import paddle +import paddle.nn.functional as F +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.distributed.auto_parallel.intermediate.tensor_parallel import ( + ColWiseParallel, + RowWiseParallel, + SequenceParallelBegin, + SequenceParallelDisable, + SequenceParallelEnd, +) + +from paddlenlp.transformers.conversion_utils import ( + StateDictNameMapping, + init_name_mappings, +) +from paddlenlp.transformers.model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from paddlenlp.transformers.model_utils import PretrainedModel, register_base_model + +from .configuration import ( + LLAMA_PRETRAINED_INIT_CONFIGURATION, + LLAMA_PRETRAINED_RESOURCE_FILES_MAP, + LlamaConfig, +) +from .modeling import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaNTKScalingRotaryEmbedding, + LlamaRotaryEmbedding, + _expand_2d_mask, + _make_causal_mask, + apply_rotary_pos_emb, + build_alibi_tensor, + get_triangle_upper_mask, + repeat_kv, + rms_norm_fused, +) + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +__all__ = [ + "LlamaForCausalLM3DNet", + "LlamaPretrainingCriterion3DNet", +] + + +def enable_fuse_ffn_qkv_pass(): + if os.getenv("FLAGS_enable_fused_ffn_qkv_pass") in [ + "True", + "true", + "1", + ]: + return True + else: + return False + + +def is_pp_enable(): + mesh = fleet.auto.get_mesh() + return "pp" in mesh.dim_names + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp", pp_idx) + return mesh + + +def global_mesh_starts_with_pp(): + mesh = fleet.auto.get_mesh() + if is_pp_enable(): + return mesh.get_mesh_with_dim("pp") + else: + return mesh + + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi=None, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + if alibi is not None: + raise ValueError("Flash Attention doesn't support alibi") + attn_output, attn_weights = flash_attention( + query_states, + key_states, + value_states, + causal=True, + return_softmax=output_attentions, + ) + else: + if alibi is not None: + attention_mask = attention_mask.cast(alibi.dtype) + alibi + attn_output = F.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=attention_mask is None and query_states.shape[1] != 1, + ) + attn_weights = None + + attn_output = attn_output.reshape([bsz, q_len, head_dim * query_states.shape[-2]]) + return (attn_output, attn_weights) if output_attentions else attn_output + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next tranpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + # matmul and devide by sqrt(head_dim) + attn_weights = paddle.matmul(query_states / math.sqrt(head_dim), key_states.transpose([0, 1, 3, 2])) + # then add alibi bias + if alibi is not None: + attn_weights = attn_weights + alibi + if list(attn_weights.shape) != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + # NOTE: we only call get_triangle_upper_mask under PP setup + # FIXME ZHUI when we use pipeline parallel, the attention_mask can be None + # we just make it triangle_upper_mask + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if list(attention_mask.shape) != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + # [bsz, q_len, num_heads, head_dim] -> [bsz, q_len, num_heads * head_dim] + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +class LlamaRMSNormNet(nn.Layer): + def __init__(self, config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + def forward(self, hidden_states): + if self.config.use_fused_rms_norm: + return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) + + with paddle.amp.auto_cast(False): + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + + return hidden_states * self.weight + + +class LlamaMLPNet(nn.Layer): + def __init__(self, config, ipp: Optional[int] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + self.config = config + + if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): + self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + def forward(self, x): + if self.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class LlamaAttentionNet(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False, ipp: Optional[int] = None): + super().__init__() + + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.head_dim = self.hidden_size // config.num_attention_heads + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + + self.max_position_embeddings = config.max_position_embeddings + self.seq_length = config.seq_length + + self.fuse_attention_qkv = config.fuse_attention_qkv + + self.kv_indices = None + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): + self.qkv_proj = nn.Linear( + self.hidden_size, + self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + else: + self.q_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + + self.k_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + + self.v_proj = nn.Linear( + self.hidden_size, + self.config.num_key_value_heads * self.head_dim, + bias_attr=False, + ) + self.o_proj = nn.Linear( + self.hidden_size, + self.hidden_size, + bias_attr=False, + ) + + if config.rope: + self._init_rope() + + self.config = config + + def _init_rope(self): + if self.config.rope_scaling_type is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "ntk": + self.rotary_emb = LlamaNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + elif self.config.rope_scaling_type == "dynamic_ntk": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=self.config.rope_scaling_factor, + base=self.config.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}") + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] or [seq_len / n, bs, num_head * head_dim] (if sequence_parallel) + # enter tp region + + if self.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): + target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] + mix_layer = self.qkv_proj(hidden_states) + mix_layer = paddle.reshape_(mix_layer, target_shape) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim], + axis=-1, + ) + if self.gqa_or_mqa: + query_states = paddle.reshape(query_states, [0, 0, self.num_heads, self.head_dim]) + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + + query_states = self.q_proj(hidden_states).reshape(shape=target_query_shape) + key_states = self.k_proj(hidden_states).reshape(shape=target_key_value_shape) + value_states = self.v_proj(hidden_states).reshape(shape=target_key_value_shape) + + kv_seq_len = key_states.shape[-3] + + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + + if self.config.rope: + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + batch_size, seq_length, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # hack here, because elementwise infer spmd not support broadcast now + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + + past_key_value = (key_states, value_states) if use_cache else None + + if self.kv_indices is not None: + key_states = paddle.index_select(key_states, self.kv_indices, axis=2) + value_states = paddle.index_select(value_states, self.kv_indices, axis=2) + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + # paddle version > 2.6 or develop support flash-attn with gqa/mqa + paddle_version = float(paddle.__version__[:3]) + if (paddle_version != 0.0) and (paddle_version <= 2.6): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + scaled_dot_product_attention, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = scaled_dot_product_attention( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + alibi, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # [bs, q_len, num_head * head_dim] + attn_output = self.o_proj(attn_output) + + # enter sp region + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class LlamaDecoderLayerNet(nn.Layer): + def __init__(self, config, layerwise_recompute: bool = False, ipp: Optional[int] = None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.self_attn = LlamaAttentionNet(config, layerwise_recompute, ipp) + self.mlp = LlamaMLPNet(config, ipp) + self.input_layernorm = LlamaRMSNormNet(config) + self.post_attention_layernorm = LlamaRMSNormNet(config) + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + alibi: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `cache` key value states are returned and can be used to speed up decoding + (see `cache`). + cache (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + # [bs, seq_len, embed_dim] or [seq_len / n, bs, embed_dim] (if sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + alibi, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + + # enter tp region + + hidden_states = self.mlp(hidden_states) + + # enter sp region + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class GlobalOutputNet(nn.Layer): + def __init__(self, config) -> None: + super().__init__() + self.config = config + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + return expanded_attn_mask + + def forward( + self, position_ids, attention_mask, seq_length, batch_size, seq_length_with_past, cache_length, emb_dtype + ): + if position_ids is None and self.config.sep_parallel_degree > 1: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + if not self.config.use_flash_attention and attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if self.config.alibi: + if attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=emb_dtype) + else: + alibi = None + if self.config.use_flash_attention and not self.config.alibi: + # attention_mask in flash_attn is always None for pretrain + # atttenton_mask is used in scaled_dot_product_attention with alibi_tensor + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, emb_dtype + ) # [bs, 1, seq_len, seq_len] + return position_ids, attention_mask, alibi + + +class LlamaPretrainedModelNet(PretrainedModel): + config_class = LlamaConfig + base_model_prefix = "llama" + pretrained_init_configuration = LLAMA_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = LLAMA_PRETRAINED_RESOURCE_FILES_MAP + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + + @classmethod + def _get_name_mappings(cls, config: LlamaConfig) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "LlamaModelNet" + if "LlamaModelNet" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "llama." + mapping[1] + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + # Column Linear + if config.fuse_attention_qkv and not enable_fuse_ffn_qkv_pass(): + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn and not enable_fuse_ffn_qkv_pass(): + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + ''' + def _init_weights(self, layer): + """Initialization hook""" + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + LlamaLMHeadNet, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + + if isinstance(layer.weight, paddle.Tensor): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.llama.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, LlamaMLPNet): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, LlamaAttentionNet): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + ''' + + +@register_base_model +class LlamaModelNet(LlamaPretrainedModelNet): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayerNet`] + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + self.global_layer = GlobalOutputNet(config=config) + + def get_layer_pp_info(layer_index): + mesh = fleet.auto.get_mesh() + if is_pp_enable() is False: + return None, False + else: + pp_degree = mesh.get_dim_size("pp") + layer_per_stage = math.ceil(config.num_hidden_layers / pp_degree) + input_need_reshard = layer_index % layer_per_stage == 0 + return layer_index // layer_per_stage, input_need_reshard + + decoder_layers = [] + self.next_pp_stage_indexes = [] + for i in range(config.num_hidden_layers): + pp_stage_id, input_need_reshard = get_layer_pp_info(i) + decoder_layers.append(LlamaDecoderLayerNet(config, i not in self.no_recompute_layers, pp_stage_id)) + if input_need_reshard: + self.next_pp_stage_indexes.append(i) + + self.layers = nn.LayerList(decoder_layers) + self.norm = LlamaRMSNormNet(config) + + self.gradient_checkpointing = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward( + self, + input_ids=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=None, + past_key_values=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + + if inputs_embeds is None: + with paddle.amp.auto_cast(False): + inputs_embeds = self.embed_tokens(input_ids) + + """ + if position_ids is None and self.config.sep_parallel_degree > 1: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + # embed positions + if not self.config.use_flash_attention and attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + + if self.config.alibi: + if attention_mask is None: + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + else: + alibi = None + if self.config.use_flash_attention and not self.config.alibi: + # attention_mask in flash_attn is always None for pretrain + # atttenton_mask is used in scaled_dot_product_attention with alibi_tensor + attention_mask = None + else: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + """ + position_ids, attention_mask, alibi = self.global_layer( + position_ids, + attention_mask, + seq_length, + batch_size, + seq_length_with_past, + cache_length, + inputs_embeds.dtype, + ) + # print(position_ids, attention_mask, alibi) + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = recompute( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + + +class LlamaPretrainingCriterion3DNet(paddle.nn.Layer): + """ + Criterion for Llama. + It calculates the final loss. + """ + + def __init__(self, config): + + super(LlamaPretrainingCriterion3DNet, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + # Force entropy same kernel + with paddle.amp.auto_cast(False): + if isinstance(prediction_scores, paddle.Tensor): + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32")._use_gpudnn(False), + masked_lm_labels.unsqueeze(2), + ) + else: + + masked_lm_loss = self.loss_func( + prediction_scores.astype("float32"), + masked_lm_labels.unsqueeze(2), + ) + + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) + return loss + + +class LlamaLMHeadNet(nn.Layer): + def __init__(self, config: LlamaConfig): + super(LlamaLMHeadNet, self).__init__() + self.config = config + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + logits = paddle.matmul(hidden_states, self.weight, transpose_y=False) + return logits + + +class LlamaForCausalLM3DNet(LlamaPretrainedModelNet): + enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.config = config + + self.llama = LlamaModelNet(config) + self.lm_head = LlamaLMHeadNet(config) + + def get_input_embeddings(self): + return self.llama.embed_tokens + + def set_input_embeddings(self, value): + self.llama.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.llama = decoder + + def get_decoder(self): + return self.llama + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithCrossAttentions) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1 + ) + + return model_kwargs + + def forward( + self, + input_ids=None, + labels=None, + position_ids=None, + attention_mask=None, + inputs_embeds=None, + use_cache=False, + past_key_values=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + input_ids.stop_gradient = True + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + outputs = self.llama( + input_ids, # [bs, seq_len] + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] # [bs, seq_len, dim] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + return logits + + # loss = None + # if labels is not None: + # labels.stop_gradient = True + # loss = self.criterion(logits, labels) + + # if not return_dict: + # output = (logits,) + outputs[1:] + # return (loss,) + output if loss is not None else output + + # return CausalLMOutputWithCrossAttentions( + # loss=loss, + # logits=logits, + # past_key_values=outputs.past_key_values, + # hidden_states=outputs.hidden_states, + # attentions=outputs.attentions, + # ) + + def auto_dist_config(self, prefix=""): + if prefix != "": + assert prefix.endswith(".") + config = { + "sp_config": { + "parallelize_plan": { + f"{prefix}llama.embed_tokens": [ + ColWiseParallel(), + SequenceParallelBegin(), + ], + f"{prefix}llama.layers.*.self_attn.qkv_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": RowWiseParallel(), + f"{prefix}llama.layers.*.self_attn": SequenceParallelDisable(), + f"{prefix}llama.layers.*.mlp.gate_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": RowWiseParallel(), + f"{prefix}llama.layers.*.mlp": SequenceParallelDisable(need_transpose=False), + f"{prefix}lm_head.weight": ColWiseParallel(), + f"{prefix}lm_head": SequenceParallelEnd(), + } + }, + "mp_config": { + "parallelize_plan": { + f"{prefix}llama.embed_tokens": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.qkv_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_up_fused_proj": ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": RowWiseParallel(), + f"{prefix}lm_head.weight": ColWiseParallel(), + } + }, + "pp_config": {"split_spec": f"{prefix}llama.layers", "global_spec": "llama.global_layer"}, + } + + return config diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index fcc207a2e0bc..c25c750f3f37 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -2736,6 +2736,103 @@ def save_pretrained( f"index located at {save_index_file}." ) + def merge_auto_dist_configs(self, configs): + """ + Merged all auto dist configs into one config. + """ + assert isinstance(configs, (dict, list)) + if isinstance(configs, dict): + return configs + final_config = { + "mp_config": None, + "sp_config": None, + "pp_config": None, + } + for config in configs: + if "mp_config" in config and config["mp_config"] is not None: + if final_config["mp_config"] is None: + final_config["mp_config"] = config["mp_config"] + else: + for k, v in config["mp_config"]["parallelize_plan"].items(): + assert k not in final_config["mp_config"]["parallelize_plan"].keys() + final_config["mp_config"]["parallelize_plan"][k] = v + if "sp_config" in config and config["sp_config"] is not None: + if final_config["sp_config"] is None: + final_config["sp_config"] = config["sp_config"] + else: + for k, v in config["sp_config"]["parallelize_plan"].items(): + assert k not in final_config["sp_config"]["parallelize_plan"].keys() + final_config["sp_config"]["parallelize_plan"][k] = v + if "pp_config" in config and config["pp_config"] is not None: + if isinstance(config["pp_config"]["split_spec"], str): + config["pp_config"]["split_spec"] = [config["pp_config"]["split_spec"]] + if final_config["pp_config"] is None: + final_config["pp_config"] = config["pp_config"] + else: + final_config["pp_config"]["split_spec"] += config["pp_config"]["split_spec"] + + if final_config["pp_config"] is not None and len(final_config["pp_config"]["split_spec"]) == 1: + final_config["pp_config"]["split_spec"] = final_config["pp_config"]["split_spec"][0] + + return final_config + + def _generate_auto_dist_config(self, auto_dist_degree): + merged_config = { + "sp_config": None, + "mp_config": None, + "pp_config": None, + } + for name, layer in self.named_sublayers(include_self=True): + if hasattr(layer, "auto_dist_config"): + if name != "": + prefix = name + "." + else: + prefix = "" + layer_config = layer.auto_dist_config(prefix) + merged_config = self.merge_auto_dist_configs([merged_config, layer_config]) + for _, deeper_layer in layer.named_sublayers(): + if hasattr(deeper_layer, "auto_dist_config"): + # mask all `auto_dist_config` methods in deeper layer + deeper_layer.auto_dist_config = lambda x: {} + + final_config = { + "dp_config": None, + "mp_config": None, + "pp_config": None, + } + + if "tensor_parallel" in auto_dist_degree and auto_dist_degree["tensor_parallel"]: + merged_config["mp_config"] is not None + final_config["mp_config"] = merged_config["mp_config"] + + if "sequence_parallel" in auto_dist_degree and auto_dist_degree["sequence_parallel"]: + merged_config["sp_config"] is not None + final_config["mp_config"] = merged_config["sp_config"] + + if "pipeline_parallel" in auto_dist_degree and auto_dist_degree["pipeline_parallel"]: + merged_config["pp_config"] is not None + final_config["pp_config"] = merged_config["pp_config"] + + if "data_sharding_parallel" in auto_dist_degree and auto_dist_degree["data_sharding_parallel"]: + # to avoid a circular import + from paddlenlp.trainer.trainer_utils import ShardingOption + + level = 0 + if "sharding" in auto_dist_degree and auto_dist_degree["sharding"] is not None: + sharding = auto_dist_degree["sharding"] + if ShardingOption.SHARD_OP in sharding: + level = 1 + if ShardingOption.SHARD_GRAD_OP in sharding: + level = 2 + if ShardingOption.FULL_SHARD in sharding: + level = 3 + final_config["dp_config"] = { + "sharding_level": level, + "sharding_mesh_dim": auto_dist_degree.get("sharding_mesh_dim", None), + } + + return final_config + class PipelinePretrainedModel(PretrainedModel): def __init_hook__(self): diff --git a/paddlenlp/transformers/qwen/__init__.py b/paddlenlp/transformers/qwen/__init__.py index d64a428adbd0..7dd8305586e5 100644 --- a/paddlenlp/transformers/qwen/__init__.py +++ b/paddlenlp/transformers/qwen/__init__.py @@ -14,5 +14,6 @@ from .configuration import * from .modeling import * from .modeling_3D_auto import * +from .modeling_network import * from .modeling_pp import * from .tokenizer import * diff --git a/paddlenlp/transformers/qwen/modeling_network.py b/paddlenlp/transformers/qwen/modeling_network.py new file mode 100644 index 000000000000..4990e2e66398 --- /dev/null +++ b/paddlenlp/transformers/qwen/modeling_network.py @@ -0,0 +1,921 @@ +# Copyright (c) 2023 Alibaba Cloud and PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from functools import partial +from typing import List + +import paddle +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils import recompute +from paddle.utils import try_import + +from paddlenlp.transformers.model_outputs import BaseModelOutputWithPast +from paddlenlp.transformers.model_utils import PretrainedModel +from paddlenlp.utils.log import logger + +from ...utils.converter import StateDictNameMapping, init_name_mappings +from .configuration import QWenConfig + +__all__ = [ + "QWenBlockNet", + "QWenForCausalLM3DNet", + "QWenPretrainedModelNet", + "QWenModelNet", + "QWenLMHeadNet", + "QWenPretrainingCriterionNet", +] + + +MAX_NTK_SEQ_LENGTH = 32768 + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except: + fused_rotary_position_embedding = None + +try: + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +def get_mesh(pp_idx=0): + mesh = fleet.auto.get_mesh() + if "pp" in mesh.dim_names: + mesh = mesh.get_mesh_with_dim("pp")[pp_idx] + return mesh + + +def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if is_fleet_init and tensor_parallel_degree > 1 and y.is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=False) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=False) + return logits + + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + + +attention_cnt = 0 + + +class QWenAttentionNet(nn.Layer): + def __init__(self, config, ipp=None): + super().__init__() + + self.config = config + self.seq_length = config.seq_length + self.hidden_size = config.hidden_size + self.split_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + + self.scale_attn_weights = True + self.enable_recompute = config.use_recompute + self.recompute_granularity = config.recompute_granularity + + self.projection_size = config.kv_channels * config.num_attention_heads + + assert self.projection_size % config.num_attention_heads == 0 + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + + self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size, bias_attr=True) + self.c_proj = nn.Linear(config.hidden_size, self.projection_size, bias_attr=False) + + if config.rotary_pct == 1.0: + self.rotary_ndims = None + else: + assert config.rotary_pct < 1 + self.rotary_ndims = int(self.hidden_size_per_attention_head * config.rotary_pct) + dim = self.rotary_ndims if self.rotary_ndims is not None else self.hidden_size_per_attention_head + self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base) + + self.use_dynamic_ntk = config.use_dynamic_ntk + self.use_logn_attn = config.use_logn_attn + + logn_list = [math.log(i, self.seq_length) if i > self.seq_length else 1 for i in range(1, MAX_NTK_SEQ_LENGTH)] + self.logn_tensor = paddle.to_tensor(logn_list)[None, :, None, None] + self._ntk_cached = 1.0 + + self.attn_dropout = nn.Dropout(config.attn_dropout_prob) + global attention_cnt + self.attention_cnt = attention_cnt + attention_cnt += 1 + + def _attn(self, query, key, value, attention_mask=None): + # Support the flash attention and normal attention + bsz, q_len, num_heads, head_dim = query.shape + _, kv_seq_len, _, _ = value.shape + if self.config.use_flash_attention and flash_attention is not None: + # Flash Attention now ignore attention mask + # Current Flash Attention doesn't support attn maskt + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + version = paddle.version.full_version + if version != "0.0.0" and version <= "2.5.2": + attn_output, attn_weights = flash_attention( + query, + key, + value, + causal=query.shape[1] != 1, + dropout=self.config.attn_dropout_prob, + return_softmax=self.config.attn_dropout_prob > 0.0, + ) + else: + attn_output = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + is_causal=attention_mask is None, + ) + attn_weights = None + return attn_output, attn_weights + else: + # [bz, sql, nh, hid] ==> [bz, nh, sql hdim] + query = query.transpose([0, 2, 1, 3]) + # [bz, sql, nh, hid] ==> [bz, nh, sql hdim] + key = key.transpose([0, 2, 1, 3]) + # [bz, sql, nh, hid] ==> [bz, nh, sql hdim] + value = value.transpose([0, 2, 1, 3]) + + attn_weights = paddle.matmul(query / math.sqrt(head_dim), key.transpose([0, 1, 3, 2])) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + # If the attention mask is None, we need to construct the causal attention mask + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + attn_weights = attn_weights + attention_mask + with paddle.amp.auto_cast(False): + attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(value.dtype) + + attn_weights = self.attn_dropout(attn_weights) + attn_output = paddle.matmul(attn_weights, value) + attn_output = attn_output.transpose([0, 2, 1, 3]) + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.shape[:-1] + [num_heads, attn_head_size] + tensor = tensor.reshape(new_shape) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + new_shape = tensor.shape[:-2] + [ + num_heads * attn_head_size, + ] + return tensor.reshape(new_shape) + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + position_ids=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + output_attentions=False, + use_cache=False, + ): + # # [bz, sql, hid] ==> [bz, sql, 3*hid] + mixed_x_layer = self.c_attn(hidden_states) + # [bz, sql, 3*hid] ==> [bz, sql, hid] + target_shape = [0, 0, self.num_heads, 3 * self.head_dim] + + mixed_x_layer = paddle.reshape_(mixed_x_layer, target_shape) + query, key, value = paddle.split(mixed_x_layer, num_or_sections=3, axis=-1) + + # [bz, sql, hid] ==> [bz, sql, nh, hdim] + # query = self._split_heads(query, self.num_heads, self.head_dim) + # key = self._split_heads(key, self.num_heads, self.head_dim) + # value = self._split_heads(value, self.num_heads, self.head_dim) + + kv_seq_len = hidden_states.shape[1] + if layer_past: + # layer past[0] shape: bs * seq_len * head_num * dim + kv_seq_len += layer_past[0].shape[1] + if self.use_dynamic_ntk and kv_seq_len == hidden_states.shape[1] and not self.training: + context_value = math.log(kv_seq_len / self.seq_length, 2) + 1 + ntk_alpha = 2 ** math.ceil(context_value) - 1 + ntk_alpha = max(ntk_alpha, 1) + self._ntk_cached = ntk_alpha + else: + ntk_alpha = self._ntk_cached + rotary_pos_emb = self.rotary_emb(value, kv_seq_len, ntk_alpha=ntk_alpha) + + if rotary_pos_emb is not None: + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb + else: + rotary_pos_emb = (rotary_pos_emb,) * 2 + + if rotary_pos_emb is not None: + cos, sin = rotary_pos_emb + if self.config.use_fused_rope: + query, key, _ = fused_rotary_position_embedding( + query, + key, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query, key = apply_rotary_pos_emb(query, key, cos, sin, position_ids=position_ids) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + key = paddle.concat([past_key, key], axis=1) + value = paddle.concat([past_value, value], axis=1) + + if use_cache: + present = (key, value) + else: + present = None + + if self.use_logn_attn and not self.training: + if self.logn_tensor.dtype != query.dtype: + self.logn_tensor = self.logn_tensor.astype(query.dtype) + seq_start = key.shape[1] - query.shape[1] + seq_end = key.shape[1] + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :] + query = query * logn_tensor.expand(query.shape) + + has_gradient = not (query.stop_gradient and key.stop_gradient and value.stop_gradient) + if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "core_attn": + attn_output, attn_weight = recompute( + self._attn, query, key, value, attention_mask, use_reentrant=self.config.recompute_use_reentrant + ) + else: + attn_output, attn_weight = self._attn(query, key, value, attention_mask) + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = self.c_proj(context_layer) + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weight,) + + return outputs + + +class QWenMLPNet(nn.Layer): + def __init__(self, config, ipp=None): + super().__init__() + ff_dim_in = config.intermediate_size // 2 + self.fuse_attention_ffn = config.fuse_attention_ffn + self.w1 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=False) + self.w2 = nn.Linear(config.hidden_size, ff_dim_in, bias_attr=False) + self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias_attr=False) + + def forward(self, hidden_states): + # # up + # a1 = self.w1(hidden_states) + # # gate + # a2 = self.w2(hidden_states) + # intermediate_parallel = a1 * F.silu(a2) + # down + if self.fuse_attention_ffn: + intermediate_parallel = swiglu(self.gate_up_fused_proj(hidden_states)) + else: + intermediate_parallel = swiglu(self.w2(hidden_states), self.w1(hidden_states)) + output = self.c_proj(intermediate_parallel) + return output + + +class QWenBlockNet(nn.Layer): + def __init__(self, config, ipp=None, idx=None): + super().__init__() + self.config = config + self.ln_1 = QWenRMSNormNet(config) + self.attn = QWenAttentionNet(config, ipp) + self.ln_2 = QWenRMSNormNet(config) + self.mlp = QWenMLPNet(config, ipp) + self.idx = idx + + def forward( + self, + hidden_states, + layer_past=None, + attention_mask=None, + position_ids=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=False, + output_attentions=False, + ): + layernorm_output = self.ln_1(hidden_states) + + attn_outputs = self.attn( + layernorm_output, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + + outputs = attn_outputs[1:] + + residual = hidden_states + layernorm_input = attn_output + residual + + layernorm_output = self.ln_2(layernorm_input) + + residual = layernorm_input + mlp_output = self.mlp(layernorm_output) + hidden_states = residual + mlp_output + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + # remove empty tuple for pipeline parallel + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + return outputs + + +class QWenPretrainedModelNet(PretrainedModel): + config_class = QWenConfig + base_model_prefix = "qwen" + + @classmethod + def _get_tensor_parallel_mappings(cls, config, is_split=True): + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_hidden_layers): + final_actions = {} + base_actions = { + # Column Linear + "lm_head.weight": partial(fn, is_column=True), + "qwen.h.0.mlp.w2.weight": partial(fn, is_column=True), + "qwen.h.0.mlp.w1.weight": partial(fn, is_column=True), + "qwen.h.0.attn.c_attn.weight": partial(fn, is_column=True, is_naive_3fuse=True), + "qwen.h.0.attn.c_attn.bias": partial(fn, is_column=True, is_naive_3fuse=True), + # Row Linear + "qwen.wte.weight": partial(fn, is_column=False), + "qwen.h.0.mlp.c_proj.weight": partial(fn, is_column=False), + "qwen.h.0.attn.c_proj.weight": partial(fn, is_column=False), + } + for key, action in base_actions.items(): + if "h.0." in key: + for i in range(num_hidden_layers): + final_actions[key.replace("h.0.", f"h.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + @classmethod + def _get_name_mappings(cls, config: QWenConfig) -> List[StateDictNameMapping]: + mappings = [ + "wte.weight", + "ln_f.weight", + ] + + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [ + f"h.{layer_index}.ln_1.weight", + f"h.{layer_index}.ln_1.weight", + ], + [ + f"h.{layer_index}.attn.c_attn.weight", + f"h.{layer_index}.attn.c_attn.weight", + "transpose", + ], + [ + f"h.{layer_index}.attn.c_attn.bias", + f"h.{layer_index}.attn.c_attn.bias", + ], + [ + f"h.{layer_index}.attn.c_proj.weight", + f"h.{layer_index}.attn.c_proj.weight", + "transpose", + ], + [ + f"h.{layer_index}.ln_2.weight", + f"h.{layer_index}.ln_2.weight", + ], + [ + f"h.{layer_index}.mlp.w1.weight", + f"h.{layer_index}.mlp.w1.weight", + "transpose", + ], + [ + f"h.{layer_index}.mlp.w2.weight", + f"h.{layer_index}.mlp.w2.weight", + "transpose", + ], + [ + f"h.{layer_index}.mlp.c_proj.weight", + f"h.{layer_index}.mlp.c_proj.weight", + "transpose", + ], + ] + mappings.extend(layer_mappings) + + init_name_mappings(mappings) + for mapping in mappings: + mapping[0] = "transformer." + mapping[0] + if len(mapping) > 1 and mapping[1] is not None: + mapping[1] = "qwen." + mapping[1] + + if config.architectures is not None: + if "QWenForCausalLM" in config.architectures or "QWenLMHeadModel" in config.architectures: + mappings.extend( + [ + [ + "lm_head.weight", + "lm_head.weight", + "transpose", + ] + ] + ) + + init_name_mappings(mappings) + return [StateDictNameMapping(*mapping) for mapping in mappings] + + +class QWenModelNet(QWenPretrainedModelNet): + def __init__(self, config): + super().__init__(config) + self.config = config + self.vocab_size = config.vocab_size + self.num_hidden_layers = config.num_hidden_layers + self.embed_dim = config.hidden_size + self.enable_recompute = config.use_recompute + self.recompute_granularity = config.recompute_granularity + + self.wte = nn.Embedding(self.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.emb_dropout_prob) + + self.h = nn.LayerList( + [ + QWenBlockNet( + config, + self.get_layer_ipp(i), + i, + ) + for i in range(config.num_hidden_layers) + ] + ) + self.ln_f = QWenRMSNormNet(config) + + def get_layer_ipp(self, layer_index): + mesh = fleet.auto.get_mesh() + if "pp" not in mesh.dim_names: + return None + else: + pp_degree = mesh.get_dim_size("pp") + layer_per_stage = math.ceil(self.config.num_hidden_layers / pp_degree) + return layer_index // layer_per_stage + + def get_last_layer_ipp(self): + return self.get_layer_ipp(self.config.num_hidden_layers - 1) + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + @paddle.jit.not_to_static + def recompute_training( + self, + block, + hidden_states, + layer_past, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(block), + hidden_states, + layer_past, + attention_mask, + position_ids, + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + use_reentrant=self.config.recompute_use_reentrant, + ) + return hidden_states + + def get_masks(self, batch_size, seq_length, past_length, dtype, padding_mask=None): + # casual mask + casual_mask = paddle.tril(paddle.ones([batch_size, 1, seq_length, seq_length], dtype="bool")) + if past_length > 0: + casual_mask = paddle.concat( + [paddle.ones([batch_size, 1, seq_length, past_length], dtype="bool"), casual_mask], axis=-1 + ) + + # seq_mask + if padding_mask is None: + padding_mask = paddle.ones((batch_size, 1, seq_length, seq_length + past_length), dtype="bool") + if len(padding_mask.shape) == 2: + # from Tokenizer + padding_mask = ( + padding_mask.unsqueeze(axis=[1, 2]) + .expand([batch_size, 1, seq_length, seq_length + past_length]) + .astype("bool") + ) + elif len(padding_mask.shape) == 3: + # [batch_size,tgt_length, src_length] -> [batch_size, 1, tgt_length, src_length] + padding_mask = padding_mask.unsqueeze(1).astype("bool") + elif len(padding_mask.shape) == 4: + padding_mask = padding_mask.astype("bool") + + casual_mask = casual_mask & padding_mask + + return casual_mask + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.reshape([-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].shape[1] + + encoder_attention_mask = None + if inputs_embeds is None: + with paddle.amp.auto_cast(False): + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + # bool 4D mask + attention_mask = self.get_masks( + input_shape[0], input_shape[1], past_length, dtype=hidden_states.dtype, padding_mask=attention_mask + ) + # TODO(GhostScreaming): how to fix paddle.finfo? + zero = paddle.zeros(attention_mask.shape, dtype=paddle.bfloat16) + neg_inf = paddle.full_like(attention_mask, paddle.finfo(paddle.bfloat16).min, dtype=paddle.bfloat16) + # dtype 4D mask + attention_mask = paddle.where(attention_mask, zero, neg_inf) + + hidden_states = self.drop(hidden_states) + output_shape = input_shape + [ + hidden_states.shape[-1], + ] + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once("`use_cache=True` is incompatible with recompute") + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + has_gradient = not hidden_states.stop_gradient + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + if self.enable_recompute and self.training and has_gradient and self.recompute_granularity == "full": + outputs = self.recompute_training( + block, + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + position_ids=position_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if use_cache is True: + presents = presents + (outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[1],) + + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states.reshape(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class QWenLMHeadNet(nn.Layer): + def __init__(self, config: QWenConfig, ipp=None): + super(QWenLMHeadNet, self).__init__() + self.config = config + vocab_size = config.vocab_size + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + + def forward(self, hidden_states, tensor_parallel_output=None): + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = paddle.matmul(hidden_states, self.weight, transpose_y=False) + return logits + + +loss_cnt = 0 + + +class QWenPretrainingCriterionNet(paddle.nn.Layer): + """ + Criterion for Llama. + It calculates the final loss. + """ + + def __init__(self, config): + + super(QWenPretrainingCriterionNet, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + def forward(self, prediction_scores, masked_lm_labels): + global loss_cnt + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + # skip ignore_index which loss == 0 + masked_lm_loss = paddle.masked_select(masked_lm_loss, masked_lm_loss > 0).astype("float32") + loss = paddle.mean(masked_lm_loss) + + loss_cnt += 1 + return loss + + +class QWenForCausalLM3DNet(QWenPretrainedModelNet): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"] + + def __init__(self, config): + super().__init__(config) + self.qwen = QWenModelNet(config) + self.lm_head = QWenLMHeadNet(config) + + def forward( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + position_ids=None, + inputs_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.qwen( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is togather with ParallelCrossEntropy + tensor_parallel_output = ( + self.config.tensor_parallel_output and labels is not None and self.config.tensor_parallel_degree > 1 + ) + lm_logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output) + + return lm_logits + + +class RotaryEmbedding(nn.Layer): + def __init__(self, dim, base=10000): + super().__init__() + self.dim = dim + self.base = base + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._seq_len_cached = 0 + self._ntk_alpha_cached = 1.0 + + def update_cos_sin_cache(self, max_seq_len, offset=0, ntk_alpha=1.0): + seqlen = max_seq_len + offset + if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached: + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + self.inv_freq = 1.0 / (base ** (paddle.arange(0, self.dim, 2, dtype=paddle.float32) / self.dim)) + self._seq_len_cached = max(2 * seqlen, 16) + self._ntk_alpha_cached = ntk_alpha + seq = paddle.arange(self._seq_len_cached) + with paddle.amp.auto_cast(enable=False): + freqs = paddle.outer(seq.astype(paddle.float32), self.inv_freq.astype(paddle.float32)) + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + + def forward(self, x, max_seq_len, offset=0, ntk_alpha=1.0): + self.update_cos_sin_cache(max_seq_len, offset, ntk_alpha) + cos = self.cos_cached[:, offset : offset + max_seq_len, :, ...] + sin = self.sin_cached[:, offset : offset + max_seq_len, :, ...] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None): + if position_ids is None: + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def rms_norm_fused(x_in, w, eps): + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +class QWenRMSNormNet(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + self.eps = config.layer_norm_epsilon + self.weight = paddle.create_parameter( + shape=[config.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + + def _norm(self, x): + return x * paddle.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + if self.config.use_fused_rms_norm: + return rms_norm_fused(x, self.weight, self.eps) + + output = self._norm(x.astype(paddle.float32)).astype(x.dtype) + return output * self.weight