diff --git a/paddleformers/examples/deepseek_v3/config/config.json b/paddleformers/examples/deepseek_v3/config/config.json new file mode 100644 index 00000000000..b5c0fcab696 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/config.json @@ -0,0 +1,76 @@ +{ + "architectures": [ + "DeepseekV3ForCausalLM" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_deepseek.DeepseekV3Config", + "AutoModel": "modeling_deepseek.DeepseekV3Model", + "AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM" + }, + "aux_loss_alpha": 0.001, + "bos_token_id": 0, + "eos_token_id": 1, + "ep_size": 1, + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 7168, + "initializer_range": 0.02, + "intermediate_size": 18432, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v3", + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 8, + "n_shared_experts": 1, + "norm_topk_prob": true, + "num_attention_heads": 128, + "num_experts_per_tok": 8, + "num_hidden_layers": 15, + "num_key_value_heads": 128, + "num_nextn_predict_layers": 1, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 10000, + "routed_scaling_factor": 2.5, + "scoring_func": "sigmoid", + "seq_aux": true, + "tie_word_embeddings": false, + "topk_group": 4, + "topk_method": "noaux_tc", + "dtype": "bfloat16", + "transformers_version": "4.33.1", + "use_cache": true, + "v_head_dim": 128, + "vocab_size": 129280, + "using_flex_token": true, + "using_fake_gate": true, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "token_drop_steps": 0, + "recompute_fwd_gate_up": true, + "adaptive_remained_O1_recompute_ratio": 0.3, + "using_post_norm_recompute": true, + "is_split_group_gemm": false, + "use_dualpipev": true, + "send_mtp_embed": true, + "offline_quant_expert_weight": false, + "clear_origin_weight_when_offline_quant": false +} + diff --git a/paddleformers/examples/deepseek_v3/config/pretrain_argument.json b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json new file mode 100644 index 00000000000..0c8d4aefed9 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/config/pretrain_argument.json @@ -0,0 +1,53 @@ +{ + "model_name_or_path": "./config/", + "tokenizer_name_or_path": "deepseek-ai/DeepSeek-V3", + "input_dir": "./data", + "output_dir": "./checkpoints/pretrain_ckpts", + "per_device_train_batch_size": 1, + "gradient_accumulation_steps": 24, + "per_device_eval_batch_size": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 4, + "pipeline_parallel_config": "use_dualpipev", + "sharding_parallel_degree": 2, + "sharding_parallel_config": "split_param", + "sharding_comm_buffer_size_MB": 2048, + "expert_parallel_degree": 2, + "sharding": "stage1", + "virtual_pp_degree": 1, + "sequence_parallel": 0, + "use_flash_attention": true, + "max_seq_length": 4097, + "learning_rate": 3e-05, + "min_learning_rate": 3e-06, + "warmup_steps": 30, + "logging_steps": 1, + "max_steps": 200, + "save_steps": 5000, + "eval_steps": 1000, + "weight_decay": 0.01, + "bf16": true, + "fp16_opt_level": "O2", + "warmup_ratio": 0.01, + "max_grad_norm": 1.0, + "amp_master_grad": 1, + "dataloader_num_workers": 8, + "continue_training": 0, + "do_train": true, + "do_eval": true, + "do_predict": false, + "disable_tqdm": true, + "recompute": false, + "distributed_dataloader": 1, + "unified_checkpoint": true, + "save_total_limit": 2, + "skip_profile_timer": false, + "use_fused_rms_norm": true, + "fuse_attention_ffn": true, + "use_fused_rope": true, + "save_sharded_model": false, + "load_sharded_model": false, + "use_expert_parallel": true, + "unified_checkpoint_config": "skip_save_model_weight", + "offload_optim": true + } \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/run.sh b/paddleformers/examples/deepseek_v3/run.sh new file mode 100644 index 00000000000..4e7943a73db --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# llama 模型数据下载 +# mkdir -p data +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.bin +# wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwebtext_100k.idx + +rm -rf output +nohup sh script/train_gpu.sh config/pretrain_argument.json > run.log 2>&1 & \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/run_pretrain.py b/paddleformers/examples/deepseek_v3/run_pretrain.py new file mode 100644 index 00000000000..eaf966da095 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/run_pretrain.py @@ -0,0 +1,615 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import math +import os +import sys +import time +from dataclasses import dataclass, field +from typing import Optional + +import paddle + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddleformers.trainer import ( + FP8QuantWeightCallback, + PdArgumentParser, + StepFlexToken, + Trainer, + TrainingArguments, + get_last_checkpoint, + set_seed, + speed_metrics, +) +from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddleformers.transformers.configuration_utils import LlmMetaConfig, llmmetaclass +from paddleformers.utils.batch_sampler import DistributedBatchSampler +from paddleformers.utils.log import logger +from paddleformers.utils.tools import get_env_device + +# Pretaining Environment Variables to support sharding stage1 overlap optimization. +os.environ["USE_CASUAL_MASK"] = "True" + + +from paddleformers.trainer.utils.doc import add_start_docstrings + + +@dataclass +@llmmetaclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + + def __post_init__(self): + super().__post_init__() + # NOTE(gongenlei): new add autotuner_benchmark + from paddleformers.trainer.trainer_utils import IntervalStrategy + + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + self.unified_checkpoint = False + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + fuse_attention_qkv: bool = field( + default=None, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=None, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddleformers model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddleformers models." + }, + ) + num_hidden_layers: Optional[int] = field( + default=None, + metadata={"help": "num_hidden_layers."}, + ) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + # input_ids, loss_mask, attention_mask, position_ids, labels = data + input_ids = data["text"] + logger.info(tokenizer._decode(list(input_ids))) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = copy.deepcopy(tokens_)[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_pretraining = True + + def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"): + # keep eval_dataloader + eval_dataloader = getattr(self, "eval_dataloader", None) + if eval_dataloader is None: + eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset + eval_dataloader = self.get_eval_dataloader(eval_dataset) + # must call data loader, otherwise, it will init many times, cause OOM error. + self.eval_dataloader = eval_dataloader() + + start_time = time.time() + # Temporarily disable metric computation, we will do it in the loop here. + compute_metrics = self.compute_metrics + eval_loop = self.evaluation_loop + + output = eval_loop( + eval_dataloader, + description="Evaluation", + # No point gathering the predictions if there are no metrics, otherwise we defer to + # self.args.prediction_loss_only + prediction_loss_only=True if compute_metrics is None else None, + ignore_keys=ignore_keys, + # Only evaluate max_eval_iters + max_eval_iters=self.args.eval_iters, + ) + + total_batch_size = self.args.eval_batch_size * self.args.world_size + output.metrics.update( + speed_metrics( + metric_key_prefix, + start_time, + num_samples=output.num_samples, + num_steps=math.ceil(output.num_samples / total_batch_size), + ) + ) + + self.log(output.metrics) + + self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) + return output.metrics + + def _get_eval_sampler(self, eval_dataset) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + eval_dataset, + batch_size=self.args.per_device_eval_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + def _get_train_sampler(self) -> Optional[paddle.io.Sampler]: + return DistributedBatchSampler( + self.train_dataset, + batch_size=self.args.per_device_train_batch_size, + shuffle=False, + num_replicas=self.args.dataset_world_size, + rank=self.args.dataset_rank, + drop_last=self.args.dataloader_drop_last, + ) + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + # Support format as "args.json --arg1 value1 --arg2 value2.” + # In case of conflict, command line arguments take precedence. + if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file_and_cmd_lines() + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + if training_args.no_recompute_layers is not None: + training_args.no_recompute_layers.sort() + + if training_args.enable_linear_fused_grad_add: + from utils.fused_layers import mock_layers + + mock_layers() + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + paddle.set_device(training_args.device) + set_seed(seed=training_args.seed) + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + # if last_checkpoint is None and len( + # os.listdir(training_args.output_dir)) > 1: + # raise ValueError( + # f"Output directory ({training_args.output_dir}) already exists and is not empty. " + # "Use --overwrite_output_dir to overcome.") + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path, **{"download_hub": "bos"}) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + + # set all llm config + LlmMetaConfig.set_llm_config(config, training_args) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp peformance.") + + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + # Config for model using dropout, such as GPT. + if hasattr(config, "use_dualpipev"): + # NOTE(zhangyuqin): In Paddle, the segmentation and scheduling of pipeline parallel + # models are separate. Therefore, first we need to set the flag in the model config + # to perform V-shape segmentation. Second, we need to set the flag in the training_args + # to configure strategy.hybrid_configs to choose the DualPipeV schedule. + config.use_dualpipev = "use_dualpipev" in training_args.pipeline_parallel_config + if hasattr(config, "hidden_dropout_prob"): + config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + if model_args.fuse_attention_qkv is not None: + config.fuse_attention_qkv = model_args.fuse_attention_qkv + if model_args.fuse_attention_ffn is not None: + config.fuse_attention_ffn = model_args.fuse_attention_ffn + + if config.sequence_parallel: + assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." + assert ( + config.num_attention_heads % config.sep_parallel_degree == 0 + ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + # It's OK, not use accumulate_steps optimization + pass + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + # 修改这里降低模型层数,deepseek前3层为dense层,之后才有稀疏层 + # config.num_hidden_layers = 4 # v3是61 + # config.first_k_dense_replace = 0 # v3是3 + # # 修改这里降低模型专家数量,如果希望进行EP并行,专家数量要能够被并行度整除 + # config.n_routed_experts = 64 # v3是256 + # config.num_experts_per_tok = 8 # v3是8 + # config.topk_group = 4 # v3是4 + + # config.using_flex_token = True + # config.num_nextn_predict_layers = 1 + # config.using_fake_gate = True + # config.use_fused_rms_norm = True + # config.fuse_attention_ffn = True + # config.use_fused_rope = True + # config.token_drop_steps = 0 + model = model_class.from_config(config, dtype=dtype) + + if training_args.recompute: + model.recompute_enable() + + # Create the learning_rate sheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + total_effective_tokens = ( + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps + * data_args.max_seq_length + ) + + callbacks = [StepFlexToken(), FP8QuantWeightCallback()] + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + callbacks=callbacks, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + if training_args.do_train and training_args.should_load_dataset: + effective_tokens_per_second = total_effective_tokens / train_result.metrics["train_runtime"] + print(f"Effective Tokens per second: {effective_tokens_per_second:.2f}") + print(f"ips: {effective_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + main() diff --git a/paddleformers/examples/deepseek_v3/script/kill_process.sh b/paddleformers/examples/deepseek_v3/script/kill_process.sh new file mode 100644 index 00000000000..3c3db6a4639 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/kill_process.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x +skip_kill_time=${1:-"False"} +function kill_impl() { + skip_kill_time=$1 + # kill aadiff test finally. + pids=`ps -ef | grep pretrain.py | grep -v grep | awk '{print $2}'` + if [[ "$pids" != "" ]] ; then + echo $pids + echo $pids | xargs kill -9 + fi + + echo "Killing processes on gpu" + lsof /dev/nvidia* | awk '{print $2}' | xargs -I {} kill -9 {} +} + +kill_impl $skip_kill_time || true \ No newline at end of file diff --git a/paddleformers/examples/deepseek_v3/script/selective_launch.py b/paddleformers/examples/deepseek_v3/script/selective_launch.py new file mode 100644 index 00000000000..1f8a37bfbc5 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/selective_launch.py @@ -0,0 +1,73 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Selective launch script. + +Usage: python script/selective_launch.py ... +""" +import os +import sys + + +def parse_ranks(ranks_strs): + """ + parse_ranks + """ + # NOTE: You can return ranks directly here to change script/train_gpu.sh + # and script/kill_process.sh together + + # Example 1: Use contiguous nodes [8, 16) + return range(6, 7) + + # Example 2: Use non-contiguous nodes [4, 8) + {10} + [30, 32), i.e., [4, 5, 6, 7, 10, 30, 31] + # return list(range(0, 16)) + list(range(24, 40)) + + # Example 3: + # Just Python code, return any nodes you want! + + if not ranks_strs: + return None + + ranks = [] + for r in ranks_strs: + r = eval(r) + if isinstance(r, int): + ranks.append(r) + else: + ranks.extend(r) + return ranks + + +def main(port, ranks): + """ + main + """ + ips = [ip.strip() for ip in os.getenv("TRAINER_INSTANCES").split(",") if ip.strip()] + if ranks is None: + ranks = list(range(len(ips))) + ranks = sorted(list(set(ranks))) + my_rank = int(os.getenv("POD_INDEX", "0")) + if my_rank not in ranks: + return + + rank = ranks.index(my_rank) + nranks = len(ranks) + + master = ips[ranks[0]] + print(f"--master {master}:{port} --rank {rank} --nnodes {nranks}") + + +if __name__ == "__main__": + main(int(sys.argv[1]), parse_ranks(sys.argv[2:])) diff --git a/paddleformers/examples/deepseek_v3/script/train_gpu.sh b/paddleformers/examples/deepseek_v3/script/train_gpu.sh new file mode 100644 index 00000000000..c446c14aa09 --- /dev/null +++ b/paddleformers/examples/deepseek_v3/script/train_gpu.sh @@ -0,0 +1,76 @@ +#!/bin/bash + +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +unset PADDLE_ELASTIC_JOB_ID +unset PADDLE_TRAINER_ENDPOINTS +unset DISTRIBUTED_TRAINER_ENDPOINTS +unset FLAGS_START_PORT +unset PADDLE_ELASTIC_TIMEOUT + +nnodes=$PADDLE_TRAINERS_NUM +rank=$PADDLE_TRAINER_ID + +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +#export FLAGS_shard_bypass_dygraph_optimizer=1 +export NCCL_IB_GID_INDEX=3 +export NVSHMEM_IB_GID_INDEX=3 +export NVSHMEM_IB_TRAFFIC_CLASS=162 + +#export NVSHMEM_IB_ENABLE_IBGDA=true +##export NVSHMEM_DISABLE_P2P=1 +export NVSHMEM_BOOTSTRAP=UID +# export NVSHMEM_BOOTSTRAP_UID_SOCK_IFNAME==xgbe0 + +unset NVSHMEM_HCA_LIST +unset NVSHMEM_ENABLE_NIC_PE_MAPPING + +if [[ -z "$LAUNCH_CMD" ]]; then + exit 0 +fi + +export PYTHONPATH=../../../:$PYTHONPATH + +export PATH=/opt/nvidia/nsight-systems/2025.3.1/bin/:$PATH + +export DSV3_USE_FP8_GEMM=true +export DSV3_USE_ATTEN_RECOMPUTE=true +export FA_VERSION=3 +export CUDA_PATH=/usr/local/cuda-12.9 +export FLAGS_share_tensor_for_grad_tensor_holder=1 +export FLAGS_use_default_stream=false +export DSV3_USE_FP8_DISPATCH=true +export USE_DS_GEMM=false + +bash script/kill_process.sh + +export FLAGS_large_pool_auto_growth_chunk_size_in_mb=500 +export FLAGS_small_pool_auto_growth_chunk_size_in_mb=20 +export FLAGS_small_pool_size_in_mb=10 + +export FLAGS_samll_pool_pre_alloc_in_mb=500 +export FLAGS_large_pool_pre_alloc_in_mb=61440 + +export DSV3_FAST_PRETRAIN=true +# nsys profile --stats=true -t cuda,nvtx -o test_no_quant_cache --force-overwrite true \ +python3.10 -m paddle.distributed.launch \ + --log_dir output/paddle_distributed_logs \ + --nnodes 256 \ + --run_mode=collective \ + ${script:-run_pretrain.py} \ + $@ diff --git a/paddleformers/trainer/__init__.py b/paddleformers/trainer/__init__.py index 53ceb66a961..b129b4a20a1 100644 --- a/paddleformers/trainer/__init__.py +++ b/paddleformers/trainer/__init__.py @@ -75,6 +75,8 @@ "TrainerState", "DEFAULT_PROGRESS_CALLBACK", "TrainerCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ], "trainer_utils": [ "get_last_checkpoint", diff --git a/paddleformers/trainer/trainer.py b/paddleformers/trainer/trainer.py index 55fb28d5c09..027f16f501a 100644 --- a/paddleformers/trainer/trainer.py +++ b/paddleformers/trainer/trainer.py @@ -92,6 +92,10 @@ RowParallelQuantizationLinear, ) +try: + from ..quantization.quantization_linear import QuantizationLinear +except: + QuantizationLinear = None try: from paddle.distributed.fleet.utils.sequence_parallel_utils import ( register_sequence_parallel_allreduce_hooks, @@ -201,6 +205,14 @@ DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback +# Name of the files used for checkpointing +TRAINING_ARGS_NAME = "training_args.bin" +TRAINER_STATE_NAME = "trainer_state.json" + +SCHEDULER_NAME = "scheduler.pdparams" +SCALER_NAME = "scaler.pdparams" + + if is_datasets_available(): import datasets diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 812c8dc9f59..d3b50d856b7 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -20,12 +20,14 @@ """ import dataclasses import json +import os from dataclasses import dataclass from typing import Dict, List, Optional, Union import numpy as np from tqdm.auto import tqdm +from paddleformers.transformers.moe_utils import offload, reload from ..utils.log import logger from .trainer_utils import IntervalStrategy, has_length from .training_args import TrainingArguments @@ -39,6 +41,8 @@ "ProgressCallback", "PrinterCallback", "EarlyStoppingCallback", + "StepFlexToken", + "FP8QuantWeightCallback", ] @@ -608,3 +612,65 @@ def on_evaluate(self, args, state, control, metrics, **kwargs): self.check_metric_value(args, state, control, metric_value) if self.early_stopping_patience_counter >= self.early_stopping_patience: control.should_training_stop = True + + +class StepFlexToken(TrainerCallback): + def on_step_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + model = kwargs.pop("model") + if hasattr(model, "step_flex_token"): + model.step_flex_token(state.global_step) + + +g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)) + + +def enable_in_dict_config(config, key): + """enable_in_dict_config""" + return key in config and config[key] + + +skip_count = 0 + + +class FP8QuantWeightCallback(TrainerCallback): + """ + FP8QuantWeightCallback + """ + + def on_step_begin(self, args, state, control, **kwargs): + """ + 每个step开始前把专家参数quant成fp8q + """ + model = kwargs["model"] + optimizer = kwargs["optimizer"] + global skip_count + + if not g_shard_bypass_dygraph_optimizer or skip_count == 0: + model.fp8_quant_weight(True) + optimizer.clear_param_storage("moe_expert") + optimizer.clear_param_storage("rms_linear") + optimizer.clear_param_storage("memory_attn") + optimizer.clear_param_storage("attn_out_project") + optimizer.clear_param_storage("shared_expert") + + self.moe_weights_name = [] + for param in optimizer._inner_opt._parameter_list: + color = getattr(param, "color", -1) + if isinstance(color, dict) and color["color"] == "moe_expert": + self.moe_weights_name.append(param.name) + + for name in self.moe_weights_name: + offload(optimizer._master_weights[name]) + + skip_count += 1 + + def on_optimizer_begin(self, args, state, control, **kwargs): + optimizer = kwargs["optimizer"] + for name in self.moe_weights_name: + reload(optimizer._master_weights[name]) diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index c505856532b..2b97b4a2afa 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -32,6 +32,7 @@ from paddle.distributed import fleet from ..utils.env import PREFIX_CHECKPOINT_DIR +from ..utils.fault_tolerance import is_ft_env from ..utils.log import logger from ..utils.pdc_sdk import FLASH_DEVICE from .trainer_utils import ( @@ -1405,12 +1406,15 @@ def is_segment_parallel_supported(): else: order = ["dp", "sharding", "pp", "mp"] if self.use_expert_parallel: - if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: - order.insert(-1, "ep") - sd_idx = order.index("sharding") - # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] - # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] - order.insert(sd_idx, "moe_sharding") + if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1: + order.insert(-1, "ep") + sd_idx = order.index("sharding") + # if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"] + # if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"] + order.insert(sd_idx, "moe_sharding") + else: + order = order[1:-1] + ["dp", "mp"] if is_segment_parallel_supported(): hybrid_configs = { @@ -1564,6 +1568,10 @@ def is_segment_parallel_supported(): fleet.init(is_collective=True, strategy=strategy) logger.info(strategy) + if os.getenv("DSV3_FAST_PRETRAIN", "False"): + if self.expert_parallel_degree > 1: + self.add_moe_comm_group() + elif self.enable_auto_parallel: self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1) self.sep_parallel_degree = max(self.sep_parallel_degree, 1) diff --git a/paddleformers/transformers/__init__.py b/paddleformers/transformers/__init__.py index ecaa18a3e9e..4e4e362564e 100644 --- a/paddleformers/transformers/__init__.py +++ b/paddleformers/transformers/__init__.py @@ -129,6 +129,10 @@ "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", ], + "deepseek_v2.modeling_fast": [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", + ], "deepseek_v2.modeling_auto": [ "DeepseekV2LMHeadAuto", "DeepseekV2ForCausalLMAuto", diff --git a/paddleformers/transformers/deepseek_v2/__init__.py b/paddleformers/transformers/deepseek_v2/__init__.py index a0fac197982..2c7634b8810 100644 --- a/paddleformers/transformers/deepseek_v2/__init__.py +++ b/paddleformers/transformers/deepseek_v2/__init__.py @@ -56,6 +56,12 @@ "yarn_find_correction_range", "get_triangle_upper_mask", "DeepseekV2LinearScalingRotaryEmbedding", + "set_global_step", + "get_global_step", + ], + "modeling_fast": [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", ], "modeling_auto": [ "DeepseekV2LMHeadAuto", diff --git a/paddleformers/transformers/deepseek_v2/configuration.py b/paddleformers/transformers/deepseek_v2/configuration.py index 1feba3cbec7..e62ae3dc5ef 100644 --- a/paddleformers/transformers/deepseek_v2/configuration.py +++ b/paddleformers/transformers/deepseek_v2/configuration.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ DeepSeekV2 model configuration""" -from ..configuration_utils import PretrainedConfig +from paddleformers.transformers.configuration_utils import PretrainedConfig __all__ = [ "DeepseekV2Config", @@ -179,6 +179,18 @@ def __init__( attention_dropout=0.0, speculate_model_type=False, using_flex_token=False, + use_dualpipev=False, + send_mtp_embed=True, + using_post_norm_recompute=False, + recompute_fwd_gate_up=0, + is_split_group_gemm=False, + fakse_gate_restrict_balance=False, + adaptive_remained_O1_recompute_ratio=0, + offline_quant_expert_weight=True, + clear_origin_weight_when_offline_quant=True, + mlp_bwd_subbatch_rows=0, + mlp_fwd_subbatch_rows=0, + output_subbatch_rows=0, **kwargs, ): self.vocab_size = vocab_size @@ -227,6 +239,18 @@ def __init__( self.speculate_model_type = speculate_model_type self.use_fp8 = False self.using_flex_token = using_flex_token + self.use_dualpipev = use_dualpipev + self.send_mtp_embed = send_mtp_embed + self.using_post_norm_recompute = using_post_norm_recompute + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.fakse_gate_restrict_balance = fakse_gate_restrict_balance + self.adaptive_remained_O1_recompute_ratio = adaptive_remained_O1_recompute_ratio + self.offline_quant_expert_weight = offline_quant_expert_weight + self.clear_origin_weight_when_offline_quant = clear_origin_weight_when_offline_quant + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows super().__init__( pad_token_id=pad_token_id, diff --git a/paddleformers/transformers/deepseek_v2/modeling.py b/paddleformers/transformers/deepseek_v2/modeling.py index 04a8651f43e..7e35a4e58b1 100644 --- a/paddleformers/transformers/deepseek_v2/modeling.py +++ b/paddleformers/transformers/deepseek_v2/modeling.py @@ -23,6 +23,7 @@ import contextlib import math +import os import warnings from functools import partial from typing import List, Optional, Tuple, Union @@ -35,7 +36,9 @@ from paddle.distributed import fleet from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from paddle.utils import try_import try: from paddle.incubate.nn.functional import fused_rotary_position_embedding @@ -51,11 +54,12 @@ except: pass +from paddle import _C_ops try: from paddle.nn.functional.flash_attention import flash_attention except: flash_attention = None - +from paddleformers.transformers.model_utils import dtype_guard from ...utils.initializer import kaiming_uniform_ from ...utils.log import logger @@ -72,11 +76,46 @@ from ..model_utils import PretrainedModel, dtype_guard, register_base_model from ..moe_gate import PretrainedMoEGate from ..moe_layer import MoEFlexTokenLayer, MoELayer -from ..utils import device_guard +from ..utils import cast_if_needed, device_guard from . import fp8_linear as linear_utils from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from ..fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8LinearFunctionBase, + FP8Mlp, + cache_fp8_weight, + set_parameter_color, +) from .fp8_linear import Linear +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + + __all__ = [ "DeepseekV2LMHead", "DeepseekV2PretrainingCriterion", @@ -84,8 +123,54 @@ "DeepseekV2ForSequenceClassification", "DeepseekV2Model", "DeepseekV2PretrainedModel", + "set_global_step", + "get_global_step", ] +global_step = 0 + + +def set_global_step(cur_step): + global global_step + global_step = cur_step + + +def get_global_step(): + global global_step + return global_step + + +def rms_norm_fused(x_in, w, eps, use_fast_ln=False): + if use_fast_ln: + fast_ln = try_import("fast_ln") + return fast_ln.fast_rms_norm(x_in, w, eps)[0] + else: + fused_ln = try_import("fused_ln") + return fused_ln.fused_rms_norm(x_in, w, eps)[0] + + +def fusion_rms_norm(hidden_states, weight, variance_epsilon, use_fast_ln=False): + if get_env_device() == "npu": + return paddle.base.core.eager._run_custom_op("rms_norm_npu", hidden_states, weight, variance_epsilon)[0] + if get_env_device() == "mlu": + return paddle.base.core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "gcu": + return paddle.base.core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0] + elif get_env_device() == "intel_hpu": + return paddle.incubate.nn.functional.fused_rms_norm( + hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1 + )[0] + elif get_env_device() == "xpu": + try: + import paddle_xpu_nn # noqa: F821 + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, weight, variance_epsilon)[0] + except ImportError: + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) + return rms_norm_fused(hidden_states, weight, variance_epsilon, use_fast_ln) + def get_triangle_upper_mask(x, mask=None): if mask is not None: @@ -129,7 +214,35 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int): return assignment_list -def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): +class LMHeadFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, weight, transpose_y): + out = paddle.matmul(x, weight, transpose_y = transpose_y) + + ctx.save_for_backward(x, weight, transpose_y) + return out + + @staticmethod + def backward(ctx, dout): + if dout.dtype == paddle.float32: + dout = dout.cast( paddle.bfloat16) + + x, weight, transpose_y = ctx.saved_tensor() + + dx = paddle.matmul( dout, weight, transpose_y = not transpose_y) + if transpose_y: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + dout.reshape( [-1, dout.shape[-1]]), x.reshape( [-1, x.shape[-1]]), weight.main_grad, None, True, False + ) + else: + with paddle.amp.auto_cast(False): + paddle._C_ops.fused_linear_param_grad_add( + x.reshape([-1, x.shape[-1]]), dout.reshape([-1, dout.shape[-1]]), weight.main_grad, None, True, False + ) + return dx, None + +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): is_fleet_init = True tensor_parallel_degree = 1 try: @@ -147,7 +260,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) - logits = paddle.matmul(input_parallel, y, transpose_y=False) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) if tensor_parallel_output: return logits @@ -155,7 +268,7 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True): return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) else: - logits = paddle.matmul(x, y, transpose_y=False) + logits = LMHeadFunction.apply(x, y, transpose_y=transpose_y) return logits @@ -328,17 +441,9 @@ def __init__(self, config: DeepseekV2Config, hidden_size=None, eps=1e-6, use_seq mark_as_sequence_parallel_parameter(self.weight) def forward(self, hidden_states): - if self.config.use_fused_rms_norm and get_env_device() == "xpu": - if self.weight.dtype != hidden_states.dtype: - hidden_states = paddle.cast(hidden_states, self.weight.dtype) - try: - import paddle_xpu_nn # noqa: F821 - - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] - except ImportError: - raise NotImplementedError( - f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" - ) + if self.config.use_fused_rms_norm: + # fusion_rms_norm集成了多硬件功能 + return fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon, self.config.use_fast_layer_norm) with paddle.amp.auto_cast(False): hidden_states = hidden_states.astype("float32") @@ -528,34 +633,35 @@ def __init__( super().__init__(dim, max_position_embeddings, base) def _set_cos_sin_cache(self, seq_len): - self.max_seq_len_cached = seq_len - dim = self.dim - - freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) - - low, high = yarn_find_correction_range( - self.beta_fast, - self.beta_slow, - dim, - self.base, - self.original_max_position_embeddings, - ) - inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) - self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + with paddle.amp.auto_cast(False): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / (self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + freq_inter = 1.0 / (self.scaling_factor * self.base ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim)) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2) + self.inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask - t = paddle.arange(seq_len, dtype=paddle.float32) + t = paddle.arange(seq_len, dtype=paddle.float32) - freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) + freqs = paddle.outer(t, paddle.cast(self.inv_freq, dtype="float32")) - _mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) - / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) - ) + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) - emb = paddle.concat((freqs, freqs), axis=-1) - self.cos_cached = emb.cos() * _mscale - self.sin_cached = emb.sin() * _mscale + emb = paddle.concat((freqs, freqs), axis=-1) + self.cos_cached = emb.cos() * _mscale + self.sin_cached = emb.sin() * _mscale def rotate_half(x): @@ -592,7 +698,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, fuse_rope=False): b, s, h, d = k.shape k = k.reshape([b, s, h, d // 2, 2]).transpose([0, 1, 2, 4, 3]).reshape([b, s, h, d]) - if get_env_device() == "xpu" and fuse_rope: + if (get_env_device() == "xpu" or get_env_device() == "gpu") and fuse_rope: q_embed, k_embed, _ = fused_rotary_position_embedding( q, k, @@ -671,9 +777,84 @@ def forward(self, x): return down_proj + +class FusedNormGateFunc(paddle.autograd.PyLayer): + """recompute of postnorm and gate""" + + _current_norm_output = None + _current_invar = None + + @classmethod + def set_temporary_vars(cls, norm_output, invar): + FusedNormGateFunc._current_norm_output = norm_output + FusedNormGateFunc._current_invar = invar + + @classmethod + def clear_temporary_vars(cls): + FusedNormGateFunc._current_norm_output = None + FusedNormGateFunc._current_invar = None + + @staticmethod + def forward(ctx, x, rms_norm_weight, moe_gate_weight, eps): + ctx.dtype = paddle.float32 + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + with paddle.amp.auto_cast(False): + gate_logits = F.linear(cast_if_needed(norm_output, ctx.dtype), cast_if_needed(moe_gate_weight, ctx.dtype)) + + ctx.save_for_backward(x, rms_norm_weight, moe_gate_weight, eps) + return gate_logits, norm_output + + @staticmethod + def backward(ctx, d_gate_logits, d_norm_output): + x, rms_norm_weight, moe_gate_weight, eps = ctx.saved_tensor() + # recompute rmsnorm + norm_output = FusedNormGateFunc._current_norm_output + invar = FusedNormGateFunc._current_invar + if norm_output is None or invar is None: + norm_output, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + d_norm_output_linear, d_moe_gate_weight = paddle._C_ops.matmul_grad( + cast_if_needed(norm_output, ctx.dtype), + cast_if_needed(moe_gate_weight, ctx.dtype), + d_gate_logits, + False, + False, + ) + d_norm_output_linear, d_moe_gate_weight = cast_if_needed( + d_norm_output_linear, norm_output.dtype + ), cast_if_needed(d_moe_gate_weight, moe_gate_weight.dtype) + + d_norm_output = d_norm_output + d_norm_output_linear + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, d_norm_output, eps) + + return dx, d_rms_norm_weight, d_moe_gate_weight + + +class TemporaryVarContext: + def __init__(self, norm_output, invar): + self.norm_output = norm_output + self.invar = invar + + def __enter__(self): + FusedNormGateFunc.set_temporary_vars(self.norm_output, self.invar) + + def __exit__(self, exc_type, exc_val, exc_tb): + FusedNormGateFunc.clear_temporary_vars() + + +def balance_expert_assignment(n, m, k): + assert k * n % m == 0 + matrix = paddle.zeros((n, m), dtype=paddle.int32) + for row in range(n): + start_col = row % m + for i in range(k): + col = (start_col + i) % m + matrix[row, col] = 1 + return matrix + + class FakeGate(paddle.autograd.PyLayer): @staticmethod - def forward(ctx, hidden_states, weight): + def forward(ctx, hidden_states, weight, fakse_gate_restrict_balance=False, num_experts_per_tok=8): expert_num = weight.shape[1] bsz, seq, _ = hidden_states.shape @@ -681,8 +862,12 @@ def forward(ctx, hidden_states, weight): ctx.x_dtype = hidden_states.dtype ctx.y_shape = weight.shape ctx.y_dtype = weight.dtype - - return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) + if fakse_gate_restrict_balance: + return paddle.reshape( + balance_expert_assignment(bsz * seq, expert_num, num_experts_per_tok), [bsz, seq, expert_num] + ) + else: + return paddle.randn([bsz, seq, expert_num]).cast(weight.dtype) @staticmethod def backward(ctx, grad_output): @@ -882,6 +1067,792 @@ def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_axis]) +@to_static(backend="CINN") +def qkv_pre_process_no_fuse( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope = q[..., :qk_nope_head_dim] + q_pe = q[..., qk_nope_head_dim:] + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + + kv = kv.reshape(shape=target_key_value_shape) + + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]).expand([-1, q_len, num_heads, qk_rope_head_dim]) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, False) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return query_states, key_states, value_states + + +@to_static(backend="CINN") +def rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads): + k_nope = kv[..., :qk_nope_head_dim] + value_states = kv[..., qk_nope_head_dim:] + + k_pe = k_pe.expand([k_pe.shape[0], k_pe.shape[1], num_heads, k_pe.shape[3]]) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + return key_states, value_states + + +def qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids +): + if (fused_partial_rope is None) or (position_ids is not None): + return qkv_pre_process_no_fuse( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + bsz, q_len, _ = q.shape + + target_query_shape = [0, 0, num_heads, q_head_dim] + target_key_value_shape = [0, 0, num_heads, qk_nope_head_dim + v_head_dim] + + q = q.reshape(shape=target_query_shape) + kv = kv.reshape(shape=target_key_value_shape) + k_pe = k_pe.reshape([-1, q_len, 1, qk_rope_head_dim]) + + value_states = kv[..., qk_nope_head_dim:] + + kv_seq_len = value_states.shape[1] + + cos, sin = rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + + query_states = fused_partial_rope(q, cos, sin) + k_pe = fused_partial_rope(k_pe, cos, sin) + + key_states, value_states = rearrange_kv(kv, k_pe, qk_nope_head_dim, num_heads) + + return query_states, key_states, value_states + + +def manul_fwd( + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, +): + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + q = paddle.matmul(q_ln_t, q_up_weight) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv = paddle.matmul(kv_ln_t, kv_up_weight) + + query_states, key_states, value_states = qkv_pre_process( + q, kv, k_pe, rotary_emb, num_heads, q_head_dim, qk_nope_head_dim, v_head_dim, qk_rope_head_dim, position_ids + ) + + q_head_dim = query_states.shape[-1] + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + query_states, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + return attn_out + + +class MemroyRecomputeAttnFunc(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + q_init, + kv_init, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ): + + bsz = q_init.shape[0] + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + # q = paddle.matmul(q_ln_t, q_up_weight) + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + q_ln_t.reshape([-1, q_orig_shape[-1]]), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + # kv = paddle.matmul(kv_ln_t, kv_up_weight) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + kv_ln_t.reshape([-1, kv_orig_shape[-1]]), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + q_head_dim = query_states.shape[-1] + + if FA_VERSION == 2: + softmax_scale = softmax_scale * (q_head_dim**0.5) + query_states = query_states * softmax_scale + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + attn_out, _, softmax_lse, seed_offset = _C_ops.flash_attn( + query_states, + key_states, + value_states_pad, + None, + None, + 0.0, + True, + False, + False, + "", + ) + + elif FA_VERSION == 3: + attn_out, softmax_lse = _C_ops.flash_attn_v3( + query_states, + key_states, + value_states, + None, # q_v_ + None, # q_descale_ + None, # k_descale_ + None, # v_descale_ + softmax_scale, + True, + -1, # window_size_left + -1, # window_size_right + 0.0, # softcap + 1, # num_splits + False, # manual_set_pack_gqa + False, # pack_gqa_ + 0, # sm_margin + ) + else: + assert False, f"invalid {FA_VERSION=}" + + if FA_VERSION == 2: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + elif FA_VERSION == 3: + ctx.save_for_backward( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + return attn_out + + @staticmethod + def backward(ctx, dout): + if FA_VERSION == 2: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + seed_offset, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + elif FA_VERSION == 3: + ( + q_init, + kv_init, + attn_out, + softmax_lse, + q_ln_weight, + kv_ln_weight, + q_up_weight, + kv_up_weight, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + eps, + kv_lora_rank, + softmax_scale, + ) = ctx.saved_tensor() + else: + assert False, f"invalid {FA_VERSION=}" + + q_ln_t, q_ln_invar = fused_ln.fused_rms_norm(q_init, q_ln_weight, eps) + + q_ln_fp8, q_ln_scale, q_ln_trans_fp8, q_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + q_ln_t.reshape([-1, q_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + q_orig_shape = q_ln_t.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (q_ln_fp8, q_ln_scale), q_up_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(q_orig_shape[:-1] + [q_up_weight.shape[-1]]) + + compressed_kv, k_pe = paddle.split(kv_init, [kv_lora_rank, qk_rope_head_dim], axis=-1) + + kv_ln_t, kv_ln_invar = fused_ln.fused_rms_norm(compressed_kv, kv_ln_weight, eps) + + kv_ln_fp8, kv_ln_scale, kv_ln_trans_fp8, kv_ln_trans_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + kv_ln_t.reshape([-1, kv_ln_t.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + kv_orig_shape = kv_ln_t.shape + kv = FP8LinearFunctionBase.compute_fp8_linear( + (kv_ln_fp8, kv_ln_scale), kv_up_weight, weight_transpose=True, return_transpose_only=True + ) + kv = kv.reshape(kv_orig_shape[:-1] + [kv_up_weight.shape[-1]]) + + paddle.base.core._set_has_grad(True) + q.stop_gradient = False + kv.stop_gradient = False + k_pe.stop_gradient = False + query_states, key_states, value_states = qkv_pre_process( + q, + kv, + k_pe, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + position_ids, + ) + + if FA_VERSION == 2: + q_head_dim = query_states.shape[-1] + query_states = query_states * softmax_scale + + bsz = value_states.shape[0] + kv_seq_len = value_states.shape[1] + v_num_heads = value_states.shape[2] + value_padding = paddle.zeros( + [bsz, kv_seq_len, v_num_heads, q_head_dim - v_head_dim], + dtype=value_states.dtype, + ) + value_states_pad = paddle.concat([value_states, value_padding], axis=-1) + + with paddle.no_grad(): + + q_grad, k_grad, v_grad = _C_ops.flash_attn_grad( + query_states, + key_states, + value_states_pad, + attn_out, + softmax_lse.view("bfloat16"), + seed_offset, + None, + dout, + 0.0, + True, + ) + + v_grad = v_grad[..., :v_head_dim] + q_grad = q_grad * softmax_scale + elif FA_VERSION == 3: + with paddle.no_grad(): + q_grad, k_grad, v_grad = _C_ops.flash_attn_v3_grad( + query_states, + key_states, + value_states, + attn_out, + softmax_lse.view("bfloat16"), + dout, + softmax_scale, + True, + -1, + -1, + 0.0, + 0, + ) + else: + assert False, f"invalid {FA_VERSION=}" + + d_q, d_kv, d_k_pe = paddle.grad( + outputs=[query_states, key_states, value_states], + inputs=[q, kv, k_pe], + grad_outputs=[q_grad, k_grad, v_grad], + create_graph=False, + retain_graph=False, + ) + + paddle.base.core._set_has_grad(False) + + # call up proj + if hasattr(kv_up_weight, "main_grad"): + d_kv_fp8, d_kv_scale, d_kv_t_fp8, d_kv_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_kv.reshape([-1, d_kv.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + + d_kv_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_kv_fp8, d_kv_scale), kv_up_weight, weight_transpose=False + ) + d_kv_ln_t = d_kv_ln_t.reshape(d_kv.shape[:-1] + [kv_up_weight.shape[0]]) + + def kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + kv_ln_trans_fp8, + kv_ln_trans_scale, + d_kv_t_fp8, + d_kv_t_scale, + True, + True, + kv_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + + WeightGradStore.put( + partial( + kv_up_weight_grad, kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight + ) + ) + else: + kv_up_weight_grad(kv_ln_trans_fp8, kv_ln_trans_scale, d_kv_t_fp8, d_kv_t_scale, kv_up_weight) + + d_kv_up_weight = None + + else: + d_kv_ln_t, d_kv_up_weight = _C_ops.matmul_grad(kv_ln_t, kv_up_weight, d_kv, False, False) + + d_compressed_kv, d_kv_ln_weight = fused_ln.fused_rms_norm_grad_func( + compressed_kv, kv_ln_weight, kv_ln_invar, d_kv_ln_t, eps + ) + + d_kv_init = paddle.concat([d_compressed_kv, d_k_pe], axis=-1) + + if hasattr(q_up_weight, "main_grad"): + + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + # d_q_ln_t = paddle.matmul(d_q, q_up_weight, transpose_y=True) + + d_q_ln_t = FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_up_weight, weight_transpose=False + ) + d_q_ln_t = d_q_ln_t.reshape(d_q.shape[:-1] + [q_up_weight.shape[0]]) + + def q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight): + FP8LinearFunctionBase.kitchen_gemm( + q_ln_trans_fp8, + q_ln_trans_scale, + d_q_t_fp8, + d_q_t_scale, + True, + True, + q_up_weight.main_grad, + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_up_weight_grad, q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + ) + else: + q_up_weight_grad(q_ln_trans_fp8, q_ln_trans_scale, d_q_t_fp8, d_q_t_scale, q_up_weight) + + d_q_up_weight = None + + else: + d_q_ln_t, d_q_up_weight = _C_ops.matmul_grad(q_ln_t, q_up_weight, d_q, False, False) + + d_q_init, d_q_ln_weight = fused_ln.fused_rms_norm_grad_func(q_init, q_ln_weight, q_ln_invar, d_q_ln_t, eps) + + return d_q_init, d_kv_init, d_q_ln_weight, d_kv_ln_weight, d_q_up_weight, d_kv_up_weight + + +class MemroyRecomputeAttn(paddle.nn.Layer): + def __init__( + self, + q_norm_hidden_size, + kv_norm_hidden_size, + q_up_in_dim, + q_up_out_dim, + kv_up_in_dim, + kv_up_out_dim, + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + ) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.q_ln_weight = paddle.create_parameter( + shape=[q_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + self.kv_ln_weight = paddle.create_parameter( + shape=[kv_norm_hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_up_weight = self.create_parameter( + shape=[q_up_in_dim, q_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_up_weight = self.create_parameter( + shape=[kv_up_in_dim, kv_up_out_dim], + dtype=self._dtype, + is_bias=False, + ) + ( + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + ) = ( + rotary_emb, + num_heads, + q_head_dim, + qk_nope_head_dim, + v_head_dim, + qk_rope_head_dim, + eps, + kv_lora_rank, + softmax_scale, + ) + set_parameter_color([self.q_up_weight, self.kv_up_weight], "memory_attn") + + def fp8_quant_weight(self): + cache_fp8_weight(self.q_up_weight) + cache_fp8_weight(self.kv_up_weight) + + def forward(self, q_init, kv_init, position_ids): + + seq_len = q_init.shape[1] + + if self.rotary_emb.max_seq_len_cached is None or seq_len > self.rotary_emb.max_seq_len_cached: + self.rotary_emb._set_cos_sin_cache(seq_len) + + return MemroyRecomputeAttnFunc.apply( + q_init, + kv_init, + self.q_ln_weight, + self.kv_ln_weight, + self.q_up_weight, + self.kv_up_weight, + self.rotary_emb, + self.num_heads, + self.q_head_dim, + self.qk_nope_head_dim, + self.v_head_dim, + self.qk_rope_head_dim, + position_ids, + self.eps, + self.kv_lora_rank, + self.softmax_scale, + ) + + +class FusedRMSLinearFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, q_down_weight, kv_down_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_fp8, h_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), output_scale_transpose=True, quant_method="1x128" + ) + + h_orig_shape = hidden_states.shape + q = FP8LinearFunctionBase.compute_fp8_linear( + (h_fp8, h_scale), q_down_weight, weight_transpose=True, return_transpose_only=True + ) + q = q.reshape(h_orig_shape[:-1] + [q_down_weight.shape[-1]]) + + kv = paddle.matmul(hidden_states, kv_down_weight) + + ctx.save_for_backward(x, rms_norm_weight, q_down_weight, kv_down_weight) + ctx.eps = eps + return q, kv + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, q_down_weight, kv_down_weight = ctx.saved_tensor() + eps = ctx.eps + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_t_fp8, h_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hidden_states.reshape([-1, hidden_states.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + h_grad, d_kv_down_weight = _C_ops.matmul_grad(hidden_states, kv_down_weight, d_kv, False, False) + + if hasattr(q_down_weight, "main_grad"): + d_q_fp8, d_q_scale, d_q_t_fp8, d_q_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + d_q.reshape([-1, d_q.shape[-1]]), + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + ) + FP8LinearFunctionBase.compute_fp8_linear( + (d_q_fp8, d_q_scale), q_down_weight, weight_transpose=False, out=h_grad.view([-1, h_grad.shape[-1]]) + ) + + def q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight): + FP8LinearFunctionBase.kitchen_gemm( + h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, True, True, q_down_weight.main_grad, paddle.float32 + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(q_down_weight_grad, h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + ) + else: + q_down_weight_grad(h_t_fp8, h_t_scale, d_q_t_fp8, d_q_t_scale, q_down_weight) + + d_q_down_weight = None + + else: + h_grad_0, d_q_down_weight = _C_ops.matmul_grad(hidden_states, q_down_weight, d_q, False, False) + h_grad = h_grad + h_grad_0 + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_q_down_weight, d_kv_down_weight + + +class FusedRMSLinear(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.q_down_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + + self.kv_down_weight = self.create_parameter( + shape=[hidden_size, kv_outdim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + set_parameter_color([self.q_down_weight], "rms_linear") + + def fp8_quant_weight(self): + cache_fp8_weight(self.q_down_weight) + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.q_down_weight, self.kv_down_weight, self.eps) + + +class FusedRMSLinearSingleFunc(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, rms_norm_weight, linear_weight, eps): + + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + q = paddle.matmul(hidden_states, linear_weight) + + ctx.save_for_backward(x, rms_norm_weight, linear_weight, eps) + return q + + @staticmethod + def backward(ctx, d_q, d_kv): + x, rms_norm_weight, linear_weight, eps = ctx.saved_tensor() + hidden_states, invar = fused_ln.fused_rms_norm(x, rms_norm_weight, eps) + + h_grad, d_linear_weight = _C_ops.matmul_grad(hidden_states, linear_weight, d_q, False, False) + + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, rms_norm_weight, invar, h_grad, eps) + + return dx, d_rms_norm_weight, d_linear_weight + + +class FusedRMSLinearSingle(paddle.nn.Layer): + def __init__(self, hidden_size, q_out_dim, kv_outdim, eps=1e-6) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.rms_norm_weight = paddle.create_parameter( + shape=[hidden_size], + dtype=self._dtype, + default_initializer=nn.initializer.Constant(1.0), + ) + + self.linear_weight = self.create_parameter( + shape=[hidden_size, q_out_dim], + dtype=self._dtype, + is_bias=False, + ) + self.eps = eps + + def forward(self, x): + + return FusedRMSLinearFunc.apply(x, self.rms_norm_weight, self.linear_weight, self.eps) + + # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1922,10 +2893,11 @@ def compute_loss(preds, labels): masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) ) count = paddle.sum(binary_sequence) - if count == 0: - loss = paddle.sum(masked_lm_loss * binary_sequence) - else: - loss = paddle.sum(masked_lm_loss * binary_sequence) / count + loss = paddle.where( + count == 0, + paddle.sum(masked_lm_loss * binary_sequence), + paddle.sum(masked_lm_loss * binary_sequence) / count, + ) return loss def add_loss(main_loss, loss): @@ -1956,7 +2928,7 @@ def add_loss(main_loss, loss): class DeepseekV2LMHead(nn.Layer): - def __init__(self, config: DeepseekV2Config): + def __init__(self, config: DeepseekV2Config, embedding_weight=None): super(DeepseekV2LMHead, self).__init__() self.config = config @@ -1970,11 +2942,16 @@ def __init__(self, config: DeepseekV2Config): else: vocab_size = config.vocab_size - self.weight = self.create_parameter( - shape=[config.hidden_size, vocab_size], - dtype=paddle.get_default_dtype(), - default_initializer=nn.initializer.XavierNormal(1.0), - ) + if embedding_weight is not None: + self.transpose_y = True + self.weight = embedding_weight + else: + self.transpose_y = False + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.XavierNormal(1.0), + ) # Must set distributed attr for Tensor Parallel ! self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if get_env_device() == "xpu": @@ -2004,7 +2981,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): training=self.training, ) else: - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + logits = parallel_matmul(hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output) return logits def extra_repr(self): diff --git a/paddleformers/transformers/deepseek_v2/modeling_fast.py b/paddleformers/transformers/deepseek_v2/modeling_fast.py new file mode 100644 index 00000000000..047fe6a269e --- /dev/null +++ b/paddleformers/transformers/deepseek_v2/modeling_fast.py @@ -0,0 +1,1580 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 DeepSeek. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle DeepSeek model.""" + +import contextlib +import math +import os +import warnings +from functools import partial +from typing import List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute +from paddle.jit import to_static +from paddle.utils import try_import + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +from paddle import _C_ops +from paddleformers.transformers.model_utils import dtype_guard + +from ...utils.initializer import kaiming_uniform_ +from ...utils.log import logger +from ...utils.tools import get_env_device +from ..activations import ACT2FN +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..llama import fusion_ops +from ..llama.modeling import get_use_casual_mask +from ..model_outputs import ( + BaseModelOutputWithPastAndMTP, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from ..model_utils import PretrainedModel, dtype_guard, register_base_model +from ..moe_gate import PretrainedMoEGate +from ..moe_layer import MoEFlexTokenLayer, MoELayer +from ..utils import cast_if_needed, device_guard +from . import fp8_linear as linear_utils +from .configuration import DeepseekV2Config + +FA_VERSION = int(os.getenv("FA_VERSION", 2)) + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +from ..fp8_utils import ( + FP8KeepXLinear, + FP8Linear, + FP8Mlp, + set_parameter_color, +) +from .fp8_linear import Linear + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_ATTEN_RECOMPUTE = os.getenv("DSV3_USE_ATTEN_RECOMPUTE", "False").lower() == "true" + +Linear = FP8Linear if DSV3_USE_FP8_GEMM else Linear + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +try: + from paddle.incubate.nn.functional import fused_partial_rope +except ImportError: + fused_partial_rope = None + +__all__ = [ + "DeepseekV2ModelFast", + "DeepseekV2PretrainedModelFast", +] + +from .modeling import (set_global_step, scaled_dot_product_attention, is_casual_mask, _make_causal_mask, _expand_2d_mask, yarn_get_mscale, apply_rotary_pos_emb, DeepseekV2RMSNorm, DeepseekV2YarnRotaryEmbedding, FusedRMSLinear, MemroyRecomputeAttn, FusedNormGateFunc, FakeGate) + +class DeepseekV2MLP(nn.Layer): + def __init__(self, config: DeepseekV2Config, hidden_size=None, intermediate_size=None, is_moe=False): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.fuse_attention_ffn = config.fuse_attention_ffn + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + with linear_dtype_gaurd(): + if config.tensor_parallel_degree > 1 and not is_moe: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.fuse_attention_ffn: + x = swiglu(self.gate_up_fused_proj(x)) + else: + x = swiglu(self.gate_proj(x), self.up_proj(x)) + out = self.down_proj(x) + return out + + +class MoEGate(PretrainedMoEGate): + def __init__( + self, + config, + num_experts, + expert_hidden_size, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + **kwargs + ): + super().__init__(config, num_experts, expert_hidden_size, **kwargs) + # [hidden_size, n_expert] + + self.scoring_func = config.scoring_func + self.topk_method = config.topk_method + + self.weight = paddle.create_parameter( + shape=[expert_hidden_size, num_experts], + dtype=paddle.float32, + is_bias=False, + # default_initializer=nn.initializer.Constant(1.0), + ) + + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + + if config.topk_method == "noaux_tc": + self.e_score_correction_bias = paddle.create_parameter( + shape=[num_experts], + dtype=paddle.float32, + default_initializer=nn.initializer.Constant(0.0), + ) + self.e_score_correction_bias.is_distributed = True + + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.using_flex_token = False + + def forward(self, hidden_states): + """ + Args: + hidden_states (_type_): [batch_size * seq_len, hidden_size] + """ + _, _, h_dim = hidden_states.shape + + # compute gating score + if self.using_post_norm_recompute: + logits, norm_out = FusedNormGateFunc.apply(hidden_states, self.norm_weight, self.weight, self.norm_eps) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + with paddle.amp.auto_cast(False): + hidden_states = hidden_states.cast(self.weight.dtype) + if hasattr(self.config, "using_fake_gate") and self.config.using_fake_gate: + logits = FakeGate.apply( + hidden_states, + self.weight, + self.config.fakse_gate_restrict_balance, + self.config.num_experts_per_tok, + ) + else: + logits = F.linear(hidden_states, self.weight, None) + + scores = self.gate_score_func(logits=logits) + scores = scores.cast(paddle.float32) + + # Compute all possible return values + if self.using_flex_token: + scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop( + scores + ) # (scores, routing_map, exp_counts, l_aux, l_zloss) + ret = (scores, routing_map, l_aux, l_zloss) + else: + ret = self.topkgating(scores) # (capacity, combine_weights, dispatch_mask, exp_counts, l_aux, l_zloss) + + # Append norm_out if needed + if self.using_post_norm_recompute: + ret = (*ret, norm_out) + + return ret + + +class AddAuxiliaryLoss(paddle.autograd.PyLayer): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + ctx.dtype = loss.dtype + ctx.required_aux_loss = not loss.stop_gradient + return x.clone() # clone to avoid inplace problem when using overlap + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = paddle.ones(1, dtype=ctx.dtype) + return grad_output, grad_loss + + +class DeepseekV2MoE(MoELayer): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config: DeepseekV2Config, norm_weight=None, norm_eps=None): + assert config.tensor_parallel_degree <= 1, "tensor_parallel_degree should be 1" + + self.using_post_norm_recompute = config.using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + + gate = MoEGate( + config=config, + num_experts=config.n_routed_experts, + expert_hidden_size=config.hidden_size, + top_k=config.num_experts_per_tok, + topk_method=config.topk_method, + n_group=config.n_group, + topk_group=config.topk_group, + norm_topk_prob=config.norm_topk_prob, + routed_scaling_factor=config.routed_scaling_factor, + drop_tokens=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + ) + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + super().__init__( + config=config, + moe_num_experts=config.n_routed_experts, + expert_class=DeepseekV2MLPClass, + expert_kwargs={ + "config": config, + "intermediate_size": config.moe_intermediate_size, + "is_moe": True, + }, + gate=gate, + capacity=2.0, + moe_group="expert", + using_post_norm_recompute=self.using_post_norm_recompute, + ) + + if config.offline_quant_expert_weight and config.clear_origin_weight_when_offline_quant: + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + for p in expert_w1_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + for p in expert_w2_list: + setattr(p, "color", {"color": "moe_expert", "group": moe_grad_group}) + + self.alpha = config.aux_loss_alpha + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + if self.using_post_norm_recompute: + assert DeepseekV2MLPClass is FP8Mlp + self.shared_experts = DeepseekV2MLPClass( + config=config, + intermediate_size=intermediate_size, + is_moe=False, + using_post_norm_recompute=self.using_post_norm_recompute, + norm_weight=norm_weight, + norm_eps=norm_eps, + recompute_fwd_gate_up=True, + ) + else: + self.shared_experts = DeepseekV2MLPClass( + config=config, intermediate_size=intermediate_size, is_moe=False + ) + set_parameter_color([self.shared_experts.w1, self.shared_experts.w2], "shared_expert") + + def fp8_quant_weight(self, batch_mode=False): + """Quantize weights in FP8 format. + + Args: + batch_mode: If True, quantize all weights in batch mode using the first expert's weights. + If False, quantize each expert's weights individually. + """ + + def quantize_weights(weight_list, weight_obj=None): + """Helper function to quantize a list of weights.""" + if weight_obj is None: + weight_obj = weight_list[0] + if hasattr(weight_obj, "fp8_weight_stacked"): + return + + # Quantize without transpose + fp8_weight, fp8_scale = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=False + ) + setattr(weight_obj, "fp8_weight_stacked", fp8_weight) + setattr(weight_obj, "fp8_scale_stacked", fp8_scale) + + # Quantize with transpose + fp8_weight_t, fp8_scale_t = paddle.incubate.nn.functional.fused_stack_transpose_quant( + weight_list, transpose=True + ) + setattr(weight_obj, "fp8_weight_stacked_transpose", fp8_weight_t) + setattr(weight_obj, "fp8_scale_stacked_transpose", fp8_scale_t) + + if batch_mode: + # Batch mode: process all experts' weights together + expert_w1_list = [expert.w1 for expert in self.experts if expert is not None] + expert_w2_list = [expert.w2 for expert in self.experts if expert is not None] + + if expert_w1_list: + quantize_weights(expert_w1_list, expert_w1_list[0]) + if expert_w2_list: + quantize_weights(expert_w2_list, expert_w2_list[0]) + else: + # Individual mode: process each expert's weights separately + for expert in self.experts: + if expert is not None: + quantize_weights([expert.w1]) + quantize_weights([expert.w1]) + + if self.config.n_shared_experts is not None: + self.shared_experts.fp8_quant_weight() + + def forward(self, hidden_states): + if self.using_post_norm_recompute: + super().update_flex_token() + if self.using_flex_token: + probs, routing_map, l_aux, l_zloss, norm_out = self.router(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, probs=probs, routing_map=routing_map, l_aux=l_aux, l_zloss=l_zloss + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss, norm_out = self.gate(hidden_states) + final_hidden_states, l_aux, l_zloss = super().forward( + norm_out, + capacity=capacity, + topk_weight=topk_weight, + topk_ids=topk_ids, + token_priority=token_priority, + l_aux=l_aux, + l_zloss=l_zloss, + ) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + else: + final_hidden_states, l_aux, l_zloss = super().forward(hidden_states) + final_hidden_states = self.post_process(hidden_states, final_hidden_states, l_aux) + return final_hidden_states + + def post_process(self, hidden_states, final_hidden_states, l_aux): + if self.training and self.alpha > 0.0: + l_aux = l_aux * self.alpha + final_hidden_states = AddAuxiliaryLoss.apply(final_hidden_states, l_aux) + + if self.config.n_shared_experts is not None: + shared_expert_output = self.shared_experts(hidden_states) + final_hidden_states = final_hidden_states + shared_expert_output + return final_hidden_states + +# # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Layer): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + self.fuse_rope = config.use_fused_rope + + if config.num_nextn_predict_layers > 0: + self.seq_length = config.seq_length - config.num_nextn_predict_layers + else: + self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + + self.input_layernorm = DeepseekV2RMSNorm(config) + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + + def linear_dtype_gaurd(): + if config.use_fp8: + return dtype_guard("float8_e4m3fn") + else: + return contextlib.nullcontext() + + # Note (@DrownFish19): For tensor parallel we consider that q_a_proj and kv_a_proj_with_mqa + # are the small weight and cannot achieve performance gain. So we use the original + # linear layers. We use the tensor parallel linear layers for q_proj,q_b_proj and kv_b_proj + # for which are the large weight and can achieve performance gain. + + self._init_rope() + self.softmax_scale = self.q_head_dim ** (-0.5) + + # fmt: off + if self.config.tensor_parallel_degree > 1: + # for tensor parallel + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = ColumnParallelLinear(self.hidden_size, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = ColumnParallelLinear(config.q_lora_rank, self.num_heads * self.q_head_dim, has_bias=False, gather_output=True) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank, use_sequence_parallel=False) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True) + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False) + else: + # for without tensor parallel + if DSV3_USE_ATTEN_RECOMPUTE: + self.fused_rms_norm_linear = FusedRMSLinear(self.hidden_size, config.q_lora_rank, config.kv_lora_rank + config.qk_rope_head_dim, 1e-6) + kv_up_dim = self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) + self.memory_recompute_att = MemroyRecomputeAttn(config.q_lora_rank, config.kv_lora_rank, config.q_lora_rank, self.num_heads * self.q_head_dim, config.kv_lora_rank, kv_up_dim, self.rotary_emb, self.num_heads, self.q_head_dim, self.qk_nope_head_dim, self.v_head_dim, self.qk_rope_head_dim, 1e-6, self.kv_lora_rank, self.softmax_scale) + self.o_proj = FP8KeepXLinear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + else: + + if self.q_lora_rank is None: + with linear_dtype_gaurd(): + self.q_proj = Linear(self.hidden_size, self.num_heads * self.q_head_dim, bias_attr=False) + else: + with linear_dtype_gaurd(): + self.q_a_proj = Linear(self.hidden_size, config.q_lora_rank, bias_attr=config.attention_bias) + self.q_b_proj = Linear(config.q_lora_rank, self.num_heads * self.q_head_dim, bias_attr=False) + self.q_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.q_lora_rank) + + with linear_dtype_gaurd(): + self.kv_a_proj_with_mqa = paddle.nn.Linear(self.hidden_size, config.kv_lora_rank + config.qk_rope_head_dim, bias_attr=config.attention_bias) + self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False) + self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias) + self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank) + + # fmt: on + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.attn_func = scaled_dot_product_attention + + def fp8_quant_weight(self): + + if DSV3_USE_ATTEN_RECOMPUTE: + self.o_proj.fp8_quant_weight() + self.memory_recompute_att.fp8_quant_weight() + self.fused_rms_norm_linear.fp8_quant_weight() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: paddle.Tensor, seq_len: int, bsz: int): + return tensor.reshape([bsz, seq_len, self.num_heads, self.v_head_dim]).transpose([1, 0, 2, 3]) + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.shape + + # DeepSeekV2 q_lora_rank=1536 + # DeepSeekV2-lite q_lora_rank=None + if DSV3_USE_ATTEN_RECOMPUTE: + + q_t1, compressed_kv = self.fused_rms_norm_linear(hidden_states) + + outputs = self.memory_recompute_att(q_t1, compressed_kv, position_ids) + + if self.v_head_dim * self.num_heads != outputs.shape[-1]: + outputs = outputs.reshape([bsz, q_len, self.num_heads, -1]) + outputs = outputs[..., : self.v_head_dim] + outputs = outputs.reshape([bsz, q_len, -1]) + else: + # 这里多了一个layernorm,是因为把 DeepseekV2Attention 之外的一次计算放进来了 + hidden_states = self.input_layernorm(hidden_states) + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + + if self.sequence_parallel: + target_query_shape = [-1, self.seq_length, self.num_heads, self.q_head_dim] + target_key_value_shape = [-1, self.seq_length, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.q_head_dim] + target_key_value_shape = [0, 0, self.num_heads, self.qk_nope_head_dim + self.v_head_dim] + + q = q.reshape(shape=target_query_shape) + q_nope, q_pe = paddle.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1) + + # DeepSeekV2 kv_lora_rank+qk_rope_head_dim=512+64 + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = paddle.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], axis=-1) + if self.sequence_parallel: + k_pe = GatherOp.apply(k_pe) + k_pe = k_pe.reshape([-1, q_len, 1, self.qk_rope_head_dim]).expand( + [-1, q_len, self.num_heads, self.qk_rope_head_dim] + ) + + # self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim = 128+64 + # self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim) = config.qk_nope_head_dim + self.v_head_dim = 128+128 + kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv)).reshape(shape=target_key_value_shape) + + k_nope, value_states = paddle.split(kv, [self.qk_nope_head_dim, self.v_head_dim], axis=-1) + kv_seq_len = value_states.shape[1] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids, self.fuse_rope) + + query_states = paddle.concat([q_nope, q_pe], axis=-1) + key_states = paddle.concat([k_nope, k_pe], axis=-1) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + # reuse k, v, self_attention + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + outputs = recompute( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + softmax_scale=self.softmax_scale, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2DecoderLayer(nn.Layer): + def __init__(self, config: DeepseekV2Config, layer_idx: int, layerwise_recompute: bool = False): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + self.using_post_norm_recompute = config.using_post_norm_recompute + + self.hidden_size = config.hidden_size + + self.self_attn = DeepseekV2Attention(config=config, layerwise_recompute=layerwise_recompute) + + DeepseekV2MLPClass = FP8Mlp if DSV3_USE_FP8_GEMM else DeepseekV2MLP + + self.input_layernorm = DeepseekV2RMSNorm(config) + self.post_attention_layernorm = DeepseekV2RMSNorm(config) + + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ): + self.mlp = ( + DeepseekV2MoE( + config, self.post_attention_layernorm.weight, self.post_attention_layernorm.variance_epsilon + ) + if config.using_post_norm_recompute + else DeepseekV2MoE(config) + ) + else: + self.mlp = DeepseekV2MLPClass(config) + + def fp8_quant_weight(self, batch_mode=False): + """fp8_quant_weight""" + if isinstance(self.mlp, DeepseekV2MoE): + # logger.info(f"fp8 quant weight for mlp {type(self.mlp)}") + self.mlp.fp8_quant_weight(batch_mode) + self.self_attn.fp8_quant_weight() + elif isinstance(self.mlp, FP8Mlp): + self.self_attn.fp8_quant_weight() + + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_axis)` + attention_mask (`paddle.Tensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + past_key_value=past_key_value, + use_cache=use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + def self_attn_compute(self, hidden_states, **kwargs): + residual = hidden_states + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + outputs = recompute( + self.self_attn, + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + else: + outputs = self.self_attn( + hidden_states=hidden_states, + position_ids=None, + attention_mask=None, + output_attentions=False, + past_key_value=None, + use_cache=False, + attn_mask_startend_row_indices=None, + **kwargs, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + hidden_states = residual + hidden_states + + residual = hidden_states + + if not (self.using_post_norm_recompute and isinstance(self.mlp, DeepseekV2MoE)): + hidden_states = self.post_attention_layernorm(hidden_states) + + return hidden_states, residual + + def pre_dispatch_compute(self, hidden_states): + l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs = self.mlp.pre_dispatch_compute( + hidden_states + ) + + return l_aux, l_zloss, intermediate_hidden_states, token_indices, token_probs + + def expert_forward_compute(self, intermediate_hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.mlp.post_dispatch_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + + expert_output = self.mlp.expert_forward(global_input_tokens) + + expert_output = self.mlp.pre_combine_compute( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + + return expert_output + + def post_combine_compute(self, residual, hidden_states, final_hidden_states, l_aux): + final_hidden_states = self.mlp.post_combine_compute(final_hidden_states) + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + final_hidden_states = residual + final_hidden_states + + outputs = (final_hidden_states,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class DeepseekV2MTPLayer(DeepseekV2DecoderLayer): + def __init__( + self, + config: DeepseekV2Config, + layer_idx: int, + layerwise_recompute: bool = False, + ): + super(DeepseekV2MTPLayer, self).__init__(config, layer_idx, layerwise_recompute) + + self.enorm = DeepseekV2RMSNorm(config) + self.hnorm = DeepseekV2RMSNorm(config) + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias_attr=False) + + def forward( + self, + hidden_states: paddle.Tensor, + nextn_hidden_state: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + concat_h = paddle.concat([hidden_states, nextn_hidden_state], axis=-1) + hidden_states = LMHeadFunction.apply( concat_h, self.eh_proj.weight, False) + + layer_outputs = super(DeepseekV2MTPLayer, self).forward( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + **kwargs, + ) + + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + return hidden_states + + +class DeepseekV2PretrainedModelFast(PretrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "deepseek_v2" + _no_split_modules = ["DeepseekV2DecoderLayer"] + + def _get_model_flops(self, batch_size=1, seq_length=None, **kwargs): + from .mfu_utils import DeepSeekProjection + + # self._ + mfu_cal_proj = DeepSeekProjection(self.config) + if seq_length is None: + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return mfu_cal_proj.get_num_flop_per_token() + + def _get_hardware_flops(self, *args, **kwargs): + return self._get_model_flops(*args, **kwargs) + + @classmethod + def _get_name_mappings(cls, config: DeepseekV2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + # last one layer contains MTP (eagle) parameters for inference + for layer_index in range(config.num_hidden_layers + config.num_nextn_predict_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.q_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_proj_with_mqa.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.kv_a_layernorm.weight"], + [f"layers.{layer_index}.self_attn.kv_b_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + # MoE parameters + model_mappings.append([f"layers.{layer_index}.mlp.gate.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.gate.e_score_correction_bias"]) + for expert_idx in range(config.n_routed_experts): + expert_mappings = [ + [f"layers.{layer_index}.mlp.experts.{expert_idx}.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.experts.{expert_idx}.down_proj.weight", None, "transpose"], + ] + model_mappings.extend(expert_mappings) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.gate_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.up_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.mlp.shared_experts.down_proj.weight", None, "transpose"]) + + # MTP (eagle) parameters for inference + if layer_index >= config.num_hidden_layers: + model_mappings.append([f"layers.{layer_index}.embed_tokens.weight"]) + model_mappings.append([f"layers.{layer_index}.enorm.weight"]) + model_mappings.append([f"layers.{layer_index}.hnorm.weight"]) + model_mappings.append([f"layers.{layer_index}.eh_proj.weight", None, "transpose"]) + model_mappings.append([f"layers.{layer_index}.shared_head.norm.weight"]) + model_mappings.append([f"layers.{layer_index}.shared_head.head.weight", None, "transpose"]) + + init_name_mappings(mappings=model_mappings) + if cls.base_model_class.__name__ not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = f"{cls.base_model_prefix}." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: DeepseekV2Config, is_split=True): + from paddleformers.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + } + if config.use_fp8: + base_actions["layers.0.self_attn.o_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + + # Column Linear + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_b_proj.weight"] = partial(fn, is_column=True) + + # if we have enough num_key_value_heads to split, then split it. + # ??? + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.kv_b_proj.weight"] = partial(fn, is_column=True) + if config.use_fp8: + base_actions["layers.0.self_attn.kv_b_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + + # dense mlp + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.up_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.gate_proj.weight.weight_scale_inv"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.weight.weight_scale_inv"] = partial(fn, is_column=False) + + # moe unit routed experts + moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + expert_parallel_degree = dist.get_world_size(moe_group) + if expert_parallel_degree <= 1: + for e_i in range(config.n_routed_experts): + base_actions[f"layers.0.mlp.experts.{e_i}.up_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.gate_proj.weight"] = partial(fn, is_column=True) + base_actions[f"layers.0.mlp.experts.{e_i}.down_proj.weight"] = partial(fn, is_column=False) + + # moe unit shared experts + base_actions["layers.0.mlp.shared_experts.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.up_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.shared_experts.down_proj.weight"] = partial(fn, is_column=False) + if config.use_fp8: + base_actions["layers.0.mlp.shared_experts.gate_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.up_proj.weight.weight_scale_inv"] = partial( + fn, is_column=True + ) + base_actions["layers.0.mlp.shared_experts.down_proj.weight.weight_scale_inv"] = partial( + fn, is_column=False + ) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + # for MTP (eagle) parameters for inference + base_actions.pop("embed_tokens.weight") + base_actions.pop("lm_head.weight") + base_actions["layers.0.embed_tokens.weight"] = partial(fn, is_column=False) + base_actions["layers.0.shared_head.head.weight"] = partial(fn, is_column=True) + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range( + config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers + ): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + else: + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + def _init_weights(self, layer): + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + Linear, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.config.initializer_range, + shape=layer.weight.shape, + ) + ) + + # set bias to zeros + if getattr(layer, "bias", None) is not None: + layer.bias.set_value(paddle.zeros(shape=layer.bias.shape)) + + if isinstance(layer, nn.Embedding): + if layer._padding_idx is not None: + layer.weight.data[layer._padding_idx].fill_(0) + + if isinstance(layer, MoEGate): + kaiming_uniform_(layer.weight, a=math.sqrt(5)) + + moe_grad_group = fleet.get_hybrid_communicate_group().expert_grad_comm_group + if moe_grad_group is not None and moe_grad_group.nranks > 1: + for p in layer.parameters(): + if hasattr(p, "color") and "color" in p.color: + if p.color["color"] == "moe_expert": + paddle.distributed.broadcast(p, src=moe_grad_group.ranks[0], group=moe_grad_group) + + def step_flex_token(self, cur_step): + set_global_step(cur_step) + +@register_base_model +class DeepseekV2ModelFast(DeepseekV2PretrainedModelFast): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding(config.vocab_size, config.hidden_size) + else: + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + + self.layers = nn.LayerList( + [ + DeepseekV2DecoderLayer(config, layer_idx, layer_idx not in self.no_recompute_layers) + for layer_idx in range(config.num_hidden_layers) + ] + ) + for layer_idx in range(config.num_hidden_layers, config.num_hidden_layers + config.num_nextn_predict_layers): + self.layers.append(DeepseekV2MTPLayer(config, layer_idx, layer_idx not in self.no_recompute_layers)) + + self.norm = DeepseekV2RMSNorm(config) + + self.enable_recompute = False + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask( + input_shape, + past_key_values_length=past_key_values_length, + ) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min).astype( + dtype + ) + return expanded_attn_mask + + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices: Optional[Tensor] = None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = recompute( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices: Optional[Tensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPastAndMTP]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.config.num_nextn_predict_layers > 0: + seq_length -= self.config.num_nextn_predict_layers + + if attention_mask is not None: + attention_mask = attention_mask[ + :, :, : -self.config.num_nextn_predict_layers, : -self.config.num_nextn_predict_layers + ] + + if self.enable_recompute and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values[0] is not None: + past_key_values_length = past_key_values[0][0].shape[1] + seq_length_with_past += past_key_values_length + + if position_ids is None: + position_ids = paddle.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=paddle.int64 + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + # [bs, seq_len, dim] + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + if attn_mask_startend_row_indices is not None or get_use_casual_mask(): + attention_mask = None + else: + # [bs, seq_len] + attention_mask = ( + paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attention_mask is None + else attention_mask + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), past_key_values_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + if self.config.use_flash_attention: + attention_mask = None if is_casual_mask(attention_mask) else attention_mask + + if self.config.num_nextn_predict_layers > 0: + inputs_embeds_extra = inputs_embeds[:, -self.config.num_nextn_predict_layers :, :] # [B, S, D] + inputs_embeds = inputs_embeds[:, : -self.config.num_nextn_predict_layers, :] + inputs_embeds_ori = inputs_embeds + + if self.config.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + mtp_outputs = [] + + for idx in range(self.config.num_hidden_layers): + decoder_layer = self.layers[idx] + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.config.num_nextn_predict_layers > 0: + mtp_outputs.append(hidden_states) + + for nextn in range(self.config.num_nextn_predict_layers): + decoder_layer = self.layers[nextn + self.config.num_hidden_layers] + + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape([-1, seq_length, hidden_states.shape[-1]]) + + inputs_embeds_cur_depth = paddle.concat( + [inputs_embeds_ori[:, (nextn + 1) :, :], inputs_embeds_extra[:, : (nextn + 1), :]], axis=1 + ) + + past_key_value = None + layer_outputs = decoder_layer( + hidden_states, + inputs_embeds_cur_depth, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + ) + + if isinstance(layer_outputs, (tuple, list)): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + mtp_outputs.append(hidden_states) + mtp_outputs = [self.norm(hidden_states) for hidden_states in mtp_outputs] + hidden_states, mtp_outputs = mtp_outputs[0], mtp_outputs[1:] + else: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, mtp_outputs] if v is not None + ) + return BaseModelOutputWithPastAndMTP( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + mtp_outputs=mtp_outputs, + ) diff --git a/paddleformers/transformers/deepseek_v2/modeling_pp.py b/paddleformers/transformers/deepseek_v2/modeling_pp.py index 42b0e5de776..a659f976e72 100644 --- a/paddleformers/transformers/deepseek_v2/modeling_pp.py +++ b/paddleformers/transformers/deepseek_v2/modeling_pp.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import math +import os from typing import OrderedDict, Tuple, Union import paddle @@ -20,32 +21,87 @@ import paddle.nn as nn from paddle.distributed.fleet.meta_parallel import ( LayerDesc, + LocalSharedLayerDesc, PipelineLayer, + ScheduleChunk, + ScheduleNode, SharedLayerDesc, ) +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +try: + from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import EventStore +except ImportError: + EventStore = None from paddle.distributed.fleet.recompute.recompute import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import ScatterOp +from ...utils.log import logger from ...utils.tools import get_env_device from ..model_utils import PipelinePretrainedModel -from .modeling import ( - DeepseekV2Config, - DeepseekV2DecoderLayer, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2MTPLayer, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, - DeepseekV2RMSNorm, + +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from .modeling import ( + DeepseekV2Config, + DeepseekV2DecoderLayer, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2MoE, + DeepseekV2MTPLayer, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) +else: + from .modeling import ( + DeepseekV2Config, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + DeepseekV2RMSNorm, + TemporaryVarContext, + set_global_step, + ) + from .modeling_fast import ( + DeepseekV2MoE, + DeepseekV2DecoderLayer, + DeepseekV2MTPLayer, + ) + from .modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from .modeling_fast import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel + + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +from paddleformers.transformers.fused_a2a import ( + fused_combine_backward_func, + fused_combine_forward_func, + fused_dispatch_backward_func, + fused_dispatch_forward_func, ) +from paddleformers.transformers.moe_layer import FusionMoeNode + +from ..fp8_utils import FP8LinearFunctionBase __all__ = [ "DeepseekV2ForCausalLMPipe", ] +import queue + +global_inputs_embeds_mtp_queue = queue.Queue() + + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + def parse_args(args): - if isinstance(args, tuple): + if isinstance(args, (tuple, list)): if len(args) == 4: hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = args @@ -55,6 +111,9 @@ def parse_args(args): elif len(args) == 2: hidden_states, attention_mask = args attn_mask_startend_row_indices, position_ids = None, None + else: # len(args) == 1: + hidden_states = args[0] + attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None else: hidden_states = args attention_mask, attn_mask_startend_row_indices, position_ids = None, None, None @@ -93,6 +152,1181 @@ def get_attr(layer, name): return get_attr(layer._layer, name) +def calc_stream_wait(group_id): + comm_event = deep_ep.get_event_from_comm_stream(group_id) + comm_event.calc_stream_wait(group_id) + + +class TensorMeta: + """Recording the meta info of forward inputs, to avoid 0-size problems""" + + def __init__(self, tensor): + self.shape = tensor.shape + self.dtype = tensor.dtype + + +class PostProcessNode(ScheduleNode): + def __init__( + self, + send_mtp_embed, + training, + alpha, + config, + shared_experts=None, + using_post_norm_recompute=False, + output_mtp_embed_first=False, + name="PostProcessNode", + ): + self.send_mtp_embed = send_mtp_embed + self.shared_experts = shared_experts + self.traning = training + self.config = config + self.alpha = alpha + self.using_post_norm_recompute = using_post_norm_recompute + self.output_mtp_embed_first = output_mtp_embed_first + self.name = name + + if self.using_post_norm_recompute: + assert self.shared_experts is not None + assert self.shared_experts.norm_weight is not None and self.shared_experts.norm_eps is not None + + def forward_without_residual(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + residual = residual + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + + hidden_states = residual + hidden_states.stop_gradient = False + + if self.send_mtp_embed: + assert not self.output_mtp_embed_first, "forward_without_residual doesn't support output_mtp_embed_first" + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + def forward(self, inputs): + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + with paddle.no_grad(): + if self.shared_experts is not None: + if self.using_post_norm_recompute: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + norm_out, self.shared_experts.w1, self.shared_experts.w2 + ) + norm_out = None + del norm_out + else: + _, _, _, shared_expert_output = FP8LinearFunctionBase.fp8_mlp_fwd( + hidden_states, self.shared_experts.w1, self.shared_experts.w2 + ) + final_hidden_states = final_hidden_states + shared_expert_output + + self.x = hidden_states + self.l_aux = l_aux + hidden_states = residual + final_hidden_states + + if self.send_mtp_embed: + if self.output_mtp_embed_first: + hidden_states = paddle.concat([inputs_embeds_mtp, hidden_states], axis=-1) + else: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + self.mtp_embed_shape = inputs_embeds_mtp.shape # 保存mtp_embed的shape用于反向传播 + + return return_args(hidden_states) + + @paddle.no_grad() + def backward(self, output_grad): + (do3,) = output_grad + + if self.send_mtp_embed: + # 分割梯度:do3的前部分对应hidden_states,后部分对应inputs_embeds_mtp + hidden_size = do3.shape[-1] - self.mtp_embed_shape[-1] + if self.output_mtp_embed_first: + hidden_states_grad = do3[..., hidden_size:] + inputs_embeds_mtp_grad = do3[..., :hidden_size] + else: + hidden_states_grad = do3[..., :hidden_size] + inputs_embeds_mtp_grad = do3[..., hidden_size:] + else: + hidden_states_grad = do3 + inputs_embeds_mtp_grad = None + + if self.using_post_norm_recompute: + dx, norm_out, invar = FP8LinearFunctionBase.fp8_mlp_bwd_norm_rc( + hidden_states_grad, + self.x, + self.shared_experts.norm_weight, + self.shared_experts.norm_eps, + self.shared_experts.w1, + self.shared_experts.w2, + ) + else: + dx = FP8LinearFunctionBase.fp8_mlp_bwd( + hidden_states_grad, self.x, self.shared_experts.w1, self.shared_experts.w2, True + ) + + self.x = None + + residual_grad = hidden_states_grad + l_aux_grad = paddle.ones(1, dtype=self.l_aux.dtype) * self.alpha + final_hidden_states_grad = hidden_states_grad + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + return ( + inputs_embeds_mtp_grad, + dx, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar) + else: + if self.send_mtp_embed: + return (inputs_embeds_mtp_grad, dx, residual_grad, l_aux_grad, final_hidden_states_grad) + else: + return (dx, residual_grad, l_aux_grad, final_hidden_states_grad) + + +class DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + dispatch_node, + mlp_node, + combine_node, + post_process_node, + mlp_layer, + name="DecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + assert (dispatch_node is None and combine_node is None) or ( + dispatch_node is not None and combine_node is not None + ) + self.attn_node = attn_node + self.dispatch_node = dispatch_node + self.mlp_node = mlp_node + self.combine_node = combine_node + self.post_process_node = post_process_node + + self.mlp_layer = mlp_layer + self.moe_group = mlp_layer.moe_group + self.moe_num_experts = mlp_layer.moe_num_experts + + self.states = None + self.hidden_states_meta = None + self.dispatched_probs_meta = None + self.combine_output_meta = None + + def dispatch_forward(self, inputs, previous_event=None, allocate_on_comm_stream=False): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + token_indices, + token_probs, + ) = inputs + + with paddle.no_grad(): + intermediate_hidden_states, dispatched_probs, states, _ = fused_dispatch_forward_func( + intermediate_hidden_states, + token_indices, + token_probs, + self.moe_num_experts, + self.moe_group, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + dispatched_indices = states["dispatched_indices"] + self.mlp_layer.set_tokens_per_expert(states["tokens_per_expert"]) + dispatched_indices.stop_gradient = True + intermediate_hidden_states.stop_gradient = False + dispatched_probs.stop_gradient = False + self.states = states + self.hidden_states_meta = TensorMeta(intermediate_hidden_states) + self.dispatched_probs_meta = TensorMeta(dispatched_probs) + + inputs = ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def combine_forward(self, inputs, previous_event=None): + paddle.base.core.nvprof_nvtx_push("raw_combine_forward") + if isinstance(inputs, list): + inputs = tuple(inputs) + (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) = inputs + + with paddle.no_grad(): + combine_output = fused_combine_forward_func( + expert_output, self.moe_group, self.states, previous_event=previous_event, async_finish=True + ) + combine_output.stop_gradient = False + self.combine_output_meta = TensorMeta(combine_output) + inputs = (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) + paddle.base.core.nvprof_nvtx_pop() + return inputs + + def dispatch_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_dispatch_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + dispatched_indices_grad, + dispatched_probs_grad, + ) = output_grad + + if intermediate_hidden_states_grad is None: + intermediate_hidden_states_grad = paddle.zeros( + self.hidden_states_meta.shape, self.hidden_states_meta.dtype + ) + if dispatched_probs_grad is None: + dispatched_probs_grad = paddle.zeros(self.dispatched_probs_meta.shape, self.dispatched_probs_meta.dtype) + with paddle.no_grad(): + intermediate_hidden_states_grad, token_indices_grad, token_probs_grad = fused_dispatch_backward_func( + intermediate_hidden_states_grad, + dispatched_probs_grad, + self.moe_group, + self.states["handle"], + async_finish=True, + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + intermediate_hidden_states_grad, + token_indices_grad, + token_probs_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def combine_backward(self, output_grad): + paddle.base.core.nvprof_nvtx_push("raw_combine_backward") + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + combine_output_grad, + ) = output_grad + + if combine_output_grad is None: + combine_output_grad = paddle.zeros(self.combine_output_meta.shape, self.combine_output_meta.dtype) + with paddle.no_grad(): + expert_output_grad = fused_combine_backward_func( + combine_output_grad, self.moe_group, self.states["handle"], async_finish=True + ) + + output_grad = ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + expert_output_grad, + ) + paddle.base.core.nvprof_nvtx_pop() + return output_grad + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + + if self.dispatch_node is None: + inputs = self.dispatch_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.dispatch_node.forward(inputs) + + inputs = self.mlp_node.forward(inputs) + + if self.combine_node is None: + inputs = self.combine_forward(inputs) + calc_stream_wait(self.moe_group.id) + else: + inputs = self.combine_node.forward(inputs) + + inputs = self.post_process_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + + output_grad = self.post_process_node.backward(output_grad) + + if self.combine_node is None: + output_grad = self.combine_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.combine_node.backward(output_grad) + + output_grad = self.mlp_node.backward(output_grad) + + if self.dispatch_node is None: + output_grad = self.dispatch_backward(output_grad) + calc_stream_wait(self.moe_group.id) + else: + output_grad = self.dispatch_node.backward(output_grad) + + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedScheduleChunk: + def __init__(self, forward_nodes, backward_nodes, use_fuion=True): + assert len(forward_nodes) == len(backward_nodes) + self.nodes = [] + for f, b in zip(forward_nodes, backward_nodes): + schedule_node_class = OverlapedScheduleNode + if use_fuion: + schedule_node_class = OverlapedFUsionScheduleNode + if isinstance(f, DenseDecoderLayerNode) or isinstance(b, DenseDecoderLayerNode): + schedule_node_class = OverlapedDenseFusionScheduleNode + self.nodes.append(schedule_node_class(f, b, f"OverlapedNode_{len(self.nodes)}")) + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # print(" fwd pp stream", pp_stream) + event_to_wait = combine_bw_event_to_wait + for i, n in enumerate(self.nodes): + pp_stream_t = pp_stream + if i + 1 != len(self.nodes): + pp_stream_t = None + + inputs, output_grad, event_to_wait = n.forward_backward( + inputs, output_grad, combine_bw_event_to_wait=event_to_wait, pp_stream=pp_stream_t + ) + return inputs, output_grad, None + + +class OverlapedScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, DecoderLayerNode) and isinstance(backward_node, DecoderLayerNode) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, event_to_wait=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + output_grad = self.backward_node.post_process_node.backward(output_grad) + + output_grad = self.backward_node.combine_backward(output_grad) + inputs = self.forward_node.attn_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, allocate_on_comm_stream=True + ) + + calc_stream_wait(self.forward_node.moe_group.id) + output_grad = self.backward_node.dispatch_backward(output_grad) + inputs = self.forward_node.mlp_node.forward(inputs) + + calc_stream_wait(self.backward_node.moe_group.id) + inputs = self.forward_node.combine_forward(inputs) + output_grad = self.backward_node.attn_node.backward(output_grad) + + calc_stream_wait(self.forward_node.moe_group.id) + inputs = self.forward_node.post_process_node.forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad + + +class FusionFp8DecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_and_gate_node, + fp8_fusion_moe_node, + post_process_node, + mlp_layer, + send_mtp_embed, + using_post_norm_recompute=False, + name="", + ): + self.attn_and_gate_node = attn_and_gate_node + self.fp8_fusion_moe_node = fp8_fusion_moe_node + self.post_process_node = post_process_node + self.send_mtp_embed = send_mtp_embed + + self.using_post_norm_recompute = using_post_norm_recompute + self.name = name + + self.moe_group = mlp_layer.moe_group + + def attn_forward(self, inputs): + inputs = self.attn_and_gate_node.forward(inputs) + + if self.send_mtp_embed: + if self.using_post_norm_recompute: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + inputs_embeds_mtp, hidden_states, residual, probs, routing_map, l_aux = inputs + else: + if self.using_post_norm_recompute: + hidden_states, residual, probs, routing_map, l_aux, norm_out = inputs + else: + hidden_states, residual, probs, routing_map, l_aux = inputs + + if self.using_post_norm_recompute: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + norm_out, probs, routing_map + ) + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + else: + hs_2d, token_indices, token_probs = self.fp8_fusion_moe_node.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + + # common return values + ret = (hidden_states, residual, l_aux, hs_2d, token_indices, token_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + return ret + + def dispatch_forward(self, inputs, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs, norm_out = inputs + else: + if self.send_mtp_embed: + inputs_embeds_mtp, hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + else: + hidden_states, residual, l_aux, hs_2d, token_indices, token_probs = inputs + + (hs_dispatched, dispatched_indices, dispatched_probs,) = self.fp8_fusion_moe_node.dispatch_node.forward( + hs_2d, + token_indices, + token_probs, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + ret = (hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def mlp_forward(self, inputs): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + norm_out, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs, norm_out = inputs + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + hs_dispatched, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + hidden_states, residual, l_aux, hs_dispatched, dispatched_indices, dispatched_probs = inputs + + hidden_states_out = self.fp8_fusion_moe_node.mlp_node.forward( + hs_dispatched, dispatched_indices, dispatched_probs + ) + ret = (hidden_states, residual, l_aux, hidden_states_out) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def combine_forward(self, inputs, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, hidden_states_out) = inputs + else: + (hidden_states, residual, l_aux, hidden_states_out) = inputs + + output_combine = self.fp8_fusion_moe_node.combine_node.forward( + hidden_states_out, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states, residual, l_aux, output_combine) + + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + return ret + + def post_process_forward(self, inputs, with_residual=True): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + (hidden_states, residual, l_aux, output_combine, norm_out) = inputs + else: + if self.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, output_combine) = inputs + else: + (hidden_states, residual, l_aux, output_combine) = inputs + final_hidden_states = self.fp8_fusion_moe_node.combine_quant_node.forward(output_combine) + + inputs = (hidden_states, residual, l_aux, final_hidden_states) + inputs = (inputs_embeds_mtp, *inputs) if self.send_mtp_embed else inputs + inputs = (*inputs, norm_out) if self.using_post_norm_recompute else inputs + + if with_residual: + inputs = self.post_process_node.forward(inputs) + else: + inputs = self.post_process_node.forward_without_residual(inputs) + return inputs + + def post_process_backward(self, output_grad, event_to_wait=None): + grad = self.post_process_node.backward(output_grad) + + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + final_hidden_states_grad, + norm_out, + invar, + ) = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad, norm_out, invar = grad + else: + if self.send_mtp_embed: + inputs_embeds_mtp_grad, hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + else: + hidden_states_grad, residual_grad, l_aux_grad, final_hidden_states_grad = grad + + output_combine_grad, quant_event = self.fp8_fusion_moe_node.combine_quant_node.backward( + final_hidden_states_grad, event_to_wait + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, output_combine_grad, quant_event) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def combine_backward(self, output_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + output_combine_grad, + quant_event, + ) = output_grad + + if DSV3_USE_FP8_DISPATCH and quant_event is not None: + combine_backward_wait_event = quant_event + else: + combine_backward_wait_event = previous_event + hidden_states_out_grad = self.fp8_fusion_moe_node.combine_node.backward( + output_combine_grad, + async_finish=async_finish, + previous_event=combine_backward_wait_event, + allocate_on_comm_stream=allocate_on_comm_stream and quant_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def mlp_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + norm_out, + invar, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hidden_states_out_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hidden_states_out_grad = output_grad + hs_dispatched_grad, dispatched_probs_grad = self.fp8_fusion_moe_node.mlp_node.backward(hidden_states_out_grad) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def dispatch_backward(self, output_grad, async_finish=False, previous_event=None, allocate_on_comm_stream=False): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + ( + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + norm_out, + invar, + ) = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_dispatched_grad, + dispatched_probs_grad, + ) = output_grad + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_dispatched_grad, dispatched_probs_grad = output_grad + + hs_grad, token_probs_grad = self.fp8_fusion_moe_node.dispatch_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream and previous_event is not None, + ) + + ret = (hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad) + ret = (inputs_embeds_mtp_grad, *ret) if self.send_mtp_embed else ret + ret = (*ret, norm_out, invar) if self.using_post_norm_recompute else ret + return ret + + def attn_backward(self, output_grad): + if self.using_post_norm_recompute: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + norm_out, + invar, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad, norm_out, invar = output_grad + else: + if self.send_mtp_embed: + ( + inputs_embeds_mtp_grad, + hidden_states_grad, + residual_grad, + l_aux_grad, + hs_grad, + token_probs_grad, + ) = output_grad + inputs_embeds_mtp_grad_shape = hidden_states_grad.shape + inputs_embeds_mtp_grad_shape[-1] = -1 + inputs_embeds_mtp_grad = inputs_embeds_mtp_grad.view(inputs_embeds_mtp_grad_shape) + else: + hidden_states_grad, residual_grad, l_aux_grad, hs_grad, token_probs_grad = output_grad + + hidden_states_grad_, probs_grad, routing_map_grad = self.fp8_fusion_moe_node.dispatch_quant_node.backward( + hs_grad, token_probs_grad + ) + + output_grad = (residual_grad, probs_grad, routing_map_grad, l_aux_grad) + + output_grad = ( + (hidden_states_grad, *output_grad, hidden_states_grad_) + if self.using_post_norm_recompute + else (hidden_states_grad + hidden_states_grad_, *output_grad) + ) + output_grad = (inputs_embeds_mtp_grad, *output_grad) if self.send_mtp_embed else output_grad + + if self.using_post_norm_recompute: + with TemporaryVarContext(norm_out, invar): + output_grad = self.attn_and_gate_node.backward(output_grad) + else: + output_grad = self.attn_and_gate_node.backward(output_grad) + return output_grad + + def forward(self, inputs): + inputs = self.attn_forward(inputs) + inputs = self.dispatch_forward(inputs) + inputs = self.mlp_forward(inputs) + inputs = self.combine_forward(inputs) + inputs = self.post_process_forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.post_process_backward(output_grad) + output_grad = self.combine_backward(output_grad) + output_grad = self.mlp_backward(output_grad) + # todo(phlrain): overlap here + output_grad = self.dispatch_backward(output_grad) + output_grad = self.attn_backward(output_grad) + return output_grad + + +class DenseDecoderLayerNode(ScheduleNode): + def __init__( + self, + attn_node, + mlp_node, + name="DenseDecoderLayerNode", + ): + super().__init__(fwd_func=None, name=name) + self.attn_node = attn_node + self.mlp_node = mlp_node + + def forward(self, inputs): + inputs = self.attn_node.forward(inputs) + inputs = self.mlp_node.forward(inputs) + return inputs + + def backward(self, output_grad=None, scaler=None): + assert (output_grad is not None) and (scaler is None) + output_grad = self.mlp_node.backward(output_grad) + output_grad = self.attn_node.backward(output_grad) + return output_grad + + +class OverlapedFUsionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) and isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + paddle.base.core.nvprof_nvtx_push("forward_backward") + + combine_bwd_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_forward") + inputs = self.forward_node.attn_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + attn_compute_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("post_process_backward") + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + paddle.base.core.nvprof_nvtx_pop() + + paddle.base.core.nvprof_nvtx_push("combine_backward") + if combine_bw_event_to_wait is not None: + # print(" event", combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + else: + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bwd_event, async_finish=True, allocate_on_comm_stream=True + ) + # get combine event + combine_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() + + combine_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_backward_dx") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.mlp_backward(output_grad) + WeightGradStore.enabled = False + WeightGradStore.flush() + paddle.base.core.nvprof_nvtx_pop() + + output_grad_event = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_forward") + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_compute_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + dispatch_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, previous_event=output_grad_event, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + # get dispatch backward event + dispatch_backward_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("dispatch_backward_dw") + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + paddle.base.core.nvprof_nvtx_pop() + + dispatch_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_push("mlp_forward") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() + mlp_fwd_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + if pp_stream is not None: + paddle.base.core.nvprof_nvtx_push("post_process_forward") + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + paddle.base.core.nvprof_nvtx_pop() + + final_out_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("combine_forward") + inputs = self.forward_node.combine_forward( + inputs, previous_event=mlp_fwd_event, async_finish=True, allocate_on_comm_stream=True + ) + paddle.base.core.nvprof_nvtx_pop() + + combine_forward_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + + combine_fwd_out = inputs[-2] if self.forward_node.using_post_norm_recompute else inputs[-1] + + if pp_stream is not None: + send_recv_stream = paddle.device.Stream(stream_base=pp_stream) + + # combine_forward_event.custom_stream_wait( pp_stream) + # final_out_event.custom_stream_wait(pp_stream) + + paddle.base.core.nvprof_nvtx_push("pp stream add") + + with paddle.device.stream_guard(send_recv_stream): + combine_forward_event.current_stream_wait() + final_out_event.current_stream_wait() + + inputs = final_out + combine_fwd_out + + final_out._record_stream() + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + + dispatch_backward_event.calc_stream_wait(self.backward_node.moe_group.id) + + paddle.base.core.nvprof_nvtx_push("attn_backward") + assert WeightGradStore.funcs_queue.empty() + WeightGradStore.enabled = True + output_grad = self.backward_node.attn_backward(output_grad) + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + + if EventStore is not None: + EventStore.set(event_to_wait) + + WeightGradStore.enabled = False + WeightGradStore.flush() + WeightGradStore.pop() + assert WeightGradStore.funcs_queue.empty() + + paddle.base.core.nvprof_nvtx_pop() + + # residual add + if pp_stream is None: + combine_forward_event.calc_stream_wait(self.forward_node.moe_group.id) + + final_out = self.forward_node.post_process_node.forward_without_residual(inputs) + if final_out.shape[-1] != combine_fwd_out.shape[-1]: + final_out[:, :, : combine_fwd_out.shape[-1]] += combine_fwd_out # 直接广播并相加 + else: + final_out += combine_fwd_out + inputs = final_out + combine_fwd_out._record_stream() + + paddle.base.core.nvprof_nvtx_pop() + return inputs, output_grad, event_to_wait + + +class OverlapedDenseFusionScheduleNode: + def __init__(self, forward_node, backward_node, name=""): + assert isinstance(forward_node, FusionFp8DecoderLayerNode) or isinstance( + backward_node, FusionFp8DecoderLayerNode + ) + assert isinstance(forward_node, DenseDecoderLayerNode) or isinstance( + backward_node, DenseDecoderLayerNode + ) + self.forward_node = forward_node + self.backward_node = backward_node + self.name = name + + def forward_backward(self, inputs, output_grad, combine_bw_event_to_wait=None, pp_stream=None): + # Dense forward + MoE backward + if isinstance(self.forward_node, DenseDecoderLayerNode): + paddle.base.core.nvprof_nvtx_push("dense_fw_moe_bw") + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + # Note: the input combine_bw_event_to_wait is unreliable, we need to record a new event here. + combine_bw_event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + output_grad = self.backward_node.post_process_backward(output_grad, combine_bw_event_to_wait) + output_grad = self.backward_node.combine_backward( + output_grad, previous_event=combine_bw_event_to_wait, async_finish=True, allocate_on_comm_stream=True + ) + combine_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.attn_node.forward(inputs) + combine_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + output_grad = self.backward_node.mlp_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.dispatch_backward( + output_grad, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_bw_event = deep_ep.get_event_from_comm_stream(self.backward_node.moe_group.id) + inputs = self.forward_node.mlp_node.forward(inputs) + dispatch_bw_event.calc_stream_wait(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_attn") + output_grad = self.backward_node.attn_backward(output_grad) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + event_to_wait = deep_ep.get_event_from_calc_stream(self.backward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_fw_moe_bw + + # Dense backward + MoE forward + else: + paddle.base.core.nvprof_nvtx_push("dense_bw_moe_fw") + + paddle.base.core.nvprof_nvtx_push("moe_attn") + inputs = self.forward_node.attn_forward(inputs) + attn_fw_event = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # moe_attn + + paddle.base.core.nvprof_nvtx_push("dense_mlp_moe_dispatch") + output_grad = self.backward_node.mlp_node.backward(output_grad) + inputs = self.forward_node.dispatch_forward( + inputs, previous_event=attn_fw_event, async_finish=True, allocate_on_comm_stream=True + ) + dispatch_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + dispatch_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_mlp_moe_dispatch + + paddle.base.core.nvprof_nvtx_push("moe_mlp") + inputs = self.forward_node.mlp_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_mlp + + paddle.base.core.nvprof_nvtx_push("dense_attn_moe_combine") + inputs = self.forward_node.combine_forward( + inputs, async_finish=True, allocate_on_comm_stream=True + ) + combine_fw_event = deep_ep.get_event_from_comm_stream(self.forward_node.moe_group.id) + output_grad = self.backward_node.attn_node.backward(output_grad) + combine_fw_event.calc_stream_wait(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_attn_moe_combine + + paddle.base.core.nvprof_nvtx_push("moe_post") + inputs = self.forward_node.post_process_forward(inputs) + paddle.base.core.nvprof_nvtx_pop() # moe_post + + event_to_wait = deep_ep.get_event_from_calc_stream(self.forward_node.moe_group.id) + paddle.base.core.nvprof_nvtx_pop() # dense_bw_moe_fw + + return inputs, output_grad, event_to_wait + + +def build_overlapped_nodes(forward_chunk, backward_chunk): + overlap_element_class = ( + FusionFp8DecoderLayerNode if DSV3_USE_FP8_GEMM else DecoderLayerNode, + DenseDecoderLayerNode + ) + forward_decoder_layer_num = 0 + backward_decoder_layer_num = 0 + assert isinstance(forward_chunk, ScheduleChunk) and isinstance(backward_chunk, ScheduleChunk) + for n in forward_chunk.nodes: + if isinstance(n, overlap_element_class): + forward_decoder_layer_num += 1 + for n in reversed(backward_chunk.nodes): + if isinstance(n, overlap_element_class): + backward_decoder_layer_num += 1 + + overlap_layers_num = min(forward_decoder_layer_num, backward_decoder_layer_num) + forward_pre_overlap_layers = [] + forward_post_overlap_layers = [] + forward_overlap_layers = [] + is_pre = True + for n in forward_chunk.nodes: + if not isinstance(n, overlap_element_class): + if is_pre: + forward_pre_overlap_layers.append(n) + else: + forward_post_overlap_layers.append(n) + else: + is_pre = False + if len(forward_overlap_layers) == overlap_layers_num: + forward_post_overlap_layers.append(n) + else: + forward_overlap_layers.append(n) + forward_pre_node = ScheduleChunk(forward_pre_overlap_layers) + forward_post_node = ScheduleChunk(forward_post_overlap_layers) + + backward_pre_overlap_layers = [] + backward_post_overlap_layers = [] + backward_overlap_layers = [] + is_pre = True + for n in reversed(backward_chunk.nodes): + if not isinstance(n, overlap_element_class): + if is_pre: + backward_pre_overlap_layers.append(n) + else: + backward_post_overlap_layers.append(n) + else: + is_pre = False + if len(backward_overlap_layers) == overlap_layers_num: + backward_post_overlap_layers.append(n) + else: + backward_overlap_layers.append(n) + + backward_pre_node = ScheduleChunk(list(reversed(backward_pre_overlap_layers))) + backward_post_node = ScheduleChunk(list(reversed(backward_post_overlap_layers))) + + overlap_node = OverlapedScheduleChunk(forward_overlap_layers, backward_overlap_layers, use_fuion=DSV3_USE_FP8_GEMM) + return forward_pre_node, backward_pre_node, overlap_node, forward_post_node, backward_post_node + + class DeepseekV2EmbeddingPipe(nn.Layer): def __init__(self, config: DeepseekV2Config): super(DeepseekV2EmbeddingPipe, self).__init__() @@ -160,6 +1394,7 @@ def forward(self, args): # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) inputs_embeds = ScatterOp.apply(inputs_embeds) embeds_res = [inputs_embeds] + mtp_embeds = [] for depth in range(self.config.num_nextn_predict_layers): inputs_embeds_mtp = paddle.concat( [ @@ -171,12 +1406,19 @@ def forward(self, args): if self.sequence_parallel: inputs_embeds_mtp = inputs_embeds_mtp.reshape([-1, inputs_embeds_mtp.shape[-1]]) inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) - embeds_res.append(inputs_embeds_mtp) - # if not self.sequence_parallel - # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] - # else: - # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] - inputs_embeds = paddle.concat(embeds_res, axis=-1) + mtp_embeds.append(inputs_embeds_mtp) + + if self.config.send_mtp_embed: + embeds_res.extend(mtp_embeds) + # if not self.sequence_parallel + # mtp_embeds: [B*num_nextn_predict_layers, seq_len, hidden_size] + # else: + # mtp_embeds: [B*seq_len*num_nextn_predict_layers, hidden_size] + inputs_embeds = paddle.concat(embeds_res, axis=-1) + else: + global global_inputs_embeds_mtp_queue + cloned_mtp_embeds = [t.detach() for t in mtp_embeds] + global_inputs_embeds_mtp_queue.put(cloned_mtp_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) else: if self.sequence_parallel: @@ -184,15 +1426,18 @@ def forward(self, args): inputs_embeds = ScatterOp.apply(inputs_embeds) return return_args(inputs_embeds, attention_mask, attn_mask_startend_row_indices, position_ids) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2EmbeddingPipe") + class DeepseekV2DecoderLayerPipe(DeepseekV2DecoderLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: batch_size, _, hidden_size = hidden_states.shape batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) - inputs_embeds_mtp = hidden_states[..., -batch_size_mtp:] + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] hidden_states = hidden_states[..., :batch_size_mtp] has_gradient = not hidden_states.stop_gradient @@ -235,19 +1480,285 @@ def forward(self, args): attn_mask_startend_row_indices=attn_mask_startend_row_indices, ) - if self.config.num_nextn_predict_layers > 0: + if self.config.send_mtp_embed: hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.send_mtp_embed + + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + def attn_compute_func(hidden_states): + hidden_states, residual = self.self_attn_compute(hidden_states) + l_aux, _, intermediate_hidden_states, token_indices, token_probs = self.pre_dispatch_compute(hidden_states) + return (hidden_states, residual, l_aux, intermediate_hidden_states, token_indices, token_probs) + + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + # for pretrain + outputs = recompute( + attn_compute_func, + hidden_states, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = attn_compute_func(hidden_states) + + return (inputs_embeds_mtp, *outputs) + + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + # slice from holy tensor + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + _, _, d_model = hidden_states.shape + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + # append mtp embed if needed + ret = (inputs_embeds_mtp, *ret) if send_mtp_embed else ret + # append norm_out if using post_norm recompute + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def mlp_compute(self, inputs): + if isinstance(inputs, list): + inputs = tuple(inputs) + send_mtp_embed = self.config.send_mtp_embed + + if send_mtp_embed: + ( + inputs_embeds_mtp, + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + else: + ( + hidden_states, + residual, + l_aux, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + ) = inputs + has_gradient = not intermediate_hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + expert_output = recompute( + self.expert_forward_compute, + intermediate_hidden_states, + dispatched_indices, + dispatched_probs, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + expert_output = self.expert_forward_compute( + intermediate_hidden_states, dispatched_indices, dispatched_probs + ) + if send_mtp_embed: + return (inputs_embeds_mtp, hidden_states, residual, l_aux, expert_output) + else: + return (hidden_states, residual, l_aux, expert_output) + + def post_process_compute(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, combine_output) = inputs + else: + (hidden_states, residual, l_aux, combine_output) = inputs + has_gradient = not hidden_states.stop_gradient + if self.enable_recompute and self.config.recompute_granularity == "full" and has_gradient: + hidden_states = recompute( + self.post_combine_compute, + residual, + hidden_states, + combine_output, + l_aux, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + hidden_states = self.post_combine_compute( + residual, + hidden_states, + combine_output, + l_aux, + ) + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def post_process_compute_for_fusion(self, inputs): + send_mtp_embed = self.config.send_mtp_embed + + if isinstance(inputs, list): + inputs = tuple(inputs) + + if send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual, l_aux, final_hidden_states) = inputs + else: + (hidden_states, residual, l_aux, final_hidden_states) = inputs + + final_hidden_states = self.mlp.post_process(hidden_states, final_hidden_states, l_aux) + + hidden_states = residual + final_hidden_states + + hidden_states = (hidden_states,) + + if type(hidden_states) is tuple and len(hidden_states) == 1: + hidden_states = hidden_states[0] + + if send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return return_args(hidden_states) + + def attn_compute_dense(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + + if self.config.send_mtp_embed: + batch_size, _, hidden_size = hidden_states.shape + batch_size_mtp = hidden_size // (self.config.num_nextn_predict_layers + 1) + inputs_embeds_mtp = hidden_states[..., batch_size_mtp:] + hidden_states = hidden_states[..., :batch_size_mtp] + + hidden_states, residual = self.self_attn_compute(hidden_states) + + ret = (hidden_states, residual) + ret = (inputs_embeds_mtp, *ret) if self.config.send_mtp_embed else ret + return ret + + def mlp_compute_dense(self, inputs): + if self.config.send_mtp_embed: + (inputs_embeds_mtp, hidden_states, residual) = inputs + else: + (hidden_states, residual) = inputs + + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + if self.config.send_mtp_embed: + hidden_states = paddle.concat([hidden_states, inputs_embeds_mtp], axis=-1) + + return hidden_states + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if self.mlp.using_flex_token: + if DSV3_USE_FP8_GEMM: + attn_and_gate_node = ScheduleNode(self.attn_compute_for_fusion, name="attn_and_gate_node") + + # recompute_fwd_gate_up_ may be 1, 0 or -1, 1 means recompute, 0 means disable recompute, -1 means adaptive recompute. + recompute_fwd_gate_up_ = 1 if self.layer_idx in self.config.recompute_fwd_gate_up_list else 0 + if recompute_fwd_gate_up_ == 0 and self.config.adaptive_remained_O1_recompute_ratio: + recompute_fwd_gate_up_ = -1 + + fp8_fusion_moe_node = FusionMoeNode( + self.mlp, + recompute_fwd_gate_up=recompute_fwd_gate_up_, + is_split_group_gemm=self.config.is_split_group_gemm, + mlp_fwd_subbatch_rows=self.config.mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=self.config.mlp_bwd_subbatch_rows, + output_subbatch_rows=self.config.output_subbatch_rows, + name="fp8_fusion_moe_node", + ) + post_process_node = PostProcessNode( + self.config.send_mtp_embed, + self.mlp.training, + self.mlp.alpha, + self.config, + self.mlp.shared_experts, + self.config.using_post_norm_recompute, + output_mtp_embed_first=isinstance(self, DeepseekV2MTPLayer), + name="post_process_node", + ) + return FusionFp8DecoderLayerNode( + attn_and_gate_node=attn_and_gate_node, + fp8_fusion_moe_node=fp8_fusion_moe_node, + post_process_node=post_process_node, + mlp_layer=self.mlp, + send_mtp_embed=self.config.send_mtp_embed, + using_post_norm_recompute=self.config.using_post_norm_recompute, + name="FusionFp8DecoderLayerNode", + ) + else: + attn_node = ScheduleNode(self.attn_compute, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute, name="mlp_node") + post_process_node = ScheduleNode(self.post_process_compute, name="post_process_node") + return DecoderLayerNode( + attn_node=attn_node, + dispatch_node=None, + mlp_node=mlp_node, + combine_node=None, + post_process_node=post_process_node, + mlp_layer=self.mlp, + name="DecoderLayerNode", + ) + + attn_node = ScheduleNode(self.attn_compute_dense, name="attn_node") + mlp_node = ScheduleNode(self.mlp_compute_dense, name="mlp_node") + return DenseDecoderLayerNode( + attn_node=attn_node, + mlp_node=mlp_node, + name="DenseDecoderLayerNode", + ) + class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer): def forward(self, args): hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) - hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) - hidden_states_main_model = hidden_states_list[0] - inputs_embeds_cur_depth_list = hidden_states_list[1:] + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + has_gradient = not hidden_states_main_model.stop_gradient if attention_mask is not None and attention_mask.dtype == paddle.int32: @@ -299,6 +1810,70 @@ def forward(self, args): hidden_states = paddle.concat(output_list, axis=-1) return return_args(hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids) + def attn_compute_for_fusion(self, args): + hidden_states, attention_mask, attn_mask_startend_row_indices, position_ids = parse_args(args) + assert attention_mask is None + assert attn_mask_startend_row_indices is None + assert position_ids is None + assert self.config.num_nextn_predict_layers == 1 + + if self.config.send_mtp_embed: + hidden_states_list = paddle.split(hidden_states, self.config.num_nextn_predict_layers + 1, axis=-1) + hidden_states_main_model = hidden_states_list[0] + inputs_embeds_cur_depth_list = hidden_states_list[1:] + else: + hidden_states_main_model = hidden_states + global global_inputs_embeds_mtp_queue + inputs_embeds_cur_depth_list = global_inputs_embeds_mtp_queue.get() + + hidden_states = hidden_states_main_model + nextn_hidden_state = inputs_embeds_cur_depth_list[0] + + # mtp compute + hidden_states = self.hnorm(hidden_states) + nextn_hidden_state = self.enorm(nextn_hidden_state) + + hidden_states = self.eh_proj(paddle.concat([hidden_states, nextn_hidden_state], axis=-1)) + + # attention compute + hidden_states, residual = self.self_attn_compute(hidden_states) + + if self.using_post_norm_recompute: + probs, routing_map, l_aux, _, norm_out = self.mlp.router(hidden_states) + else: + probs, routing_map, l_aux, _ = self.mlp.router(hidden_states) + + # common return values + ret = ( + hidden_states_main_model, + hidden_states, + residual, + probs, + routing_map, + l_aux, + ) + ret = (*ret, norm_out) if self.using_post_norm_recompute else ret + + return ret + + def build_schedule_node(self): + if isinstance(self.mlp, DeepseekV2MoE): + self.mlp.update_flex_token() + if ( + self.mlp.using_flex_token and + DSV3_USE_FP8_GEMM and + self.config.num_nextn_predict_layers == 1 + ): + prev_send_mtp_embed = self.config.send_mtp_embed + self.config.send_mtp_embed = True # must be True in MTP node + + node = DeepseekV2DecoderLayerPipe.build_schedule_node(self) + assert isinstance(node, FusionFp8DecoderLayerNode) + + self.config.send_mtp_embed = prev_send_mtp_embed + return node + return ScheduleNode(self.forward, name="DeepseekV2MTPLayerPipe") + class DeepseekV2RMSNormPipe(nn.Layer): def __init__(self, config): @@ -321,10 +1896,13 @@ def forward(self, args): else: return self.norm(hidden_states) + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2RMSNormPipe") + class DeepseekV2LMHeadPipe(DeepseekV2LMHead): - def __init__(self, config): - super(DeepseekV2LMHeadPipe, self).__init__(config) + def __init__(self, config, embedding_weight=None): + super(DeepseekV2LMHeadPipe, self).__init__(config, embedding_weight=embedding_weight) @property def embedding_weight(self): @@ -340,6 +1918,9 @@ def forward(self, args: Union[Tuple, paddle.Tensor]): logits = super().forward(hidden_states) return logits + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2LMHeadPipe") + class DeepseekV2PretrainingCriterionPipe(DeepseekV2PretrainingCriterion): def forward(self, logits, labels): @@ -348,9 +1929,14 @@ def forward(self, logits, labels): logits = logits[0] loss = super().forward(logits, labels, mtp_logits=mtp_logits) else: + if isinstance(logits, (tuple, list)): + logits = logits[0] loss = super().forward(logits, labels) return loss + def build_schedule_node(self): + return ScheduleNode(self.forward, name="DeepseekV2PretrainingCriterionPipe") + class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): """DeepseekV2ForPretraining adapted for pipeline parallelism. @@ -371,6 +1957,9 @@ class DeepseekV2ForCausalLMPipe(PipelinePretrainedModel, PipelineLayer): # DONOT Add base_model_prefix !!!! + def step_flex_token(self, cur_step): + set_global_step(cur_step) + @classmethod def _prepare_pipeline_inputs_func(cls, inputs): first_stage_keys = ["input_ids", "attention_mask", "attn_mask_startend_row_indices", "position_ids"] @@ -408,6 +1997,10 @@ def __init__(self, config: DeepseekV2Config): assert len(self.no_recompute_layers) == 0, "for pp with full recompute, no_recompute_layers is not support" virtual_pp_degree = getattr(self.config, "virtual_pp_degree", 1) + use_dualpipev = getattr(self.config, "use_dualpipev", False) + if use_dualpipev: + assert LocalSharedLayerDesc is not None, "LocalSharedLayerDesc is None, please update your paddle." + shared_class = LocalSharedLayerDesc if use_dualpipev else SharedLayerDesc def get_hcg(): return fleet.get_hybrid_communicate_group() @@ -422,7 +2015,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2EmbeddingPipe, shared_weight_attr="embedding_weight", @@ -435,6 +2028,43 @@ def get_hcg(): LayerDesc(DeepseekV2EmbeddingPipe, config=config), self._base_model.base_model_prefix ) + def compute_recompute_fwd_gate_up_list(pp_nums, all_dl_nums, dense_dl_nums, recompute_fwd_gate_up): + all_layers_nums = all_dl_nums + 4 # embedding, rms, lm_head, mtp + segment_size = all_layers_nums // pp_nums + boundary = math.ceil((1 + dense_dl_nums) / segment_size) * segment_size + recompute_fwd_gate_up_list = [dense_dl_nums] + for idx in range(boundary - 1, all_dl_nums, segment_size): + recompute_fwd_gate_up_list.append(idx) + + # If `recompute_fwd_gate_up` is a Boolean value and is True, means all O1 will be recomputed. + # Otherwise `recompute_fwd_gate_up` should be an integer representing how many O1 are recomputed. + assert isinstance(recompute_fwd_gate_up, (int, bool)) + if type(recompute_fwd_gate_up) is bool: + enable_k_o1_rc = segment_size if recompute_fwd_gate_up is True else 0 + else: + enable_k_o1_rc = recompute_fwd_gate_up + + ret = [] + for i in range(len(recompute_fwd_gate_up_list)): + for k in range(min(segment_size, enable_k_o1_rc)): + ret.append(recompute_fwd_gate_up_list[i] + k) + return ret + + pp_nums = ( + self.config["pipeline_parallel_degree"] * 2 + if self.config.use_dualpipev + else self.config["pipeline_parallel_degree"] + ) + recompute_fwd_gate_up_list = compute_recompute_fwd_gate_up_list( + pp_nums, + self.config.num_hidden_layers, + self.config.first_k_dense_replace, + self.config.recompute_fwd_gate_up, + ) + + logger.info(f"recompute_fwd_gate_up_list: {recompute_fwd_gate_up_list}") + config.recompute_fwd_gate_up_list = recompute_fwd_gate_up_list + for i in range(config.num_hidden_layers): self.add_sequential_layer( LayerDesc( @@ -455,7 +2085,7 @@ def get_hcg(): if config.tie_word_embeddings: self.add_sequential_layer( - SharedLayerDesc( + shared_class( "DeepseekV2_shared_weight", DeepseekV2LMHeadPipe, shared_weight_attr="embedding_weight", @@ -491,11 +2121,69 @@ def get_hcg(): "partition": False, }, num_virtual_pipeline_stages=virtual_pp_degree, + use_dualpipev=use_dualpipev, ) # You should call init here, since there is a diamond inheritance problem self.apply(self._init_weights) # DON'T init PipelinePretrainedModel # PipelinePretrainedModel.__init__(self.super(), config=config) + def fp8_quant_weight(self, batch_mode=False): + """fp8_quant_weight""" + with paddle.no_grad(): + for i, layer in self._sub_layers.items(): + if isinstance( + layer, paddle.distributed.fleet.meta_parallel.parallel_layers.pp_layers.PipelineLayerChunk + ): + for i, sub_layer in layer.named_sublayers(): + if isinstance(sub_layer, DeepseekV2DecoderLayer) and hasattr(sub_layer, "fp8_quant_weight"): + sub_layer.fp8_quant_weight(batch_mode) + if isinstance(layer, DeepseekV2DecoderLayer) and hasattr(layer, "fp8_quant_weight"): + layer.fp8_quant_weight(batch_mode) + def get_loss_fn(self, config): return DeepseekV2PretrainingCriterionPipe(config) + + def overlapped_forward_backward( + self, + forward_chunk, # the module of the forward chunk + forward_inputs, + forward_loss_fn_node, + backward_chunk, # the module of the backward chunk, maybe not used + backward_loss_fn_node, + backward_input_grads, + scaler, + combine_bw_event_to_wait=None, + pp_stream=None, + ): + if backward_loss_fn_node is not None: + if scaler: + backward_input_grads = backward_loss_fn_node.backward(scaler=scaler) + else: + backward_input_grads = backward_loss_fn_node.backward() + + ( + forward_pre_node, + backward_pre_node, + overlap_node, + forward_post_node, + backward_post_node, + ) = build_overlapped_nodes(forward_chunk, backward_chunk) + forward_inputs = forward_pre_node.forward(forward_inputs) + backward_input_grads = backward_pre_node.backward(backward_input_grads) + forward_inputs, backward_input_grads, _ = overlap_node.forward_backward( + forward_inputs, + backward_input_grads, + combine_bw_event_to_wait=combine_bw_event_to_wait, + pp_stream=pp_stream, + ) + forward_inputs = forward_post_node.forward(forward_inputs) + backward_input_grads = backward_post_node.backward(backward_input_grads) + + if forward_loss_fn_node is not None: + forward_loss = forward_loss_fn_node.forward(forward_inputs) + else: + forward_loss = None + + forward_inputs = [forward_inputs] if isinstance(forward_inputs, paddle.Tensor) else forward_inputs + return forward_inputs, forward_loss, backward_input_grads diff --git a/paddleformers/transformers/deepseek_v3/modeling.py b/paddleformers/transformers/deepseek_v3/modeling.py index 51c0d1978fe..8f6ed05c3e7 100644 --- a/paddleformers/transformers/deepseek_v3/modeling.py +++ b/paddleformers/transformers/deepseek_v3/modeling.py @@ -25,13 +25,25 @@ import paddle -from ..deepseek_v2.modeling import ( - DeepseekV2ForSequenceClassification, - DeepseekV2LMHead, - DeepseekV2Model, - DeepseekV2PretrainedModel, - DeepseekV2PretrainingCriterion, -) +import os +if not os.getenv("DSV3_FAST_PRETRAIN", "False"): + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2Model, + DeepseekV2PretrainedModel, + DeepseekV2PretrainingCriterion, + ) +else: + from ..deepseek_v2.modeling import ( + DeepseekV2ForSequenceClassification, + DeepseekV2LMHead, + DeepseekV2PretrainingCriterion, + ) + + from ..deepseek_v2.modeling_fast import DeepseekV2ModelFast as DeepseekV2Model + from ..deepseek_v2.modeling_fast import DeepseekV2PretrainedModelFast as DeepseekV2PretrainedModel + from ..model_outputs import CausalLMOutputWithPast from ..model_utils import register_base_model from .configuration import DeepseekV3Config diff --git a/paddleformers/transformers/fp8_utils.py b/paddleformers/transformers/fp8_utils.py new file mode 100644 index 00000000000..93790005d67 --- /dev/null +++ b/paddleformers/transformers/fp8_utils.py @@ -0,0 +1,1252 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial + +import numpy +import paddle +import paddle.nn.functional as F + +try: + import fused_ln + from paddle.incubate.nn.functional import swiglu +except ImportError: + + def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + + +from paddle.distributed.fleet.meta_parallel.zero_bubble_utils import WeightGradStore + +USE_DS_GEMM = os.getenv("USE_DS_GEMM", "False").lower() == "true" + +try: + if USE_DS_GEMM: + import deep_gemm + else: + from paddle.incubate.fp8 import deep_gemm +except: + pass + + +__all__ = [ + "FP8LinearFunctionBase", + "FP8Linear", + "FP8GroupGemmMlpFunctionNode", +] + + +def set_parameter_color( + parameters, color, group=None, offline_quant_expert_weight=True, clear_origin_weight_when_offline_quant=True +): + if offline_quant_expert_weight and clear_origin_weight_when_offline_quant: + if group is None: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color}) + else: + for p in parameters: + if hasattr(p, "color") and p.color is not None: + continue + setattr(p, "color", {"color": color, "group": group}) + + +def extract_first_if_tuple(x): + return x[0] if isinstance(x, tuple) else x + + +def _get_fp8_weight_and_scale(weight, stacked=False, transpose=False): + """_get_fp8_weight_and_scale""" + if stacked: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_stacked_transpose, weight.fp8_scale_stacked_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight_stacked, weight.fp8_scale_stacked + else: + if transpose: + fp8_weight, fp8_scale = weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + fp8_weight, fp8_scale = weight.fp8_weight, weight.fp8_scale + return fp8_weight, fp8_scale + + +def fused_stack_quant(expert_weight_list, transpose=False): + if hasattr(expert_weight_list[0], "fp8_weight_stacked"): + w, scale = _get_fp8_weight_and_scale(expert_weight_list[0], stacked=True, transpose=transpose) + else: + w, scale = paddle.incubate.nn.functional.fused_stack_transpose_quant(expert_weight_list, transpose=transpose) + return w, scale + + +def weight_quant(weight, transpose=False): + if transpose: + if hasattr(weight, "fp8_weight_transpose"): + return weight.fp8_weight_transpose, weight.fp8_scale_transpose + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=True, + ) + else: + if hasattr(weight, "fp8_weight"): + return weight.fp8_weight, weight.fp8_scale + else: + return paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=False, + return_transpose_only=False, + ) + + +class FP8LinearFunctionBase: + @staticmethod + def dequantize_fp8_to_fp32(fp8_tensor, scale): + res = fp8_tensor.reshape([-1, 128]).astype("bfloat16") * (scale.reshape([-1, 1])) + return res.reshape(fp8_tensor.shape) + + @staticmethod + def padding(x, axis): + if x.shape[axis] % 512 != 0: + if (x.shape[axis] + 128 - (x.shape[axis] % 128)) % 512 != 0: + padding_size = 512 + else: + padding_size = 128 + pad_size = padding_size - (x.shape[axis] % padding_size) + if axis == 0: + x = paddle.concat([x, paddle.zeros([pad_size, x.shape[-1]], dtype=x.dtype)], axis=0) + else: + x = paddle.concat([x, paddle.zeros([x.shape[0], pad_size], dtype=x.dtype)], axis=-1) + return x + + @staticmethod + def padding_and_quant_input(tensor): + """Quantize input to FP8, with fallback to padded transposed version if shape not aligned.""" + if tensor.shape[0] % 512 != 0: + tensor_fp8, tensor_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + tensor = FP8LinearFunctionBase.padding(tensor, 0) + tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, + output_scale_transpose=True, + tquant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + else: + tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + tensor, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + return tensor_fp8, tensor_scale, tensor_t_fp8, tensor_t_scale + + @staticmethod + def kitchen_gemm( + x_fp8, x_scale, w_fp8, w_scale, is_a_1d_scaled, is_b_1d_scaled, out=None, rtn_dtype=paddle.bfloat16 + ): + if USE_DS_GEMM: + if out is None: + out = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], rtn_dtype) + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + deep_gemm.wgrad_gemm_fp8_fp8_fp32_nt((x_fp8, x_scale), (w_fp8, w_scale), out, num_sms=118) + return out + + if out is not None: + accumulate = True + out_dtype = out.dtype + else: + accumulate = False + out_dtype = rtn_dtype + if numpy.prod(x_fp8.shape) != 0 and numpy.prod(w_fp8.shape) != 0: + y = paddle.incubate.nn.functional.fp8_gemm_blockwise( + a=x_fp8, + a_decode_scale=x_scale, + b=w_fp8, + b_decode_scale=w_scale, + out_dtype=out_dtype, + out=out, + accumulate=accumulate, + use_split_accumulator=True, + is_a_1d_scaled=is_a_1d_scaled, + is_b_1d_scaled=is_b_1d_scaled, + ) + else: + y = paddle.zeros([x_fp8.shape[0], w_fp8.shape[0]], out_dtype) + if out is not None: + out = out + y + return out + + return y + + @staticmethod + def compute_fp8_linear( + input, weight, weight_transpose=False, return_transpose_only=False, return_mode="output_only", *, out=None + ): + """ + FP8 Linear 计算函数,支持多种返回模式,支持量化/未量化输入。 + + Args: + input: 输入张量(原始或已经量化的(input_fp8, input_scale) 元组)。 + weight: 权重张量。 + weight_transpose (bool): 是否转置权重。 + return_transpose_only (bool): 是否仅返回转置后的权重。 + return_mode (str): 返回模式,可选: + - "output_only": 仅返回输出张量。 + - "with_input_quant": 返回输出 + 输入量化结果 (input_fp8, input_scale)。 + - "with_input_transpose_quant": 返回输出(out) + 输入量化转置结果 (input_t_fp8, input_t_scale). + Returns: + 根据 return_mode 返回不同组合的张量。 + + Raises: + RuntimeError: 如果 return_mode 不支持。 + """ + # check input + is_input_quantized = isinstance(input, (tuple, list)) and len(input) == 2 + + if is_input_quantized: + input_fp8, input_scale = input + if return_mode == "with_input_transpose_quant": + raise RuntimeError( + "Cannot return transposed quant if input is already quantized. " "Use raw input instead." + ) + else: + # quant input (with optional transposed output) + if return_mode == "with_input_transpose_quant": + input_fp8, input_scale, input_t_fp8, input_t_scale = FP8LinearFunctionBase.padding_and_quant_input( + input + ) + else: + input_fp8, input_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + input, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=False, + return_transpose_only=False, + ) + + # quant weight + weight_fp8, weight_scale = weight_quant(weight, weight_transpose) + + # FP8 GEMM + if out is None: + out = paddle.empty([input_fp8.shape[0], weight_fp8.shape[0]], dtype=weight.dtype) + + deep_gemm.gemm_fp8_fp8_bf16_nt((input_fp8, input_scale.T), (weight_fp8, weight_scale), out, num_sms=118) + + # Return outputs + if return_mode == "output_only": + return out + elif return_mode == "with_input_quant": + return (out, input_fp8, input_scale) + elif return_mode == "with_input_transpose_quant": + return (out, input_t_fp8, input_t_scale) + else: + raise RuntimeError( + f"Unsupported return_mode: {return_mode}. " + "Supported modes: 'output_only', 'with_input_quant', 'with_input_transpose_quant'" + ) + + @staticmethod + def compute_expert_w_grad( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled=True, + is_b_1d_scaled=True, + weight=None, + rtn_dtype=paddle.bfloat16, + ): + """ + 统一处理 expert_w 的梯度计算(支持 main_grad 和普通 grad) + """ + + if input_t is None or numpy.prod(input_t.shape) == 0: + return + + if hasattr(weight, "main_grad"): + if weight.main_grad is None: + weight.main_grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.kitchen_gemm, + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + ) + result = None + + else: + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, + input_t_scale, + dout_t, + dout_t_scale, + is_a_1d_scaled, + is_b_1d_scaled, + weight.main_grad, + rtn_dtype, + ) + else: + if weight.grad is None: + weight.grad = paddle.zeros(shape=weight.shape, dtype=paddle.float32) + result = FP8LinearFunctionBase.kitchen_gemm( + input_t, input_t_scale, dout_t, dout_t_scale, is_a_1d_scaled, is_b_1d_scaled, weight.grad, rtn_dtype + ) + + if hasattr(weight, "_apply_backward_hook"): + weight._apply_backward_hook() + return result + + @staticmethod + def common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=None, x_scale=None, apply_backward_hook=False + ): + if o1 is not None and (x_fp8 is not None or x_scale is not None): + raise ValueError("When o1 is provided, both x_fp8 and x_scale must be None.") + + if o1 is None: + if x_fp8 is None or x_scale is None: + raise ValueError("When o1 is None, both x_fp8 and x_scale must be provided.") + + # # ===== [recompute] o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + + # Recompute o1 using deep_gemm(x_fp8, w1_t_fp8) + w1_fp8, w1_scale = weight_quant(w1, True) + o1 = paddle.empty([x_fp8.shape[0], w1_fp8.shape[0]], dtype=do3.dtype) + deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale.T), (w1_fp8, w1_scale), o1, num_sms=118) + + # ===== [recompute] o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== do2 = deep_gemm(do3_fp8, w2_fp8) + do2, do3_t_fp8, do3_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do3, w2, return_mode="with_input_transpose_quant" + ) + + # ===== dw2 = deep_gemm(o2_t_fp8, do3_t_fp8) + o2 = FP8LinearFunctionBase.padding(o2, 0) + o2_t_fp8, o2_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + o2, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + o2_t_fp8, + o2_t_scale, + do3_t_fp8, + do3_t_scale, + True, + True, + w2, + rtn_dtype=paddle.float32, + ) + ) + else: + + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, w2, rtn_dtype=paddle.float32 + ) + else: + dw2 = FP8LinearFunctionBase.kitchen_gemm( + o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + # ===== do1 = swiglu_grad(o1, None, do2) ===== + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + + # ===== dx = deep_gemm(do1_fp8, w1_fp8) ===== + dx, do1_t_fp8, do1_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + do1, w1, return_mode="with_input_transpose_quant" + ) + + # ===== dw1 = deep_gemm(x_t_fp8, do1_t_fp8) ===== + if apply_backward_hook: + if WeightGradStore.enabled: + WeightGradStore.put( + partial( + FP8LinearFunctionBase.compute_expert_w_grad, + x_t_fp8, + x_t_scale, + do1_t_fp8, + do1_t_scale, + True, + True, + w1, + rtn_dtype=paddle.float32, + ) + ) + + else: + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, w1, rtn_dtype=paddle.float32 + ) + else: + dw1 = FP8LinearFunctionBase.kitchen_gemm( + x_t_fp8, x_t_scale, do1_t_fp8, do1_t_scale, True, True, rtn_dtype=paddle.float32 + ) + + if apply_backward_hook: + return dx + else: + assert dw1 is not None and dw2 is not None + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_fwd(x, w1, w2): + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== o1 = deep_gemm(x_fp8, w1_t_fp8) ===== + o1, x_fp8, x_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, w1, weight_transpose=True, return_transpose_only=True, return_mode="with_input_quant" + ) + + # ===== o2 = swiglu(o1) ===== + o2 = swiglu(o1) + + # ===== o3 = deep_gemm(o2_fp8, w2_t_fp8) ===== + o3 = FP8LinearFunctionBase.compute_fp8_linear(o2, w2, weight_transpose=True, return_transpose_only=True) + + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + return o1, x_fp8, x_scale, o3 + + @staticmethod + def fp8_mlp_fwd_norm_rc(x, norm_w, norm_eps, w1, w2): + # ===== compute norm_output ===== + norm_output, _ = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== compute fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + return o3 + + @staticmethod + def fp8_mlp_bwd(do3, x, w1, w2, apply_backward_hook=False): + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + x_fp8, x_scale, x_t_fp8, x_t_scale = FP8LinearFunctionBase.padding_and_quant_input(x) + + if apply_backward_hook: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + return dx + else: + dx, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, + x_t_fp8, + x_t_scale, + w1, + w2, + o1=None, + x_fp8=x_fp8, + x_scale=x_scale, + apply_backward_hook=apply_backward_hook, + ) + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, dw1, dw2 + + @staticmethod + def fp8_mlp_bwd_norm_rc(do3, x, norm_w, norm_eps, w1, w2): + # ===== recompute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + + # ===== compute fp8_mlp_fwd ===== + d_norm_output = FP8LinearFunctionBase.fp8_mlp_bwd(do3, norm_output, w1, w2, True) + + if hasattr(norm_w, "_apply_backward_hook"): + norm_w._apply_backward_hook() + + return d_norm_output, norm_output, invar + + +class FP8LinearFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, custom_map, keep_x=False): + weight = custom_map.weight + x_orig_shape = x.shape + + # deep_gemm only support 2D + x = x.reshape([-1, x_orig_shape[-1]]).contiguous() + + if keep_x: + out = FP8LinearFunctionBase.compute_fp8_linear( + x, + weight, + weight_transpose=True, + return_transpose_only=True, + ) + # save for bwd + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward(x, weight) + return out + else: + x_t = x.T + out, x_t_fp8, x_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + x, weight, weight_transpose=True, return_transpose_only=True, return_mode="with_input_transpose_quant" + ) + out = out.reshape([x_orig_shape[0], -1, weight.shape[-1]]) + ctx.save_for_backward((x_t_fp8, x_t_scale), weight) + ctx.x_t_shape = x_t.shape + return out + + @staticmethod + def backward(ctx, dout): + x, weight = ctx.saved_tensor() + dout_2d = dout.reshape([-1, dout.shape[-1]]) + + keep_x = not isinstance(x, tuple) + + if keep_x: + # padding x and quant + dx_orig_shape = x.shape + x = FP8LinearFunctionBase.padding(x, 0) + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=True, return_transpose_only=True + ) + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx = dx.reshape(dx_orig_shape) + + else: + x_t_fp8, x_t_scale = x + + # ===== dx = deep_gemm(dout_fp8, w_fp8) + dx, dout_t_fp8, dout_t_scale = FP8LinearFunctionBase.compute_fp8_linear( + dout_2d, weight, weight_transpose=False, return_mode="with_input_transpose_quant" + ) + dx_orig_shape = dout.shape[:-1] + dx_orig_shape.append(ctx.x_t_shape[0]) + dx = dx.reshape(dx_orig_shape) + + # ===== dw1 = deep_gemm(x_t_fp8, dout_t_fp8) + FP8LinearFunctionBase.compute_expert_w_grad( + x_t_fp8, x_t_scale, dout_t_fp8, dout_t_scale, True, True, weight, paddle.float32 + ) + return dx + + +class FP8Linear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=False) + + +def cache_fp8_weight(weight): + if hasattr(weight, "fp8_weight"): + return + w_fp8, w_scale, w_t_fp8, w_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + weight, + output_scale_transpose=False, + quant_method="128x128", + input_transpose=True, + return_transpose_only=False, + ) + + setattr(weight, "fp8_weight_transpose", w_t_fp8) + setattr(weight, "fp8_scale_transpose", w_t_scale) + setattr(weight, "fp8_weight", w_fp8) + setattr(weight, "fp8_scale", w_scale) + + +class FP8KeepXLinear(paddle.nn.Layer): + def __init__(self, in_features: int, out_features: int, bias_attr: bool = False) -> None: + super().__init__() + self._dtype = self._helper.get_default_dtype() + + self.weight = self.create_parameter( + shape=[in_features, out_features], + dtype="bfloat16", + is_bias=False, + ) + set_parameter_color([self.weight], "attn_out_project") + + def fp8_quant_weight(self): + cache_fp8_weight(self.weight) + + def forward(self, x): + return FP8LinearFunction.apply(x, self, keep_x=True) + + +class FusedNormFP8MLPFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, norm_w, w1, w2, norm_eps): + # ===== compute norm_output ===== + norm_output, invar = fused_ln.fused_rms_norm(x, norm_w, norm_eps) + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = norm_output.shape + norm_output = norm_output.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + _, _, _, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(norm_output, w1, w2) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + ctx.save_for_backward( + norm_output, + invar, + x, + norm_w, + w1, + w2, + norm_eps, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + norm_output, invar, x, norm_w, w1, w2, norm_eps, x_orig_shape = ctx.saved_tensor() + + x_fp8, x_scale, x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + norm_output, output_scale_transpose=True, quant_method="1x128", input_transpose=True + ) + + # ===== call func common_fp8_mlp_bwd ===== + d_norm_output, dw1, dw2 = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale + ) + + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + d_norm_output = d_norm_output.reshape([x_orig_shape[0], -1, d_norm_output.shape[-1]]) + + # ===== compute norm grad ===== + dx, d_rms_norm_weight = fused_ln.fused_rms_norm_grad_func(x, norm_w, invar, d_norm_output, norm_eps) + + return dx, d_rms_norm_weight, dw1, dw2 + + +class FP8MlpFunction(paddle.autograd.PyLayer): + @staticmethod + def forward(ctx, x, w1, w2, recompute_fwd_gate_up): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + x_orig_shape = x.shape + x = x.reshape([-1, x_orig_shape[-1]]) + + # ===== call func fp8_mlp_fwd ===== + o1, x_fp8, x_scale, o3 = FP8LinearFunctionBase.fp8_mlp_fwd(x, w1, w2) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + o3 = o3.reshape([x_orig_shape[0], -1, o3.shape[-1]]) + + # ===== save for backward ===== + o1 = None if recompute_fwd_gate_up else o1 + ctx.save_for_backward( + o1, + x_fp8, + x_scale, + w1, + w2, + paddle.to_tensor(x_orig_shape, dtype="int64", place=paddle.CPUPlace()), + ) + return o3 + + @staticmethod + def backward(ctx, do3): + # ===== reshape for deep_gemm, since deep_gemm only support 2D ===== + do3_orig_shape = do3.shape + do3 = do3.reshape([-1, do3_orig_shape[-1]]) + + # ===== recive saved tensors ===== + o1, x_fp8, x_scale, w1, w2, x_orig_shape = ctx.saved_tensor() + + # ===== compute x_t_fp8, x_t_scale for dw1 ===== + x_dequant_fp16 = paddle.incubate.nn.functional.fused_act_dequant(x_fp8, x_scale.T.contiguous()) + x_dequant_fp16 = FP8LinearFunctionBase.padding(x_dequant_fp16, 0) + + x_t_fp8, x_t_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x_dequant_fp16, + output_scale_transpose=True, + quant_method="1x128", + input_transpose=True, + return_transpose_only=True, + ) + + # ===== call func common_fp8_mlp_bwd ===== + if o1 is None: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=None, x_fp8=x_fp8, x_scale=x_scale, apply_backward_hook=True + ) + else: + dx = FP8LinearFunctionBase.common_fp8_mlp_bwd( + do3, x_t_fp8, x_t_scale, w1, w2, o1=o1, x_fp8=None, x_scale=None, apply_backward_hook=True + ) + # ===== reshape to origin shape ===== + if len(x_orig_shape) > 2: + dx = dx.reshape([x_orig_shape[0], -1, dx.shape[-1]]) + + return dx, None, None + + +class FP8Mlp(paddle.nn.Layer): + def __init__( + self, + config, + hidden_size=None, + intermediate_size=None, + is_moe=False, + using_post_norm_recompute=False, + norm_weight=None, + norm_eps=None, + recompute_fwd_gate_up=False, + ): + super().__init__() + self.config = config + self.using_post_norm_recompute = using_post_norm_recompute + if self.using_post_norm_recompute: + assert norm_weight is not None and norm_eps is not None + self.norm_weight = norm_weight + self.norm_eps = norm_eps + + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + + self.w1 = self.create_parameter( + shape=[self.hidden_size, self.intermediate_size * 2], + dtype="bfloat16", + is_bias=False, + ) + self.w2 = self.create_parameter( + shape=[self.intermediate_size, self.hidden_size], + dtype="bfloat16", + is_bias=False, + ) + + def fp8_quant_weight(self): + cache_fp8_weight(self.w1) + cache_fp8_weight(self.w2) + + def forward(self, x): + if self.using_post_norm_recompute: + return FusedNormFP8MLPFunction.apply(x, self.norm_weight, self.w1, self.w2, self.norm_eps) + else: + return FP8MlpFunction.apply(x, self.w1, self.w2, self.recompute_fwd_gate_up) + + +def split_group_gemm(x_fp8, x_scale, w_fp8, w_scale, tokens_per_expert, gemm_out): + start_idx = 0 + for i, token_num in enumerate(tokens_per_expert): + if token_num == 0: + continue + end_idx = start_idx + token_num + + x_scale_tma_align = x_scale[start_idx:end_idx].T.contiguous().T + + deep_gemm.gemm_fp8_fp8_bf16_nt( + (x_fp8[start_idx:end_idx], x_scale_tma_align), + (w_fp8[i], w_scale[i]), + gemm_out[start_idx:end_idx], + num_sms=118, + ) + + start_idx = end_idx + + return gemm_out + + +class FP8GroupGemmMlpFunctionNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=False, + name="experts_group_gemm_contiguous_node", + ): + self.experts = custom_map.experts + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.is_split_group_gemm = is_split_group_gemm + self.m_indices = None + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + self.fwd_subbatch = None + self.bwd_subbatch = None + + def reset_statue(self): + self.m_indices = None + self.fwd_subbatch = None + self.bwd_subbatch = None + self.clear_activation_tensors() + + def clear_activation_tensors(self): + self.input = None + self.input_fp8 = None + self.input_scale = None + self.o1 = None + self.all_unzipped_grad = None + + def gen_m_indices(self, tokens_per_expert): + tokens = [] + for i in range(len(tokens_per_expert)): + tokens.append(paddle.full([tokens_per_expert[i]], i, dtype="int32")) + out = paddle.concat(tokens, axis=0) + return out + + def fwd_gate_up(self, x, expert_w1, num_expert, tokens_per_expert, m_indices=None): + """ + o1 = x * w1 + [m_sum, n] = [m_sum, k] * [num_groups, k, n] (m_sum = sum(tokens_per_expert)) + """ + if not self.is_split_group_gemm and self.m_indices is None: + self.m_indices = self.gen_m_indices(tokens_per_expert) + # concat w1, shape is [num_groups, n, k] + w1_t_quant, w1_t_scale = fused_stack_quant(expert_w1, transpose=True) + w1_t_quant = w1_t_quant.reshape([num_expert, -1, w1_t_quant.shape[-1]]) + w1_t_scale = w1_t_scale.reshape([num_expert, -1, w1_t_scale.shape[-1]]) + + if x is None: + x_fp8, x_scale = self.input_fp8, self.input_scale + assert x_fp8 is not None and x_scale is not None + else: + if isinstance(x, tuple): + (x_fp8, x_scale) = x + x_scale = paddle.transpose(paddle.transpose(x_scale, [1, 0]).contiguous(), [1, 0]) + else: + # quant x_bf16 + x_fp8, x_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + x, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + x_scale = x_scale.T + + # compute gemm + o1 = paddle.empty([x_fp8.shape[0], w1_t_quant.shape[1]], dtype=expert_w1[0].dtype) + if numpy.prod(x_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(x_fp8, x_scale, w1_t_quant, w1_t_scale, tokens_per_expert, o1) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (x_fp8, x_scale), + (w1_t_quant, w1_t_scale), + o1, + m_indices=self.m_indices if m_indices is None else m_indices, + num_sms=118, + ) + + if m_indices is None: + self.input_fp8 = x_fp8 + self.input_scale = x_scale + return o1 + + def fwd_swiglu(self, o1): + o2 = swiglu(o1) + return o2 + + def fwd_down( + self, o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, m_indices=None, o3=None, clear_o1=False + ): + """ + o3 = o2 * w2 + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # concat and transpose w2 + w2_quant, w2_scale = fused_stack_quant(expert_w2, transpose=True) + w2_quant = w2_quant.reshape([num_expert, -1, w2_quant.shape[-1]]) + w2_scale = w2_scale.reshape([num_expert, -1, w2_scale.shape[-1]]) + + # quant o2 + with paddle.amp.auto_cast(False): + unzipped_probs = unzipped_probs.squeeze(-1) + o2_fp8, o2_scale = paddle.incubate.nn.functional.fused_weighted_swiglu_act_quant( + o1, unzipped_probs, using_pow2_scaling=True + ) + o2_scale = paddle.transpose(paddle.transpose(o2_scale, [1, 0]).contiguous(), [1, 0]) + + if clear_o1: + o1._clear_to_zero_allocation() + + # compute gemm + o3_shape = [o2_fp8.shape[0], w2_quant.shape[1]] + if o3 is not None: + assert o3.shape == o3_shape, "{} vs {}".format(o3.shape, o3_shape) + else: + o3 = paddle.empty(o3_shape, dtype=o1.dtype) + if numpy.prod(o2_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(o2_fp8, o2_scale, w2_quant, w2_scale, tokens_per_expert, o3) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (o2_fp8, o2_scale), + (w2_quant, w2_scale), + o3, + m_indices=m_indices if self.fwd_subbatch else self.m_indices, + num_sms=118, + ) + + return o3 + + def bwd_dowm_input(self, expert_w2, unzipped_grad, o1, tokens_per_expert, m_indices=None, unzipped_probs=None): + """ + do2 = do3 * w2_t + [m_sum, n] = [m_sum, k] * [num_groups, k, n] + """ + # recompute concated_w2_2d + bw_w2_quant, bw_w2_scale = fused_stack_quant(expert_w2, transpose=False) + bw_w2_quant = bw_w2_quant.reshape([len(expert_w2), -1, bw_w2_quant.shape[-1]]) + bw_w2_scale = bw_w2_scale.reshape([len(expert_w2), -1, bw_w2_scale.shape[-1]]) + + # compute gemm + if isinstance(unzipped_grad, tuple): + (unzipped_grad_fp8, unzipped_grad_scale) = unzipped_grad + unzipped_grad_scale = unzipped_grad_scale.T.contiguous().T + else: + unzipped_grad_fp8, unzipped_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + unzipped_grad, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + unzipped_grad_scale = unzipped_grad_scale.T + + do2_s = paddle.empty([unzipped_grad_fp8.shape[0], bw_w2_quant.shape[1]], dtype="bfloat16") + if numpy.prod(unzipped_grad_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm( + unzipped_grad_fp8, unzipped_grad_scale, bw_w2_quant, bw_w2_scale, tokens_per_expert, do2_s + ) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (unzipped_grad_fp8, unzipped_grad_scale), + (bw_w2_quant, bw_w2_scale), + do2_s, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=118, + ) + + with paddle.amp.auto_cast(False): + do1, probs_grad, o2_s = paddle.incubate.nn.functional.fused_swiglu_weighted_bwd(o1, do2_s, unzipped_probs) + + return do1, o2_s, probs_grad + + def bwd_swiglu(self, o1, do2): + do1, _ = paddle._C_ops.swiglu_grad(o1, None, do2) + return do1 + + def bwd_gate_up_input(self, do1, expert_w1, tokens_per_expert, m_indices=None, dx=None): + """ + dx = do1 * w1_t + [m_sum, k] = [m_sum, n] * [num_groups, n, k] + """ + # recompute concated_w1_t + bw_w1_quant, bw_w1_scale = fused_stack_quant(expert_w1, transpose=False) + bw_w1_quant = bw_w1_quant.reshape([len(expert_w1), -1, bw_w1_quant.shape[-1]]) + bw_w1_scale = bw_w1_scale.reshape([len(expert_w1), -1, bw_w1_scale.shape[-1]]) + + # quant do1 + do1_fp8, do1_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + do1, output_scale_transpose=True, quant_method="1x128", input_transpose=False + ) + do1_scale = do1_scale.T + # compute gemm + dx_shape = [do1_fp8.shape[0], bw_w1_quant.shape[1]] + if dx is None or dx.dtype != do1.dtype: + dx = paddle.empty(shape=dx_shape, dtype=do1.dtype) + else: + assert dx.shape == dx_shape, f"{dx.shape} vs {dx_shape}" + if numpy.prod(do1_fp8.shape) != 0: + if self.is_split_group_gemm: + split_group_gemm(do1_fp8, do1_scale, bw_w1_quant, bw_w1_scale, tokens_per_expert, dx) + else: + deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( + (do1_fp8, do1_scale), + (bw_w1_quant, bw_w1_scale), + dx, + m_indices=m_indices if self.bwd_subbatch else self.m_indices, + num_sms=118, + ) + + return dx + + def fused_transpose_split_quant(self, x, scale, tokens_per_expert, pow_2_scales): + out, scale = paddle.incubate.nn.functional.fused_transpose_split_quant( + x, scale, tokens_per_expert, pow_2_scales + ) + return out, scale + + def bwd_down_weight(self, do3, o2, expert_w2, tokens_per_expert): + """ + dw2 = do2_t * do3 + [n, k] = [n, m_sum] * [m_sum, k] (m_sum = sum(tokens_per_expert)) + """ + if isinstance(o2, tuple): + o2_t_fp8, o2_t_scale = o2 + else: + o2_t_fp8, o2_t_scale = self.fused_transpose_split_quant(o2, None, tokens_per_expert, True) + + if isinstance(do3, tuple): + do3_t_fp8, do3_t_scale = do3 + else: + do3_t_fp8, do3_t_scale = self.fused_transpose_split_quant(do3, None, tokens_per_expert, True) + + def cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2): + with paddle.no_grad(): + for i in range(len(expert_w2)): + FP8LinearFunctionBase.compute_expert_w_grad( + o2_t_fp8[i], + o2_t_scale[i], + do3_t_fp8[i], + do3_t_scale[i], + True, + True, + expert_w2[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put(partial(cal_weight_fn, o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2)) + else: + cal_weight_fn(o2_t_fp8, o2_t_scale, do3_t_fp8, do3_t_scale, expert_w2) + + def bwd_gate_up_weight( + self, + do1, + input_x, + expert_w1, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + clear_input=False, + ): + """ + dw1 = dx_t * do1 + [k, n] = [k, m_sum] * [m_sum, n] (m_sum = sum(tokens_per_expert)) + """ + if input_x is None: + inp = (input_fp8_slice, input_scale_slice) if self.bwd_subbatch else (self.input_fp8, self.input_scale) + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(inp[0], inp[1], tokens_per_expert, True) + + else: + input_x_t_fp8, input_x_t_scale = self.fused_transpose_split_quant(input_x, None, tokens_per_expert, True) + + if clear_input: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + do1_t_fp8, do1_t_scale = self.fused_transpose_split_quant(do1, None, tokens_per_expert, True) + + def cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1): + with paddle.no_grad(): + for i in range(len(expert_w1)): + FP8LinearFunctionBase.compute_expert_w_grad( + input_x_t_fp8[i], + input_x_t_scale[i], + do1_t_fp8[i], + do1_t_scale[i], + True, + True, + expert_w1[i], + paddle.float32, + ) + + if WeightGradStore.enabled: + WeightGradStore.put( + partial(cal_weight_fn, input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + ) + else: + cal_weight_fn(input_x_t_fp8, input_x_t_scale, do1_t_fp8, do1_t_scale, expert_w1) + + @paddle.no_grad() + def forward(self, hs_out, unzipped_probs, tokens_per_expert, m_indices=None): + # check subbatch + if self.fwd_subbatch: + assert m_indices is not None + # deal 0 size + dtype = paddle.bfloat16 + if hs_out is None: + assert self.input_fp8 is not None + assert self.input_scale is not None + shape = self.input_fp8.shape + else: + if isinstance(hs_out, tuple): + shape = hs_out[0].shape + else: + shape = hs_out.shape + + if shape[0] == 0: + o3 = paddle.zeros(shape, dtype=dtype) + return o3 + + # get w1/w2 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + num_expert = len(expert_w1) + + # o1 + o1 = self.fwd_gate_up(hs_out, expert_w1, num_expert, tokens_per_expert, m_indices) + if not self.recompute_fwd_gate_up: + self.o1 = o1 + clear_o1 = False + else: + clear_o1 = True + + # o3 + o3 = self.fwd_down( + o1, unzipped_probs, expert_w2, num_expert, tokens_per_expert, clear_o1=clear_o1, m_indices=m_indices + ) + + # save for bwd + return o3 + + @paddle.no_grad() + def backward( + self, + out_grad, + unzipped_probs, + tokens_per_expert, + input_fp8_slice=None, + input_scale_slice=None, + m_indices=None, + reset_status=False, + ): + # check subbatch + if self.bwd_subbatch: + assert ( + m_indices is not None + and input_fp8_slice is not None + and input_scale_slice is not None + and tokens_per_expert is not None + ) + # deal 0 size + dtype = paddle.bfloat16 + shape = out_grad[0].shape if isinstance(out_grad, tuple) else out_grad.shape + if shape[0] == 0: + return paddle.zeros_like(extract_first_if_tuple(out_grad), dtype=dtype), paddle.zeros_like(unzipped_probs) + + # recompute expert_w2 and expert_w1 + expert_w1 = [x.w1 for x in self.experts if x is not None] + expert_w2 = [x.w2 for x in self.experts if x is not None] + + if self.recompute_fwd_gate_up: + inp = None if not self.bwd_subbatch else (input_fp8_slice, input_scale_slice) + o1 = self.fwd_gate_up(inp, expert_w1, len(expert_w1), tokens_per_expert, m_indices=m_indices) + else: + o1 = self.o1 + + # do2 + do1, o2_s, probs_grad = self.bwd_dowm_input( + expert_w2, out_grad, o1, tokens_per_expert, unzipped_probs=unzipped_probs, m_indices=m_indices + ) + del o1 + if self.o1 is not None: + self.o1._clear_to_zero_allocation() + self.o1 = None + + # dw1 + self.bwd_gate_up_weight( + do1, + None, + expert_w1, + tokens_per_expert, + input_fp8_slice=input_fp8_slice, + input_scale_slice=input_scale_slice, + clear_input=reset_status, + ) + + if reset_status: + if self.input_fp8 is not None: + self.input_fp8._clear_to_zero_allocation() + self.input_fp8 = None + if self.input_scale is not None: + self.input_scale._clear_to_zero_allocation() + self.input_scale = None + if self.input is not None: + self.input._clear_to_zero_allocation() + self.input = None + + # dx + dx = self.bwd_gate_up_input( + do1, + expert_w1, + tokens_per_expert, + dx=out_grad[0] if isinstance(out_grad, tuple) else out_grad, + m_indices=m_indices, + ) + del do1 + + # dw2 + if isinstance(out_grad, tuple): + do3_fp8, do3_scale = self.fused_transpose_split_quant(out_grad[0], out_grad[1], tokens_per_expert, True) + out_grad[0]._clear_to_zero_allocation() + out_grad[1]._clear_to_zero_allocation() + self.bwd_down_weight((do3_fp8, do3_scale), o2_s, expert_w2, tokens_per_expert) + else: + self.bwd_down_weight(out_grad, o2_s, expert_w2, tokens_per_expert) + + if reset_status: + self.reset_statue() + return dx, probs_grad diff --git a/paddleformers/transformers/fused_a2a.py b/paddleformers/transformers/fused_a2a.py index 7b5fa09c9e0..400f97cd0a4 100644 --- a/paddleformers/transformers/fused_a2a.py +++ b/paddleformers/transformers/fused_a2a.py @@ -72,78 +72,144 @@ def get_buffer(group: Group, hidden_bytes: int): return _buffer +def fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Forward pass of fused dispatch.""" + # Calculate layout before actual dispatch + if isinstance(x, tuple): + buffer = get_buffer(group, get_hidden_bytes(x[0])) + else: + buffer = get_buffer(group, get_hidden_bytes(x)) + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + _previous_event, + ) = buffer.get_dispatch_layout( + token_indices, + num_experts, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + # Do MoE dispatch + # NOTES: the CPU will wait for GPU's signal to arrive, + # so this is not compatible with CUDA graph + (recv_x, recv_token_indices, recv_token_probs, num_recv_tokens_per_expert_list, handle, event,) = buffer.dispatch( + x, + topk_idx=token_indices, + topk_weights=token_probs.cast(paddle.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + states = dict() + states["dispatched_indices"] = recv_token_indices + states["tokens_per_expert"] = num_recv_tokens_per_expert_list + states["handle"] = handle + + return recv_x, recv_token_probs, states, event + + +def fused_dispatch_backward_func( + grad_output, + grad_token_probs, + group, + handle, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, +): + """Backward pass of fused dispatch.""" + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + + grad_x, grad_token_probs, event = buffer.combine( + grad_output.contiguous(), + handle, + topk_weights=grad_token_probs.cast(paddle.float32), + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x, None, grad_token_probs + + +def fused_combine_forward_func( + x, group, states, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Forward pass of fused combine.""" + handle = states["handle"] + buffer = get_buffer(group, get_hidden_bytes(x)) + combined_x, _, event = buffer.combine( + x, + handle=handle, + async_finish=async_finish, + previous_event=previous_event, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return combined_x + + +def fused_combine_backward_func( + grad_output, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False +): + """Backward pass of fused combine.""" + if isinstance(grad_output, tuple): + buffer = get_buffer(group, get_hidden_bytes(grad_output[0])) + grad_x, _, _, _, _, event = buffer.dispatch( + (grad_output[0].contiguous(), grad_output[1].contiguous()), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + else: + buffer = get_buffer(group, get_hidden_bytes(grad_output)) + grad_x, _, _, _, _, event = buffer.dispatch( + grad_output.contiguous(), + handle=handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return grad_x + + class FusedDispatch(PyLayer): """Fused dispatch operation for MoE routing combining computation and communication.""" @staticmethod def forward(ctx, x, token_indices, token_probs, num_experts, group, previous_event=None): """Forward pass of fused dispatch.""" - # Calculate layout before actual dispatch - buffer = get_buffer(group, get_hidden_bytes(x)) - ( - num_tokens_per_rank, - num_tokens_per_rdma_rank, - num_tokens_per_expert, - is_token_in_rank, - previous_event, - ) = buffer.get_dispatch_layout( - token_indices, - num_experts, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - - # Do MoE dispatch - # NOTES: the CPU will wait for GPU's signal to arrive, - # so this is not compatible with CUDA graph - ( - recv_x, - recv_token_indices, - recv_token_probs, - num_recv_tokens_per_expert_list, - handle, - event, - ) = buffer.dispatch( - x, - topk_idx=token_indices, - topk_weights=token_probs.cast(paddle.float32), - num_tokens_per_rank=num_tokens_per_rank, - num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, - is_token_in_rank=is_token_in_rank, - num_tokens_per_expert=num_tokens_per_expert, - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, token_indices, token_probs, num_experts, group, previous_event ) ctx.group = group - ctx.handle = handle + ctx.handle = states["handle"] ctx.event = event - tokens_per_expert = paddle.to_tensor(num_recv_tokens_per_expert_list) - - states = dict() - states["dispatched_indices"] = recv_token_indices - states["tokens_per_expert"] = tokens_per_expert - states["handle"] = handle return recv_x, recv_token_probs, states @staticmethod def backward(ctx, grad_output, grad_token_probs): """Backward pass of fused dispatch.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - handle = ctx.handle - - grad_x, grad_token_probs, event = buffer.combine( - grad_output.contiguous(), - handle, - topk_weights=grad_token_probs.cast(paddle.float32), - previous_event=None, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x, None, grad_token_probs + return fused_dispatch_backward_func(grad_output, grad_token_probs, ctx.group, ctx.handle) class FusedCombine(PyLayer): @@ -152,12 +218,9 @@ class FusedCombine(PyLayer): @staticmethod def forward(ctx, x, group, states, previous_event=None): """Forward pass of fused combine.""" - handle = states["handle"] - buffer = get_buffer(group, get_hidden_bytes(x)) - combined_x, _, event = buffer.combine( - x, handle=handle, async_finish=False, previous_event=None, allocate_on_comm_stream=False - ) - ctx.handle = handle + combined_x = fused_combine_forward_func(x, group, states, previous_event) + + ctx.handle = states["handle"] ctx.group = group ctx.previous_event = previous_event @@ -166,15 +229,7 @@ def forward(ctx, x, group, states, previous_event=None): @staticmethod def backward(ctx, grad_output): """Backward pass of fused combine.""" - buffer = get_buffer(ctx.group, get_hidden_bytes(grad_output)) - grad_x, _, _, _, _, event = buffer.dispatch( - grad_output.contiguous(), - handle=ctx.handle, - previous_event=ctx.previous_event, - async_finish=False, - allocate_on_comm_stream=False, - ) - return grad_x + return fused_combine_backward_func(grad_output, ctx.group, ctx.handle, ctx.previous_event) if HAVE_DEEP_EP: @@ -214,3 +269,96 @@ def fused_combine(x, group, handle, previous_event=None): else: fused_dispatch = None fused_combine = None + + +class DispatchNode: + def __init__(self, name="dispatch"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward( + self, + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + """Forward pass of fused dispatch.""" + recv_x, recv_token_probs, states, event = fused_dispatch_forward_func( + x, + token_indices, + token_probs, + num_experts, + group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.group = group + self.handle = states["handle"] + self.event = event + + return recv_x, recv_token_probs, states + + def backward( + self, grad_output, grad_token_probs, previous_event=None, async_finish=False, allocate_on_comm_stream=False + ): + """Backward pass of fused dispatch.""" + out = fused_dispatch_backward_func( + grad_output, + grad_token_probs, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out + + +class CombineNode: + def __init__(self, name="combine"): + self.name = name + + def reset_statue(self): + self.handle = None + + def forward(self, x, group, handle, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Forward pass of fused combine.""" + states = dict() + states["handle"] = handle + combined_x = fused_combine_forward_func( + x, + group, + states, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + + self.handle = handle + self.group = group + self.previous_event = previous_event + + return combined_x + + def backward(self, grad_output, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + """Backward pass of fused combine.""" + out = fused_combine_backward_func( + grad_output, + self.group, + self.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.reset_statue() + return out \ No newline at end of file diff --git a/paddleformers/transformers/moe_layer.py b/paddleformers/transformers/moe_layer.py index 340fba1f524..f14e45d13a6 100644 --- a/paddleformers/transformers/moe_layer.py +++ b/paddleformers/transformers/moe_layer.py @@ -16,6 +16,7 @@ # limitations under the License. from __future__ import annotations +import os from typing import Any, List, Tuple import numpy as np @@ -24,8 +25,48 @@ from paddle import Tensor, nn from paddle.distributed.communication.group import Group +from ..utils.log import logger +from .fp8_utils import FP8GroupGemmMlpFunctionNode, extract_first_if_tuple +from .fused_a2a import CombineNode, DispatchNode, get_buffer, get_hidden_bytes from .moe_gate import PretrainedMoEGate -from .token_dispatcher import MoEFlexTokenDispatcher +from .moe_utils import ( + UnZipNode, + ZipNode, + merge_subbatch_cast, + offload, + reload, + tokens_zip_unique_add_with_subbatch, +) +from .token_dispatcher import MoEFlexTokenDispatcher, PreDispatchNode + +try: + import paddle.distributed.communication.deep_ep as deep_ep +except ImportError: + deep_ep = None + +DSV3_USE_FP8_GEMM = os.getenv("DSV3_USE_FP8_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_GROUP_GEMM = os.getenv("DSV3_USE_FP8_GROUP_GEMM", "False").lower() == "true" + +DSV3_USE_FP8_DISPATCH = os.getenv("DSV3_USE_FP8_DISPATCH", "False").lower() == "true" + + +import TokenDispatcherUtils as TDU + + +def record_stream_for_multi_input(x): + if isinstance(x, (tuple, list)): + for i in range(len(x)): + x[i]._record_stream() + else: + x._record_stream() + + +def stop_gradient_for_multi_input(x): + if isinstance(x, (tuple, list)): + x[0].stop_gradient = False + else: + x.stop_gradient = False def dispatching(x, dispatch_mask, scatter_index, num_experts, capacity): @@ -162,6 +203,7 @@ def __init__( capacity: int = 1.0, moe_group: str = "data", all_to_all_dropout=0.0, + using_post_norm_recompute=False, ): super().__init__() @@ -176,12 +218,11 @@ def __init__( except AttributeError: is_fleet_init = False - if ( - is_fleet_init - and dist.fleet.get_hybrid_communicate_group().get_data_parallel_world_size() > 1 - and moe_group == "data" - ): - self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + if is_fleet_init and dist.get_world_size() > 1: + if moe_group == "data": + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_data_parallel_group() + elif moe_group == "expert": + self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group self.moe_rank = dist.get_rank(self.moe_group) self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank self.expert_parallel_degree = dist.get_world_size(self.moe_group) @@ -210,8 +251,32 @@ def __init__( self.gate = gate self.gate.group = self.moe_group + # for flex token moe layer + self.router = gate + self.ep_size = dist.get_world_size(self.moe_group) + self.moe_router_topk = gate.top_k + self.num_local_experts = moe_num_experts // self.ep_size + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_local_experts, self.moe_router_topk, self.moe_num_experts, self.moe_group + ) + self.token_drop_steps = config.token_drop_steps + self.using_flex_token = False + + self.using_post_norm_recompute = using_post_norm_recompute self._post_init() + def update_flex_token(self): + from paddleformers.transformers.deepseek_v2 import get_global_step + + if (not self.config.using_flex_token) or (get_global_step() < self.token_drop_steps): + self.using_flex_token = False + self.router.using_flex_token = False + else: + if not self.using_flex_token: + logger.info("Changing to flex token moe mode") + self.using_flex_token = True + self.router.using_flex_token = True + def _parse_moe_expert_parallel(self, moe_num_experts, expert_parallel_degree): assert ( moe_num_experts >= expert_parallel_degree @@ -234,8 +299,35 @@ def _post_init(self): # logger.info(f"expert param={p.name}, no-sync={p.no_sync}") def forward( + self, + hidden_states: paddle.Tensor, + probs=None, + routing_map=None, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, + ): + self.update_flex_token() + + if self.using_flex_token: + return self.forward_flex_token(hidden_states, probs, routing_map, l_aux, l_zloss) + else: + return self.forward_drop_token( + hidden_states, capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss + ) + + def forward_drop_token( self, hidden_state: paddle.Tensor, + capacity=None, + topk_weight=None, + topk_ids=None, + token_priority=None, + l_aux=None, + l_zloss=None, ): """MoE Layer forward function 1. Gate Forward. @@ -257,7 +349,17 @@ def forward( # topk_ids : sk # token_priority : se # self.exp_counts : - capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) + if self.using_post_norm_recompute: + assert ( + capacity is not None + and topk_weight is not None + and topk_ids is not None + and token_priority is not None + and l_aux is not None + and l_zloss is not None + ) + else: + capacity, topk_weight, topk_ids, token_priority, l_aux, l_zloss = self.gate(hidden_state) """MoE expert dispatch from: https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py""" cnts = paddle.zeros([topk_ids.shape[0], len(self.experts)], dtype=topk_ids.dtype) @@ -336,6 +438,801 @@ def forward( return final_out, l_aux, l_zloss + def forward_flex_token(self, hidden_states: paddle.Tensor, probs=None, routing_map=None, l_aux=None, l_zloss=None): + _, _, d_model = hidden_states.shape + # reshaped_input = hidden_states.reshape([-1, d_model]) + if self.using_post_norm_recompute: + assert probs is not None and routing_map is not None and l_aux is not None and l_zloss is not None + else: + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + if DSV3_USE_FP8_GEMM: + if DSV3_USE_FP8_DISPATCH: + output = FusionMoe.apply( + hidden_states, + probs, + routing_map, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + output = FusionMoe.apply( + hidden_states, + token_indices, + token_probs, + self, + recompute_fwd_gate_up=self.config.recompute_fwd_gate_up, + is_split_group_gemm=self.config.is_split_group_gemm, + ) + else: + ( + dispatched_input, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) = self.token_dispatcher.token_permutation_fast(hidden_states, probs, routing_map) + + expert_output = self.expert_forward(dispatched_input) + output, _ = self.token_dispatcher.token_unpermutation_fast( + expert_output, token_permuted_indices, prob_permuted_indices, dispatched_probs, None + ) + return output, l_aux, l_zloss + + def get_tokens_per_expert(self): + return self.token_dispatcher._comm_manager.tokens_per_expert_list + + def set_tokens_per_expert(self, tokens_per_expert_list): + self.token_dispatcher._comm_manager.tokens_per_expert_list = tokens_per_expert_list + + def expert_forward(self, dispatched_input): + outputs = [] + chunks = paddle.split(dispatched_input, num_or_sections=self.get_tokens_per_expert(), axis=0) + for i, chunk in enumerate(chunks): + chunk = chunk.contiguous() + # assert chunk.shape[0] != 0, "Cannot dispatch empty input" + expert = self.experts[i + self.moe_rank * self.moe_num_experts_per_device] + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def pre_dispatch_compute(self, hidden_states): + _, _, d_model = hidden_states.shape + probs, routing_map, l_aux, l_zloss = self.router(hidden_states) + hidden_states, token_indices, token_probs = self.token_dispatcher.pre_dispatch( + hidden_states, probs, routing_map + ) + return l_aux, l_zloss, hidden_states, token_indices, token_probs + + def post_dispatch_compute(self, hidden_states, dispatched_indices, dispatched_probs): + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.token_dispatcher.post_dispatch( + hidden_states, dispatched_indices + ) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine_compute(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self.token_dispatcher.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine_compute(self, hidden_states): + hidden_states = self.token_dispatcher.post_combine(hidden_states) + return hidden_states + + +class Fp8DispatchQuantNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_quant_node"): + self.token_dispatcher = token_dispatcher + self.pre_dispatch_node = PreDispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + # reshape + self.token_dispatcher.hidden_shape = hidden_states.shape + hs_2d = hidden_states.view([-1, self.token_dispatcher.hidden_shape[-1]]) + + if DSV3_USE_FP8_DISPATCH: + # quant + hs_fp8, hs_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + hs_2d, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_fp8.stop_gradient = False + token_probs.stop_gradient = False + return (hs_fp8, hs_scale), token_indices, token_probs + else: + # pre_dispatch + token_indices, token_probs = self.pre_dispatch_node.forward(routing_map, probs) + + self.hidden_states_shape = hidden_states.shape + hs_2d.stop_gradient = False + token_probs.stop_gradient = False + return hs_2d, token_indices, token_probs + + @paddle.no_grad() + def backward(self, hs_grad, token_probs_grad): + # predispatch grad + probs_grad = self.pre_dispatch_node.backward(token_probs_grad) + token_probs_grad._record_stream() + + # reshape_grad + hs_grad = hs_grad.view(self.hidden_states_shape) + hs_grad._record_stream() + + return hs_grad, probs_grad, None + + +class Fp8DispatchNode: + def __init__(self, token_dispatcher, name="fp8_dispatch_node"): + self.token_dispatcher = token_dispatcher + self.dispatch_act_node = DispatchNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward( + self, + hs_2d, + token_indices, + token_probs, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch + hs_2d_dispatched, dispatched_probs, states = self.dispatch_act_node.forward( + hs_2d, + token_indices, + token_probs, + self.token_dispatcher._comm_manager.num_experts, + self.token_dispatcher._comm_manager.group, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + self.token_dispatcher._comm_manager.handle = states["handle"] + self.token_dispatcher._comm_manager.tokens_per_expert = states["tokens_per_expert"] + dispatched_indices = states["dispatched_indices"] + + stop_gradient_for_multi_input(hs_2d_dispatched) + dispatched_probs.stop_gradient = False + return hs_2d_dispatched, dispatched_indices, dispatched_probs + + @paddle.no_grad() + def backward( + self, + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False, + ): + # dispatch grad + hs_grad, _, token_probs_grad = self.dispatch_act_node.backward( + hs_dispatched_grad, + dispatched_probs_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hs_grad, token_probs_grad + + +class Fp8CombineNode: + def __init__(self, token_dispatcher, name="fp8_combine_node"): + self.token_dispatcher = token_dispatcher + self.combine_node = CombineNode(token_dispatcher) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states_out, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine + output_combine = self.combine_node.forward( + hidden_states_out, + self.token_dispatcher._comm_manager.group, + self.token_dispatcher._comm_manager.handle, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + output_combine.stop_gradient = False + self.token_dispatcher._comm_manager.handle = None + return output_combine + + @paddle.no_grad() + def backward(self, output_combine_grad, previous_event=None, async_finish=False, allocate_on_comm_stream=False): + # combine grad -> fp8 + hidden_states_out_grad = self.combine_node.backward( + output_combine_grad, + previous_event=previous_event, + async_finish=async_finish, + allocate_on_comm_stream=allocate_on_comm_stream, + ) + return hidden_states_out_grad + + +class Fp8CombineQuantNode: + def __init__(self, token_dispatcher, moe_group=None, name="fp8_combine_quant_node"): + self.token_dispatcher = token_dispatcher + self.name = name + self.moe_group = moe_group + + @paddle.no_grad() + def forward(self, output_combine): + # post combine + output = output_combine.reshape(self.token_dispatcher.hidden_shape) + output_combine._record_stream() + self.output_combine_shape = output_combine.shape + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad, event_to_wait=None): + # post combine grad + if DSV3_USE_FP8_DISPATCH: + if event_to_wait is not None: + assert self.moe_group is not None + event_to_wait.comm_stream_wait(self.moe_group.id) + buffer = get_buffer(self.token_dispatcher._comm_manager.group, get_hidden_bytes(output_grad)) + custom_stream = paddle.device.Stream(stream_base=buffer.runtime.get_comm_stream()) + else: + custom_stream = paddle.device.current_stream() + with paddle.device.stream_guard(custom_stream): + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + # output_combine_grad quant to fp8 + output_combine_grad_fp8, output_combine_grad_scale = paddle.incubate.nn.functional.fp8_quant_blockwise( + output_combine_grad, output_scale_transpose=False, quant_method="1x128", input_transpose=False + ) + output_grad._record_stream() + quant_event = None + if event_to_wait is not None: + quant_event = deep_ep.get_event_from_custom_stream(custom_stream.stream_base) + return (output_combine_grad_fp8, output_combine_grad_scale), quant_event + else: + output_combine_grad = paddle.reshape(output_grad, [-1, output_grad.shape[-1]]) + return output_combine_grad, None + + +class FusionMlpNode: + """ + The FusedMoeLayer class includes operations for unzipping, expert computation, and zipping. + """ + + def __init__( + self, + custom_map, + max_topk, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + ): + self.token_dispatcher = custom_map.token_dispatcher + self.experts = custom_map.experts + self.unzip_node = UnZipNode() + self.zip_node = ZipNode() + self.experts_group_gemm_node = FP8GroupGemmMlpFunctionNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + + self.seq_length = custom_map.config.seq_length + self.num_experts_per_tok = custom_map.config.num_experts_per_tok + self.adaptive_remained_O1_recompute_ratio = custom_map.config.adaptive_remained_O1_recompute_ratio + + self.recompute_fwd_gate_up = recompute_fwd_gate_up + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = max_topk + self.mlp_fwd_subbatch_rows = mlp_fwd_subbatch_rows + self.mlp_bwd_subbatch_rows = mlp_bwd_subbatch_rows + self.output_subbatch_rows = output_subbatch_rows + + def set_recompute_fwd_gate_up(self, recompute_fwd_gate_up): + self.experts_group_gemm_node.recompute_fwd_gate_up = recompute_fwd_gate_up + + def reset_statue(self): + """ + 重置所有状态变量。 + + Args: + 无。 + + Returns: + 无。 + + """ + self.dispatched_indices = None + self.dispatched_probs = None + self.tokens_per_expert = None + self.padding_token_per_experts = None + self.router_topk = None + + del self.unzip_node + del self.zip_node + self.unzip_node = None + self.zip_node = None + + self.experts_group_gemm_node.reset_statue() + self.experts_group_gemm_node = None + + def prepare_env_subbatch(self, unzipped_tokens=None, unzipped_tokens_scale=None, is_fwd=True): + if is_fwd: + assert unzipped_tokens is not None and unzipped_tokens_scale is not None + self.experts_group_gemm_node.input_fp8 = unzipped_tokens + self.experts_group_gemm_node.input_scale = unzipped_tokens_scale + self.m_indices = self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + self.experts_group_gemm_node.fwd_subbatch = True + else: + self.m_indices = ( + self.experts_group_gemm_node.gen_m_indices(self.padding_token_per_experts) + if not hasattr(self, "m_indices") + else self.m_indices + ) + self.experts_group_gemm_node.bwd_subbatch = True + reload(self.experts_group_gemm_node.input_fp8) + reload(self.experts_group_gemm_node.input_scale) + + def gemm_forward_subbatch( + self, + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + ): + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = unzipped_tokens.shape[0] + start_idx = max(0, start_idx) + end_idx = min(unzipped_tokens.shape[0], end_idx) + + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens[start_idx:end_idx], unzipped_tokens_scale[start_idx:end_idx]), + unzipped_probs[start_idx:end_idx], + padding_token_per_experts, + m_indices=self.m_indices[start_idx:end_idx], + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + expert_out, + map_unzipped_indices_to_zipped[start_idx:end_idx], + total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + return output + + def gemm_backward_subbatch( + self, + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=None, + end_idx=None, + output_subbatch_rows=None, + reset_status=False, + ): + def split_list_prefix(l, start, end): + prefix_sum = [0] * (len(l) + 1) + for i in range(len(l)): + prefix_sum[i + 1] = prefix_sum[i] + l[i] + + result = [] + for i in range(len(l)): + segment_start = prefix_sum[i] + segment_end = prefix_sum[i + 1] + overlap_start = max(start, segment_start) + overlap_end = min(end, segment_end) + selected = max(0, overlap_end - overlap_start) + result.append(selected) + return result + + if start_idx is None or end_idx is None: + start_idx = 0 + end_idx = extract_first_if_tuple(unzipped_grad).shape[0] + + start_idx = max(0, start_idx) + end_idx = min(extract_first_if_tuple(unzipped_grad).shape[0], end_idx) + + # m_indices = self.experts_group_gemm_node.gen_m_indices(self.tokens_per_expert) + unzipped_inp_grad = ( + (unzipped_grad[0][start_idx:end_idx].contiguous(), unzipped_grad[1][start_idx:end_idx].contiguous()) + if isinstance(unzipped_grad, tuple) + else unzipped_grad[start_idx:end_idx].contiguous() + ) + unzipped_grad, unzipped_probs_grad = self.experts_group_gemm_node.backward( + unzipped_inp_grad, + self.unzipped_probs[start_idx:end_idx].contiguous(), + input_fp8_slice=self.experts_group_gemm_node.input_fp8[start_idx:end_idx].contiguous(), + input_scale_slice=self.experts_group_gemm_node.input_scale[start_idx:end_idx].contiguous(), + tokens_per_expert=split_list_prefix(padding_token_per_experts, start_idx, end_idx), + m_indices=self.m_indices[start_idx:end_idx].contiguous(), + reset_status=reset_status, + ) + + output = tokens_zip_unique_add_with_subbatch( + output, + unzipped_grad, + map_unzipped_indices_to_zipped[start_idx:end_idx], + zipped_rows=total_zipped_tokens, + subbatch_rows=output_subbatch_rows, + ) + + return output, unzipped_probs_grad + + @paddle.no_grad() + def forward(self, hs_2d_dispatched, dispatched_indices, dispatched_probs): + """ + 对输入数据进行前向传播计算。 + + Args: + hs_fp8_dispatched (Tensor): 表示被分派到各个专家的输入数据。 + dispatched_indices (Tensor):表示输入数据被分派到的专家索引。 + dispatched_probs (Tensor): 表示输入数据被分派到各个专家的概率。 + + Returns: + Tensor: 经过前向传播计算后的输出数据。 + + """ + self.tokens_per_expert = self.token_dispatcher._comm_manager.tokens_per_expert + self.dispatched_probs = dispatched_probs + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + self.padding_token_per_experts = padding_token_per_experts + # 1 unzip + self.dispatched_indices = dispatched_indices.to(paddle.int32) + if DSV3_USE_FP8_DISPATCH: + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_tokens_scale, + ) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hs_2d_dispatched) + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + total_unzipped_tokens = extract_first_if_tuple(unzipped_tokens).shape[0] + total_zipped_tokens = extract_first_if_tuple(hs_2d_dispatched).shape[0] + + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + total_unzipped_tokens + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + # logger.debug(f"recompute_fwd_gate_up changed to True, Because the receives {unzipped_tokens.shape[0]} Tensors greater then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(True) + else: + # logger.debug(f"recompute_fwd_gate_up changed to False, Because the receives {unzipped_tokens.shape[0]} Tensors less then {self.seq_length*self.num_experts_per_tok*self.adaptive_remained_O1_recompute_ratio}.") + self.set_recompute_fwd_gate_up(False) + + self.unzipped_probs = unzipped_probs.unsqueeze(-1) + + # if use_mlp_subbatch is enabled, then split the unzipped_tokens into subbatches + if self.mlp_fwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_fwd_subbatch_rows * 2: + assert ( + self.experts_group_gemm_node.recompute_fwd_gate_up + ), "recompute_fwd_gate_up must be true when use_mlp_subbatch = True" + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hs_2d_dispatched), + zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_fwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, extract_first_if_tuple(hs_2d_dispatched).shape[-1]], dtype=paddle.float32) + self.prepare_env_subbatch(unzipped_tokens, unzipped_tokens_scale, True) + logger.info( + f"Enable subbatch_forward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output = self.gemm_forward_subbatch( + unzipped_tokens, + unzipped_tokens_scale, + unzipped_probs, + map_unzipped_indices_to_zipped, + output, + total_zipped_tokens, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + ) + + output = merge_subbatch_cast(output, paddle.bfloat16) + output.stop_gradient = False + offload(self.experts_group_gemm_node.input_fp8) + offload(self.experts_group_gemm_node.input_scale) + return output + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + (unzipped_tokens, unzipped_tokens_scale), unzipped_probs, padding_token_per_experts + ) + else: + (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, _,) = self.unzip_node.forward( + hs_2d_dispatched, + self.dispatched_indices, + dispatched_probs, + topk=self.router_topk, + num_experts=num_experts, + tokens_per_expert=self.tokens_per_expert, + ) + hs_2d_dispatched._record_stream() + dispatched_indices._record_stream() + dispatched_probs._record_stream() + + # If adaptive O1 recompute is enabled, determine whether to enable recompute O1 based on the degree of imbalance + if self.recompute_fwd_gate_up == -1: + if ( + unzipped_tokens.shape[0] + > self.seq_length * self.num_experts_per_tok * self.adaptive_remained_O1_recompute_ratio + ): + self.set_recompute_fwd_gate_up(True) + else: + self.set_recompute_fwd_gate_up(False) + + # 2 experts + expert_out = self.experts_group_gemm_node.forward( + unzipped_tokens, unzipped_probs, padding_token_per_experts + ) + + # 3 zip + if isinstance(hs_2d_dispatched, tuple): + hs_2d_dispatched[0]._clear_to_zero_allocation() + hs_2d_dispatched[1]._clear_to_zero_allocation() + else: + hs_2d_dispatched._clear_to_zero_allocation() + expert_out_tmp = expert_out.reshape([-1, expert_out.shape[-1]]) + + expert_out_zipped = self.zip_node.forward( + expert_out_tmp, + zipped_expertwise_rowmap, + self.dispatched_indices, + unzipped_probs, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + + expert_out_zipped.stop_gradient = False + return expert_out_zipped + + @paddle.no_grad() + def backward(self, hidden_states_out_grad): + """ + 反向传播函数。 + + Args: + hidden_states_out_grad_fp8 (Tensor): 隐藏状态梯度。 + + Returns: + Tuple[Tensor, Tensor]: 包含两个元素,分别为hs_fp8_dispatched_grad和dispatched_probs_grad。 + - hs_fp8_dispatched_grad (Tensor): 解压后的隐藏状态梯度。 + - dispatched_probs_grad (Tensor): 分发概率梯度。 + + """ + # zip_grad + unzipped_grad = self.zip_node.backward( + hidden_states_out_grad, + self.dispatched_indices, + self.dispatched_probs, + top_k=self.router_topk, + num_experts=len(self.tokens_per_expert), + tokens_per_expert=self.tokens_per_expert, + ) + record_stream_for_multi_input(hidden_states_out_grad) + + total_zipped_tokens = extract_first_if_tuple(hidden_states_out_grad).shape[0] + total_unzipped_tokens = extract_first_if_tuple(unzipped_grad).shape[0] + hidden_states_size = extract_first_if_tuple(hidden_states_out_grad).shape[-1] + num_experts = len(self.tokens_per_expert) + padding_token_per_experts = [(x + 127) // 128 * 128 for x in self.tokens_per_expert] + + if self.mlp_bwd_subbatch_rows != 0 and total_unzipped_tokens > self.mlp_bwd_subbatch_rows * 2: + map_unzipped_indices_to_zipped = TDU.tokens_unzip_slice( + extract_first_if_tuple(hidden_states_out_grad), + self.unzip_node.zipped_expertwise_rowmap, + num_experts, + total_unzipped_tokens, + 0, + total_unzipped_tokens + 1, + ) + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + subbatch_rows = min((total_unzipped_tokens // num_experts) // 128 * 128, self.mlp_bwd_subbatch_rows) + nparts = (total_unzipped_tokens + subbatch_rows - 1) // subbatch_rows + output = paddle.empty([0, hidden_states_size], dtype=paddle.float32) + probs_grad_list = [] + self.prepare_env_subbatch(is_fwd=False) + logger.info( + f"Enable subbatch_backward!! total_zipped_tokens:{total_zipped_tokens}, total_unzipped_tokens:{total_unzipped_tokens}, nparts:{nparts}, subbatch_rows:{subbatch_rows}, output_sub_rows:{self.output_subbatch_rows}" + ) + for i in range(nparts): + reset_status = True if i == nparts - 1 else False # release saved status in the last part. + start_idx = i * subbatch_rows + end_idx = min(start_idx + subbatch_rows, total_unzipped_tokens) + output, probs_grad = self.gemm_backward_subbatch( + unzipped_grad, + map_unzipped_indices_to_zipped, + total_zipped_tokens, + output, + padding_token_per_experts, + start_idx=start_idx, + end_idx=end_idx, + output_subbatch_rows=self.output_subbatch_rows, + reset_status=reset_status, + ) + probs_grad_list.append(probs_grad) + if isinstance(unzipped_grad, tuple): + unzipped_grad[0]._clear_to_zero_allocation() + unzipped_grad[1]._clear_to_zero_allocation() + else: + unzipped_grad._clear_to_zero_allocation() + hs_dispatched_grad = merge_subbatch_cast(output, paddle.bfloat16) + dispatched_probs_grad = TDU.tokens_zip_prob_seq_subbatch( + probs_grad_list, self.unzip_node.zipped_expertwise_rowmap, self.dispatched_indices, subbatch_rows + ) + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + if isinstance(hidden_states_out_grad, tuple): + hidden_states_out_grad[0]._clear_to_zero_allocation() + hidden_states_out_grad[1]._clear_to_zero_allocation() + else: + hidden_states_out_grad._clear_to_zero_allocation() + + # expert_grad + expert_out, probs_grad = self.experts_group_gemm_node.backward( + unzipped_grad, self.unzipped_probs, padding_token_per_experts + ) + + hs_dispatched_grad, dispatched_probs_grad = self.unzip_node.backward( + expert_out, + total_zipped_tokens, + probs_grad, + self.dispatched_indices, + num_experts=num_experts, + ) + + self.reset_statue() + return hs_dispatched_grad, dispatched_probs_grad + + +class FusionMoeNode: + def __init__( + self, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + mlp_fwd_subbatch_rows=0, + mlp_bwd_subbatch_rows=0, + output_subbatch_rows=0, + name="fusion_moe_node", + ): + self.token_dispatcher = custom_map.token_dispatcher + self.moe_router_topk = custom_map.moe_router_topk + self.dispatch_quant_node = Fp8DispatchQuantNode(self.token_dispatcher) + self.dispatch_node = Fp8DispatchNode(self.token_dispatcher) + self.mlp_node = FusionMlpNode( + custom_map, + self.moe_router_topk, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + mlp_fwd_subbatch_rows=mlp_fwd_subbatch_rows, + mlp_bwd_subbatch_rows=mlp_bwd_subbatch_rows, + output_subbatch_rows=output_subbatch_rows, + ) + self.combine_node = Fp8CombineNode(self.token_dispatcher) + self.combine_quant_node = Fp8CombineQuantNode(self.token_dispatcher, custom_map.moe_group) + self.name = name + + @paddle.no_grad() + def forward(self, hidden_states, probs, routing_map): + if DSV3_USE_FP8_DISPATCH: + (hs_fp8, hs_scale), token_indices, token_probs = self.dispatch_quant_node.forward( + hidden_states, probs, routing_map + ) + ( + (hs_fp8_dispatched, hs_scale_dispatched), + dispatched_indices, + dispatched_probs, + ) = self.dispatch_node.forward((hs_fp8, hs_scale), token_indices, token_probs) + hidden_states_out = self.mlp_node.forward( + (hs_fp8_dispatched, hs_scale_dispatched), dispatched_indices, dispatched_probs + ) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + else: + hs_2d_dispatched, dispatched_indices, dispatched_probs = self.dispatch_node.forward( + hidden_states, probs, routing_map + ) + hidden_states_out = self.mlp_node.forward(hs_2d_dispatched, dispatched_indices, dispatched_probs) + output_combine = self.combine_node.forward(hidden_states_out) + output = self.combine_quant_node.forward(output_combine) + output.stop_gradient = False + return output + + @paddle.no_grad() + def backward(self, output_grad): + output_combine_grad, _ = self.combine_quant_node.backward(output_grad) + hidden_states_out_grad = self.combine_node.backward(output_combine_grad) + + hs_dispatched_grad, dispatched_probs_grad = self.mlp_node.backward(hidden_states_out_grad) + + if DSV3_USE_FP8_DISPATCH: + hs_fp8_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + hs_grad, probs_grad, routing_map_grad = self.dispatch_quant_node.backward(hs_fp8_grad, token_probs_grad) + return hs_grad, probs_grad, routing_map_grad + else: + hs_bf16_grad, token_probs_grad = self.dispatch_node.backward(hs_dispatched_grad, dispatched_probs_grad) + return hs_bf16_grad, None, token_probs_grad + + +class FusionMoe(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx, + hidden_states, + probs, + routing_map, + custom_map, + recompute_fwd_gate_up=False, + is_split_group_gemm=True, + ): + ctx.node = FusionMoeNode( + custom_map, + recompute_fwd_gate_up=recompute_fwd_gate_up, + is_split_group_gemm=is_split_group_gemm, + ) + return ctx.node.forward(hidden_states, probs, routing_map) + + @staticmethod + def backward(ctx, output_grad): + return ctx.node.backward(output_grad) class MoEFlexTokenLayer(nn.Layer): def __init__(self, config, moe_num_experts, expert_class, expert_kwargs, gate, moe_group): diff --git a/paddleformers/transformers/moe_utils.py b/paddleformers/transformers/moe_utils.py index 466591b0638..4c5f3390e86 100644 --- a/paddleformers/transformers/moe_utils.py +++ b/paddleformers/transformers/moe_utils.py @@ -17,6 +17,51 @@ from typing import Optional import paddle +import numpy as np +import TokenDispatcherUtils as TDU + +from .fp8_utils import FP8LinearFunctionBase + +if not hasattr(paddle.Tensor, "_clear_to_zero_allocation"): + + def _clear_to_zero_allocation(self): + """ + _clear_to_zero_allocation + """ + old_shape = self.shape + dst = paddle.empty([0], dtype=self.dtype) + dst_t = dst.value().get_tensor() + src_t = self.value().get_tensor() + src_t._share_data_with(dst_t) + src_t._set_dims(old_shape) + + setattr(paddle.Tensor, "_clear_to_zero_allocation", _clear_to_zero_allocation) + + +if not hasattr(paddle.Tensor, "_holder_size"): + + def _holder_size(self): + """ + _holder_size + """ + if self._is_initialized(): + return int(np.prod(self.shape)) * paddle.core.size_of_dtype(self.dtype) + else: + return 0 + + setattr(paddle.Tensor, "_holder_size", _holder_size) + + +def topk_to_permuted_indices(x, num_tokens_per_expert_list, topk): + x = paddle.flatten(x) + prob_permuted_indices = paddle.concat( + [ + paddle.tensor.search._restrict_nonzero(x == i, total_true_num) + for i, total_true_num in enumerate(num_tokens_per_expert_list) + ] + ).flatten() + token_permuted_indices = prob_permuted_indices // topk + return token_permuted_indices, prob_permuted_indices def permute( @@ -99,3 +144,340 @@ def unpermute( include_self=True, ) return output_tokens + +class UnZipNode: + def __init__(self, name="unzip"): + self.name = name + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + def reset_statue(self): + self.unzipped_probs = None + self.zipped_expertwise_rowmap = None + + @paddle.no_grad() + def forward( + self, + hs_2d_dispatched, + dispatched_indices, + dispatched_probs, + topk, + num_experts, + tokens_per_expert, + ): + if isinstance(hs_2d_dispatched, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched[0], + hs_2d_dispatched[1], + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_tokens, + zipped_expertwise_rowmap, + unzipped_probs, + unzipped_scale, + ) = paddle.nn.functional.moe_permute( + hs_2d_dispatched, + None, + dispatched_indices, + dispatched_probs, + num_experts=num_experts, + tokens_per_expert=tokens_per_expert, + padding_alignment=128, + ) + self.unzipped_probs = unzipped_probs + self.zipped_expertwise_rowmap = zipped_expertwise_rowmap + return (unzipped_tokens, zipped_expertwise_rowmap, unzipped_probs, unzipped_scale) + + @paddle.no_grad() + def backward(self, dx, total_zipped_tokens, probs_grad, dispatched_indices, num_experts): + with paddle.amp.auto_cast(False): + weighted_zipped_tokens, probs_grad_zipped = paddle.nn.functional.moe_unpermute( + dx, + self.zipped_expertwise_rowmap, + dispatched_indices, + probs_grad, + total_zipped_tokens=total_zipped_tokens, + num_experts=num_experts, + ) + self.reset_statue() + return weighted_zipped_tokens, probs_grad_zipped + + +class ZipNode: + def __init__(self, name="zip"): + self.name = name + + @paddle.no_grad() + def forward( + self, expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ): + with paddle.amp.auto_cast(False): + expert_out_zipped, zipped_probs_topk = paddle.nn.functional.moe_unpermute( + expert_out, zipped_expertwise_rowmap, routemap_topk, unzipped_probs, total_zipped_tokens, num_experts + ) + return expert_out_zipped + + @paddle.no_grad() + def backward( + self, + grad_output, + dispatched_indices, + dispatched_probs, + top_k, + num_experts, + tokens_per_expert, + ): + if isinstance(grad_output, tuple): + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output[0], + grad_output[1], + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + return (unzipped_grad, unzipped_scale_grad) + else: + with paddle.amp.auto_cast(False): + ( + unzipped_grad, + zipped_expertwise_rowmap_grad, + unzipped_probs_grad, + unzipped_scale_grad, + ) = paddle.nn.functional.moe_permute( + grad_output, + None, + dispatched_indices, + dispatched_probs, + num_experts, + tokens_per_expert, + padding_alignment=128, + ) + + return unzipped_grad + + +class PermuteNode: + def __init__(self, token_dispatcher, name="permute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.prob_permuted_indices = None + + def forward(self, hidden_states, hidden_states_scale, dispatched_indices): + self.token_dispatcher._comm_manager.hidden_shape_before_permute = hidden_states.shape + self.hidden_shape_before_permute = hidden_states.shape + self.token_permuted_indices, self.prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, + self.token_dispatcher._comm_manager.tokens_per_expert, + self.token_dispatcher._comm_manager.router_topk, + ) + hidden_states = permute(hidden_states, self.token_permuted_indices) + # permute scale + hidden_states_scale = permute(hidden_states_scale, self.token_permuted_indices) + + return hidden_states, hidden_states_scale, self.token_permuted_indices, self.prob_permuted_indices + + def backward(self, out_grad, dispatched_probs): + input_dtype = out_grad.dtype + hidden_states_grad = unpermute( + permuted_tokens=out_grad, + token_permuted_indices=self.token_permuted_indices, + prob_permuted_indices=self.prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + self.reset_status() + return hidden_states_grad.to(input_dtype) + + +class UnPermuteNode: + def __init__(self, token_dispatcher, name="unpermute"): + self.token_dispatcher = token_dispatcher + self.name = name + + def reset_status(self): + self.token_permuted_indices = None + self.hidden_states = None + self.prob_permuted_indices = None + self.faltten_dispatched_probs = None + self.hidden = None + self.permuted_tokens = None + self.output_tokens = None + + def forward( + self, + hidden_states, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ): + self.token_permuted_indices = token_permuted_indices + self.input_dtype = hidden_states.dtype + self.hidden_states = hidden_states + self.prob_permuted_indices = prob_permuted_indices + self.dispatched_probs_shape = dispatched_probs.shape + # permute + _, self.hidden = self.token_dispatcher._comm_manager.hidden_shape_before_permute + + self.faltten_dispatched_probs = dispatched_probs.flatten() + + self.permuted_probs = paddle.gather(self.faltten_dispatched_probs, self.prob_permuted_indices) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + # Create an output tensor filled with zeros + output_tokens = paddle.zeros( + self.token_dispatcher._comm_manager.hidden_shape_before_permute, dtype=self.hidden_states.dtype + ) + # Scatter add the permuted_input back to the original positions + output_tokens.put_along_axis_( + axis=0, + indices=self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + values=permuted_tokens, + reduce="add", + include_self=True, + ) + with paddle.base.device_guard("cpu"): + self.output_tokens = paddle.empty(shape=output_tokens.shape, dtype=output_tokens.dtype) + + return output_tokens.to(self.input_dtype) + + def backward(self, out_grad, out_grad_scale): + hidden_states_grad = paddle.gather(out_grad, self.token_permuted_indices) + + output_tokens_grad = FP8LinearFunctionBase.dequantize_fp8_to_fp32(out_grad, out_grad_scale) + permuted_tokens = self.hidden_states * self.permuted_probs.unsqueeze(-1) + permuted_tokens = permuted_tokens.cast(self.hidden_states.dtype) + + _, permuted_tokens_grad = paddle._C_ops.put_along_axis_grad( + self.output_tokens, + self.token_permuted_indices.cast("int32").unsqueeze(1).expand([-1, self.hidden]), + permuted_tokens, + self.output_tokens, + output_tokens_grad, + 0, + "add", + True, + ) + + permuted_probs_grad = (permuted_tokens_grad * self.hidden_states).sum(axis=-1) + + faltten_dispatched_probs_grad = paddle._C_ops.gather_grad( + self.faltten_dispatched_probs, self.prob_permuted_indices, permuted_probs_grad, 0 + ) + + # dispatched_probs_grad = paddle._C_ops.flatten_grad(self.dispatched_probs, faltten_dispatched_probs_grad) + dispatched_probs_grad = faltten_dispatched_probs_grad.reshape(self.dispatched_probs_shape) + + self.reset_status() + return hidden_states_grad, dispatched_probs_grad + + +def tokens_zip_unique_add_with_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows=None): + """ + tokens_zip_unique_add_with_subbatch + """ + if subbatch_rows is None or subbatch_rows <= 0 or zipped_rows <= 0: + return TDU.tokens_zip_unique_add(zipped, unzipped, index_unzipped, zipped_rows) + else: + if isinstance(zipped, paddle.Tensor): + num_split = (zipped_rows + subbatch_rows - 1) // subbatch_rows + remainder = zipped_rows % subbatch_rows + if remainder == 0: + rows = [subbatch_rows] * num_split + else: + rows = [subbatch_rows] * (num_split - 1) + [remainder] + + if zipped.shape[0] == 0: + dtype = zipped.dtype + hidden_size = zipped.shape[1] + zipped = [paddle.zeros([r, hidden_size], dtype=dtype) for r in rows] + else: + zipped = paddle.split(zipped, rows, axis=0) + return TDU.tokens_zip_unique_add_subbatch(zipped, unzipped, index_unzipped, zipped_rows, subbatch_rows) + + +def merge_subbatch_cast(x, dtype): + if isinstance(x, (list, tuple)): + if len(x) == 1: + x = x[0] + return x.cast(dtype) if x.dtype != dtype else x + else: + return TDU.merge_subbatch_cast(x, dtype) + else: + return x.cast(dtype) if x.dtype != dtype else x + + +def get_env_device(): + """ + Return the device name of running environment. + """ + if paddle.is_compiled_with_cuda(): + return "gpu" + elif "npu" in paddle.device.get_all_custom_device_type(): + return "npu" + elif "mlu" in paddle.device.get_all_custom_device_type(): + return "mlu" + elif "gcu" in paddle.device.get_all_custom_device_type(): + return "gcu" + elif "intel_hpu" in paddle.device.get_all_custom_device_type(): + return "intel_hpu" + elif paddle.is_compiled_with_rocm(): + return "rocm" + elif paddle.is_compiled_with_xpu(): + return "xpu" + return "cpu" + + +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + + +def offload(tensor): + if paddle.is_compiled_with_cuda(): + place = paddle.CUDAPinnedPlace() + else: + place = paddle.CPUPlace() + + new_tensor = to_device(tensor, place) + assert new_tensor is tensor, "to_device must be inplace operation" + + +def reload(tensor): + new_tensor = to_device(tensor) + assert new_tensor is tensor, "to_device must be inplace operation" \ No newline at end of file diff --git a/paddleformers/transformers/token_dispatcher.py b/paddleformers/transformers/token_dispatcher.py index 128f6e52f4d..30a93e7de53 100644 --- a/paddleformers/transformers/token_dispatcher.py +++ b/paddleformers/transformers/token_dispatcher.py @@ -21,7 +21,7 @@ from paddle.distributed.communication.group import Group from .fused_a2a import fused_combine, fused_dispatch -from .moe_utils import permute, unpermute +from .moe_utils import permute, topk_to_permuted_indices, unpermute class _DispatchManager(ABC): @@ -127,7 +127,7 @@ def dispatch(self, hidden_states: paddle.Tensor) -> paddle.Tensor: self.dispatched_indices = states["dispatched_indices"] self.dispatched_probs = dispatched_probs - return hidden_states + return hidden_states, dispatched_indices, dispatched_probs def _indices_to_multihot(self, indices, probs): """ @@ -193,6 +193,34 @@ def get_restored_hidden_states_by_experts(self, hidden_states: paddle.Tensor) -> ) return hidden_states.to(input_dtype) + def get_permuted_hidden_states_by_experts_fast( + self, hidden_states: paddle.Tensor, dispatched_indices: paddle.Tensor + ) -> paddle.Tensor: + self.hidden_shape_before_permute = hidden_states.shape + token_permuted_indices, prob_permuted_indices = topk_to_permuted_indices( + dispatched_indices, self.tokens_per_expert, self.router_topk + ) + hidden_states = permute(hidden_states, token_permuted_indices) + return hidden_states, token_permuted_indices, prob_permuted_indices + + def get_restored_hidden_states_by_experts_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices: paddle.Tensor, + prob_permuted_indices: paddle.Tensor, + dispatched_probs: paddle.Tensor, + ) -> paddle.Tensor: + input_dtype = hidden_states.dtype + assert dispatched_probs.dtype == paddle.float32, "DeepEP only supports float32 probs" + hidden_states = unpermute( + permuted_tokens=hidden_states, + token_permuted_indices=token_permuted_indices, + prob_permuted_indices=prob_permuted_indices, + restore_shape=self.hidden_shape_before_permute, + probs=dispatched_probs, + ) + return hidden_states.to(input_dtype) + class MoETokenDispatcher: """ @@ -260,6 +288,34 @@ def __init__(self, num_local_experts: int, moe_router_topk: int, num_moe_experts num_local_experts=self.num_local_experts, ) + def pre_dispatch(self, hidden_states, probs, routing_map): + self.hidden_shape = hidden_states.shape + hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) + num_tokens = routing_map.shape[0] + routing_map = routing_map.reshape([num_tokens, self._comm_manager.num_experts]) + probs = probs.reshape([num_tokens, self._comm_manager.num_experts]) + # Convert the format of routing map from multihot to indices. + token_probs, token_indices = paddle.topk(probs, self._comm_manager.router_topk, axis=-1) + return hidden_states, token_indices, token_probs + + def post_dispatch(self, hidden_states, dispatched_indices): + ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + ) = self._comm_manager.get_permuted_hidden_states_by_experts_fast(hidden_states, dispatched_indices) + return (global_input_tokens, token_permuted_indices, prob_permuted_indices) + + def pre_combine(self, hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs): + hidden_states = self._comm_manager.get_restored_hidden_states_by_experts_fast( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + return hidden_states + + def post_combine(self, hidden_states): + hidden_states = hidden_states.reshape(self.hidden_shape) + return hidden_states + def token_permutation( self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -267,7 +323,7 @@ def token_permutation( hidden_states = hidden_states.view([-1, self.hidden_shape[-1]]) self._comm_manager.setup_metadata(routing_map, probs) - hidden_states = self._comm_manager.dispatch(hidden_states) + hidden_states, _, _ = self._comm_manager.dispatch(hidden_states) global_input_tokens = self._comm_manager.get_permuted_hidden_states_by_experts(hidden_states) tokens_per_expert = self._comm_manager.get_number_of_tokens_per_expert() @@ -282,3 +338,79 @@ def token_unpermutation( hidden_states = hidden_states.reshape(self.hidden_shape) return hidden_states, None + + def token_permutation_fast( + self, hidden_states: paddle.Tensor, probs: paddle.Tensor, routing_map: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + hidden_states, token_indices, token_probs = self.pre_dispatch(hidden_states, probs, routing_map) + hidden_states, dispatched_indices, dispatched_probs = self._comm_manager.dispatch( + hidden_states, token_indices, token_probs + ) + (global_input_tokens, token_permuted_indices, prob_permuted_indices) = self.post_dispatch( + hidden_states, dispatched_indices + ) + + return ( + global_input_tokens, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + ) + + def token_unpermutation_fast( + self, + hidden_states: paddle.Tensor, + token_permuted_indices, + prob_permuted_indices, + dispatched_probs, + bias: Optional[paddle.Tensor] = None, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: + assert bias is None, "Bias is not supported in MoEFlexTokenDispatcher" + hidden_states = self.pre_combine( + hidden_states, token_permuted_indices, prob_permuted_indices, dispatched_probs + ) + hidden_states = self._comm_manager.combine(hidden_states) + + hidden_states = self.post_combine(hidden_states) + return hidden_states, None + + +class PreDispatchNode: + def __init__(self, token_dispatcher): + self.token_dispatcher = token_dispatcher + self.probs_origin_shape = None + + def reset_status(self): + self.probs = None + self.reshaped_probs = None + self.token_indices = None + + @paddle.no_grad() + def forward(self, routing_map, probs): + num_tokens = routing_map.shape[0] + self.probs_origin_shape = probs.shape + # routing_map = routing_map.reshape([num_tokens, token_dispatcher._comm_manager.num_experts]) + self.probs = probs + reshaped_probs = probs.reshape([num_tokens, self.token_dispatcher._comm_manager.num_experts]) + self.reshaped_probs = reshaped_probs + token_probs, token_indices = paddle.topk( + reshaped_probs, self.token_dispatcher._comm_manager.router_topk, axis=-1 + ) + self.token_indices = token_indices + token_probs.stop_gradient = False + return token_indices, token_probs + + @paddle.no_grad() + def backward(self, token_probs_g): + probs_grad = paddle._C_ops.topk_grad( + self.reshaped_probs, + self.token_indices, + token_probs_g, + self.token_dispatcher._comm_manager.router_topk, + -1, + True, + True, + ) + probs_reshape_g = paddle._C_ops.reshape_grad(self.probs, probs_grad) + self.reset_status() + return probs_reshape_g \ No newline at end of file diff --git a/paddleformers/transformers/utils.py b/paddleformers/transformers/utils.py index 83c85fc147f..219e4e1d8b6 100644 --- a/paddleformers/transformers/utils.py +++ b/paddleformers/transformers/utils.py @@ -1005,3 +1005,9 @@ def caculate_llm_per_token_flops( # 2 for mul + add in matmul # 1 for forward, 2 for backwards since we caluate gradients for input_x and input_y return 2 * (layer_num * (flops_per_transformer * 3 + flops_recompute_transformer) + 3 * flops_loggits) / seq_length + +def cast_if_needed(x, dtype): + """ + cast_if_needed + """ + return x.cast(dtype) if x.dtype != dtype else x diff --git a/paddleformers/utils/download/download.py b/paddleformers/utils/download/download.py index bcc2e5bde70..f36e40f4bff 100644 --- a/paddleformers/utils/download/download.py +++ b/paddleformers/utils/download/download.py @@ -44,6 +44,7 @@ class DownloadSource(str, Enum): HUGGINGFACE = "huggingface" AISTUDIO = "aistudio" MODELSCOPE = "modelscope" + BOS = "bos" MODEL_MAPPINGS = {} @@ -64,6 +65,7 @@ def check_repo(model_name_or_path, download_hub): DownloadSource.HUGGINGFACE, DownloadSource.AISTUDIO, DownloadSource.MODELSCOPE, + DownloadSource.BOS, ], f"download_hub must be one of {DownloadSource.HUGGINGFACE}, {DownloadSource.AISTUDIO}, {DownloadSource.MODELSCOPE}" if model_name_or_path not in HF_MODEL_MAPPINGS.keys(): # repo id set by user @@ -87,6 +89,88 @@ def strtobool(v): f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)." ) +from .aistudio_hub_download import ( + aistudio_hub_download, + aistudio_hub_file_exists, + aistudio_hub_try_to_load_from_cache, +) +from .bos_download import bos_download, bos_file_exists, bos_try_to_load_from_cache + + +def bos_aistudio_hf_file_exist( + repo_id: str, + filename: str, + *, + subfolder: Optional[str] = None, + repo_type: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + endpoint: Optional[str] = None, + from_bos: bool = True, + from_aistudio: bool = False, + from_hf_hub: bool = False, +): + assert repo_id is not None, "repo_id cannot be None" + assert filename is not None, "filename cannot be None" + + if subfolder is None: + subfolder = "" + filename = os.path.join(subfolder, filename) + if from_aistudio: + out = aistudio_hub_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=endpoint, + ) + elif from_hf_hub: + out = hf_hub_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, + ) + else: + out = bos_file_exists( + repo_id=repo_id, + filename=filename, + repo_type=repo_type, + revision=revision, + token=token, # donot need token + endpoint=endpoint, + ) + return out + +def bos_aistudio_hf_try_to_load_from_cache( + repo_id: str, + filename: str, + cache_dir: Union[str, Path, None] = None, + subfolder: str = None, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + from_bos: bool = True, + from_aistudio: bool = False, + from_hf_hub: bool = False, +): + if subfolder is None: + subfolder = "" + load_kwargs = dict( + repo_id=repo_id, + filename=os.path.join(subfolder, filename), + cache_dir=cache_dir, + revision=revision, + repo_type=repo_type, + ) + if from_aistudio: + return aistudio_hub_try_to_load_from_cache(**load_kwargs) + elif from_hf_hub: + return hf_hub_try_to_load_from_cache(**load_kwargs) + else: + return bos_try_to_load_from_cache(**load_kwargs) + def resolve_file_path( repo_id: str = None, @@ -132,7 +216,6 @@ def resolve_file_path( if isinstance(filenames, str): filenames = [filenames] - # check repo id if download_hub is None: download_hub = os.environ.get("DOWNLOAD_SOURCE", "huggingface") @@ -238,6 +321,28 @@ def resolve_file_path( ) if cached_file is not None: return cached_file + else: + log_endpoint = "BOS" + for filename in filenames: + download_kwargs["filename"] = filename + is_available = bos_aistudio_hf_file_exist( + repo_id, + filename, + subfolder=subfolder, + repo_type=repo_type, + revision=revision, + token=token, + endpoint=endpoint, + from_bos=True, + from_aistudio=False, + from_hf_hub=False, + ) + if is_available: + cached_file = bos_download( + **download_kwargs, + ) + if cached_file is not None: + return cached_file except LocalEntryNotFoundError: raise EnvironmentError( "Cannot find the requested files in the cached path and"