Skip to content

Commit cf184b5

Browse files
committed
fix startend_row_indics bug and docs
1 parent 21073db commit cf184b5

File tree

8 files changed

+42
-16
lines changed

8 files changed

+42
-16
lines changed

examples/README.md

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,39 @@
2020
wget https://bj.bcebos.com/paddlenlp/datasets/examples/alpaca_demo.gz
2121
tar -xvf alpaca_demo.gz
2222
```
23+
### 模型下载
24+
```bash
25+
# PaddleNLP/Qwen2-0.5B-Instruct
26+
aistudio download --model PaddleNLP/Qwen2-0.5B-Instruct --local_dir PaddleNLP/Qwen2-0.5B-Instruct
27+
28+
# baidu/ERNIE-4.5-0.3B-PT
29+
aistudio download --model PaddlePaddle/ERNIE-4.5-0.3B-PT --local_dir baidu/ERNIE-4.5-0.3B-PT
30+
31+
# baidu/ERNIE-4.5-0.3B-PT
32+
aistudio download --model PaddlePaddle/ERNIE-4.5-21B-A3B-PT --local_dir baidu/ERNIE-4.5-21B-A3B-PT
33+
```
2334

2435
### 全参精调:SFT
2536

2637
单卡
2738
```bash
28-
# 需要12G显存左右
39+
# 微调Qwen2-0.5B-Instruct 需要12G显存左右
2940
python -u run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json
41+
42+
# 微调ERNIE-4.5-0.3B-PT
43+
python -u run_finetune.py ./config/ernie4_5/sft_argument_ernie4_5_0p3b.json
3044
```
3145

3246
多卡
3347
```bash
48+
# SFT Qwen2-0.5B-Instruct
3449
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/qwen/sft_argument_qwen2_0p5b.json
50+
51+
# SFT ERNIE-4.5-0.3B-PT
52+
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/ernie4_5/sft_argument_ernie4_5_0p3b.json
53+
54+
# SFT ERNIE-4.5-21B-A3B-PT
55+
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json
3556
```
3657

3758
### LoRA

examples/config/ernie4_5/sft_argument_ernie4_5_0p3b.json

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
"max_steps": 100,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,
@@ -33,15 +31,17 @@
3331
"metric_for_best_model": "accuracy",
3432
"recompute": true,
3533
"save_total_limit": 1,
36-
"tensor_parallel_degree": 2,
37-
"pipeline_parallel_degree": 2,
34+
"tensor_parallel_degree": 1,
35+
"pipeline_parallel_degree": 1,
3836
"sharding": "stage2",
3937
"zero_padding": true,
4038
"flash_mask": true,
4139
"unified_checkpoint": true,
4240
"use_flash_attention": true,
43-
"sequence_parallel": true,
41+
"use_attn_mask_startend_row_indices": true,
42+
"sequence_parallel": false,
4443
"report_to": "none",
4544
"convert_from_hf": true,
45+
"save_to_hf": true,
4646
"pp_seg_method": "layer:DecoderLayer|EmptyLayer"
4747
}

examples/config/ernie4_5_moe/sft_argument_ernie4_5_21b_a3b.json

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
"max_steps": 100,
2222
"evaluation_strategy": "epoch",
2323
"save_strategy": "epoch",
24-
"src_length": 1024,
25-
"max_length": 2048,
2624
"bf16": true,
2725
"fp16_opt_level": "O2",
2826
"do_train": true,
@@ -43,5 +41,6 @@
4341
"sequence_parallel": true,
4442
"report_to": "none",
4543
"convert_from_hf": true,
44+
"save_to_hf": true,
4645
"pp_seg_method": "layer:DecoderLayer|EmptyLayer"
4746
}

examples/run_finetune.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,13 @@ def main():
153153
logger.info(f"Final model config: {model_config}")
154154
logger.info("Creating model")
155155

156+
if model_args.flash_mask and model_args.use_attn_mask_startend_row_indices:
157+
model_config._attn_implementation = "flashmask"
158+
elif model_args.flash_mask and not model_args.use_attn_mask_startend_row_indices:
159+
model_config._attn_implementation = "sdpa"
160+
else:
161+
model_config._attn_implementation = "eager"
162+
156163
model_class = AutoModelForCausalLM
157164
if training_args.pipeline_parallel_degree > 1:
158165
if data_args.eval_with_do_generation and training_args.do_eval:
@@ -174,7 +181,6 @@ def main():
174181
logger.warning("`flash_mask` must use with zero padding and flash attention.")
175182
data_args.zero_padding = True
176183
model.config.use_flash_attention = True
177-
model.config._attn_implementation = "flashmask"
178184

179185
if model_args.flash_mask and not any(isinstance(model, cls) for cls in flash_mask_support_list):
180186
raise NotImplementedError(f"{model.__class__} not support flash mask.")

paddleformers/transformers/ernie4_5/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,8 @@ def forward(
834834

835835
# Pretrain & Eval must have labels
836836
assert labels is not None
837-
return self.criterion(logits, labels, loss_mask)
837+
loss, _ = self.criterion(logits, labels, loss_mask)
838+
return loss, logits
838839

839840

840841
class Ernie4_5ForCausalLMPipe(GeneralModelForCausalLMPipe):

paddleformers/transformers/ernie4_5_moe/modeling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,8 @@ def forward(
11571157
# Pretrain & Eval must have labels
11581158
assert labels is not None
11591159

1160-
return self.criterion(logits, labels, loss_mask, router_loss=router_loss, mtp_logits=mtp_logits)
1160+
loss, _ = self.criterion(logits, labels, loss_mask, router_loss=router_loss, mtp_logits=mtp_logits)
1161+
return loss, logits
11611162

11621163

11631164
class Ernie4_5_MoeForCausalLMPipe(GeneralModelForCausalLMPipe):

paddleformers/transformers/llama/fusion_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,14 @@ def fusion_flash_attention(
248248
else:
249249
if attn_mask_startend_row_indices is not None:
250250
assert alibi is None, "flashmask_attention or flash_attention_with_sparse_mask not support alibi"
251-
if len(attn_mask_startend_row_indices.shape) == 2:
252-
attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1)
253251

254252
if hasattr(F, "flashmask_attention"):
255253
attn_output = no_recompute(
256254
F.flashmask_attention,
257255
query_states,
258256
key_states,
259257
value_states,
260-
startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1),
258+
startend_row_indices=attn_mask_startend_row_indices,
261259
causal=True,
262260
enable=skip_recompute,
263261
)

paddleformers/utils/masking_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ def _gen_from_sparse_attn_mask_indices(attn_mask_start_row_indices, dtype):
3131
Returns:
3232
paddle.Tensor: The dense attention mask recovered from attn_mask_start_row_indices.
3333
"""
34-
batch_size, _, max_seq_len = attn_mask_start_row_indices.shape
34+
batch_size, _, max_seq_len, _ = attn_mask_start_row_indices.shape
3535
base = paddle.arange(max_seq_len, dtype="int32").unsqueeze(1).expand([batch_size, -1, max_seq_len]).unsqueeze(1)
36-
mask_indices = attn_mask_start_row_indices.unsqueeze(1)
36+
mask_indices = attn_mask_start_row_indices
3737

3838
tril = paddle.tril(
3939
paddle.ones([max_seq_len, max_seq_len], dtype="bool").expand([batch_size, 1, max_seq_len, max_seq_len])

0 commit comments

Comments
 (0)