Skip to content

Commit 4badb71

Browse files
authored
[cherry-pick]examples update yaml training config (#2644) (#2665)
Co-authored-by: llbdyiu66 <[email protected]>
1 parent 4a9bc53 commit 4badb71

File tree

10 files changed

+213
-164
lines changed

10 files changed

+213
-164
lines changed

examples/README.md

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ export DOWNLOAD_SOURCE=aistudio
1919

2020
### Paddle 权重使用说明
2121

22-
使用 **Paddle** 格式权重,需要在配置文件(如 `sft_full.json``sft_lora.json`等)中手动添加以下参数,以避免与 **HuggingFace** 格式冲突:
22+
使用 **Paddle** 格式权重,需要在配置文件(如 `sft_full.yaml``sft_lora.yaml`等)中手动添加以下参数,以避免与 **HuggingFace** 格式冲突:
2323

24-
```json
25-
"model_name_or_path": "your_model_name",
26-
"convert_from_hf": false,
27-
"save_to_hf": false,
24+
```yaml
25+
model_name_or_path: your_model_name_or_path
26+
convert_from_hf: false
27+
save_to_hf: false
2828
```
2929
3030
@@ -55,19 +55,19 @@ tar -xvf alpaca_demo.gz
5555

5656
单卡
5757
```bash
58-
python -u run_finetune.py ./config/sft_full.json
58+
python -u run_finetune.py ./config/sft_full.yaml
5959
```
6060

6161
多卡
6262
```bash
63-
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/sft_full.json
63+
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" run_finetune.py ./config/sft_full.yaml
6464
```
6565

6666
### 1.3 LoRA SFT
6767

6868
LoRA SFT 启动命令参考
6969
```bash
70-
python -u run_finetune.py ./config/sft_lora.json
70+
python -u run_finetune.py ./config/sft_lora.yaml
7171
```
7272

7373

@@ -109,19 +109,19 @@ tar -zxvf ultrafeedback_binarized.tar.gz
109109

110110
单卡
111111
```bash
112-
python -u ./alignment/dpo/run_dpo.py ./config/dpo_full.json
112+
python -u ./alignment/dpo/run_dpo.py ./config/dpo_full.yaml
113113
```
114114

115115
多卡
116116
```bash
117-
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/dpo_full.json
117+
python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./alignment/dpo/run_dpo.py ./config/dpo_full.yaml
118118
```
119119

120120
### 2.3 LoRA DPO
121121

122122
LoRA DPO 启动命令参考
123123
```bash
124-
python -u ./alignment/dpo/run_dpo.py ./config/dpo_lora.json
124+
python -u ./alignment/dpo/run_dpo.py ./config/dpo_lora.yaml
125125
```
126126

127127

examples/alignment/dpo/run_dpo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def main():
7676
parser = PdArgumentParser((DPOModelArgument, DPODataArgument, DPOTrainingArguments, DPOConfig))
7777
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
7878
model_args, data_args, training_args, dpo_config = parser.parse_json_file_and_cmd_lines()
79+
elif len(sys.argv) >= 2 and sys.argv[1].endswith(".yaml"):
80+
model_args, data_args, training_args, dpo_config = parser.parse_yaml_file_and_cmd_lines()
7981
else:
8082
model_args, data_args, training_args, dpo_config = parser.parse_args_into_dataclasses()
8183

examples/config/dpo_full.json

Lines changed: 0 additions & 37 deletions
This file was deleted.

examples/config/dpo_full.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
### data
2+
train_dataset_type: erniekit
3+
eval_dataset_type: erniekit
4+
train_dataset_path: ./data/dpo/train.jsonl
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: ./data/dpo/dev.jsonl
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 8192
9+
num_samples_each_epoch: 6000000
10+
packing: false
11+
mix_strategy: concat
12+
13+
### model
14+
model_name_or_path: Qwen/Qwen3-0.6B-Base
15+
attn_impl: flashmask
16+
17+
### finetuning
18+
# base
19+
seed: 23
20+
do_train: true
21+
do_eval: true
22+
per_device_eval_batch_size: 1
23+
per_device_train_batch_size: 1
24+
num_train_epochs: 1
25+
max_steps: -1
26+
eval_steps: 100
27+
evaluation_strategy: steps
28+
save_steps: 100
29+
save_total_limit: 1
30+
save_strategy: steps
31+
logging_steps: 1
32+
gradient_accumulation_steps: 4
33+
logging_dir: ./vdl_log
34+
output_dir: ./checkpoints/qwen3_hf_0p6b_dpo_ckpts
35+
disable_tqdm: true
36+
eval_accumulation_steps: 16
37+
38+
# train
39+
warmup_steps: 20
40+
learning_rate: 1.0e-6
41+
42+
# performance
43+
tensor_parallel_degree: 1
44+
pipeline_parallel_degree: 1
45+
sharding: stage2
46+
recompute: true
47+
bf16: true
48+
fp16_opt_level: O2
49+
unified_checkpoint: true

examples/config/dpo_lora.json

Lines changed: 0 additions & 39 deletions
This file was deleted.

examples/config/dpo_lora.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
### data
2+
train_dataset_type: erniekit
3+
eval_dataset_type: erniekit
4+
train_dataset_path: ./data/dpo/train.jsonl
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: ./data/dpo/dev.jsonl
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 8192
9+
num_samples_each_epoch: 6000000
10+
packing: false
11+
mix_strategy: concat
12+
13+
### model
14+
model_name_or_path: Qwen/Qwen3-0.6B-Base
15+
attn_impl: flashmask
16+
lora: true
17+
lora_rank: 8
18+
19+
### finetuning
20+
# base
21+
seed: 23
22+
do_train: true
23+
do_eval: true
24+
per_device_eval_batch_size: 1
25+
per_device_train_batch_size: 1
26+
num_train_epochs: 1
27+
max_steps: -1
28+
eval_steps: 100
29+
evaluation_strategy: steps
30+
save_steps: 100
31+
save_total_limit: 1
32+
save_strategy: steps
33+
logging_steps: 1
34+
gradient_accumulation_steps: 4
35+
logging_dir: ./vdl_log
36+
output_dir: ./checkpoints/qwen3_hf_0p6b_dpo_lora_ckpts
37+
disable_tqdm: true
38+
eval_accumulation_steps: 16
39+
40+
# train
41+
warmup_steps: 20
42+
learning_rate: 1.0e-5
43+
44+
# performance
45+
tensor_parallel_degree: 1
46+
pipeline_parallel_degree: 1
47+
sharding: stage2
48+
recompute: true
49+
bf16: true
50+
fp16_opt_level: O2
51+
unified_checkpoint: true

examples/config/sft_full.json

Lines changed: 0 additions & 38 deletions
This file was deleted.

examples/config/sft_full.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
### data
2+
train_dataset_type: erniekit
3+
eval_dataset_type: erniekit
4+
train_dataset_path: ./data/sft/train.json
5+
train_dataset_prob: "1.0"
6+
eval_dataset_path: ./data/sft/dev.json
7+
eval_dataset_prob: "1.0"
8+
max_seq_len: 8192
9+
num_samples_each_epoch: 6000000
10+
packing: false
11+
mix_strategy: concat
12+
13+
### model
14+
model_name_or_path: Qwen/Qwen3-0.6B-Base
15+
attn_impl: flashmask
16+
17+
### finetuning
18+
# base
19+
seed: 23
20+
do_train: true
21+
do_eval: true
22+
per_device_eval_batch_size: 1
23+
per_device_train_batch_size: 1
24+
num_train_epochs: 1
25+
max_steps: -1
26+
eval_steps: 100
27+
evaluation_strategy: steps
28+
save_steps: 100
29+
save_total_limit: 1
30+
save_strategy: steps
31+
logging_steps: 1
32+
gradient_accumulation_steps: 4
33+
logging_dir: ./vdl_log
34+
output_dir: ./checkpoints/qwen3_hf_0p6b_sft_ckpts
35+
disable_tqdm: true
36+
eval_accumulation_steps: 16
37+
38+
# train
39+
warmup_steps: 20
40+
learning_rate: 1.0e-5
41+
42+
# performance
43+
tensor_parallel_degree: 1
44+
pipeline_parallel_degree: 1
45+
sharding: stage2
46+
recompute: true
47+
bf16: true
48+
fp16_opt_level: O2
49+
unified_checkpoint: true

examples/config/sft_lora.json

Lines changed: 0 additions & 39 deletions
This file was deleted.

0 commit comments

Comments
 (0)