Skip to content

Commit cbf4c17

Browse files
committed
add save/load_checkpoint_mode flag
1 parent 6c87ea1 commit cbf4c17

File tree

2 files changed

+158
-45
lines changed

2 files changed

+158
-45
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -929,13 +929,13 @@ def train(
929929
self._memory_tracker.start()
930930

931931
if not self.args.enable_auto_parallel:
932-
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
932+
if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint:
933933
self._load_from_checkpoint(resume_from_checkpoint)
934934

935935
if self.args.should_load_sharding_stage1_model:
936936
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
937937

938-
elif self.args.should_save_sharding_stage1_model:
938+
elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint:
939939
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940940
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941941
model = self._wrap_model(self.model_wrapped)
@@ -949,36 +949,44 @@ def train(
949949
if delay_optimizer_creation:
950950
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
951951
self._load_optimizer_and_scheduler(resume_from_checkpoint)
952-
elif not self.args.using_flex_checkpoint:
952+
953+
elif self.args.load_flex_checkpoint:
953954
model = self._wrap_model(self.model_wrapped)
954-
# for the rest of this function `model` is the outside model, whether it was wrapped or not
955955
if model is not self.model:
956956
self.model_wrapped = model
957+
957958
if delay_optimizer_creation:
958959
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
959-
self._load_optimizer_and_scheduler(resume_from_checkpoint)
960-
else:
961-
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"
962960

961+
if resume_from_checkpoint is not None:
962+
if not self.args.ignore_load_lr_and_optim:
963+
model_sharded_state_dict = self.model.sharded_state_dict()
964+
accessible_files = os.listdir(resume_from_checkpoint)
965+
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
966+
assert len(metadata_files) == 1, "Only support one metadata file now."
967+
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
968+
state_dict_metadata = metadata.state_dict_metadata
969+
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
970+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
971+
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
972+
dist.load_state_dict(
973+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config, offload=False
974+
)
975+
self._load_scheduler(resume_from_checkpoint)
976+
else:
977+
model_sharded_state_dict = self.model.sharded_state_dict()
978+
sharded_state_dict = model_sharded_state_dict
979+
dist.load_state_dict(
980+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
981+
)
982+
else:
963983
model = self._wrap_model(self.model_wrapped)
984+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
964985
if model is not self.model:
965986
self.model_wrapped = model
966-
967987
if delay_optimizer_creation:
968988
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
969-
970-
if resume_from_checkpoint is not None:
971-
model_sharded_state_dict = self.model.sharded_state_dict()
972-
accessible_files = os.listdir(resume_from_checkpoint)
973-
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
974-
assert len(metadata_files) == 1, "Only support one metadata file now."
975-
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
976-
state_dict_metadata = metadata.state_dict_metadata
977-
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
978-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
979-
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
980-
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config)
981-
self._load_scheduler(resume_from_checkpoint)
989+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
982990
else:
983991
model = self.model_wrapped
984992
if delay_optimizer_creation:
@@ -2738,7 +2746,7 @@ def _save_checkpoint(self, model, metrics=None):
27382746
else:
27392747
self.save_model(output_dir)
27402748

2741-
if self.args.using_flex_checkpoint:
2749+
if self.args.save_flex_checkpoint:
27422750
model_sharded_state_dict = self.model.sharded_state_dict()
27432751
os.makedirs(output_dir, exist_ok=True)
27442752

@@ -2801,7 +2809,18 @@ def _save_checkpoint(self, model, metrics=None):
28012809
signal_dir,
28022810
)
28032811
else:
2804-
if not self.args.using_flex_checkpoint:
2812+
if self.args.save_flex_checkpoint:
2813+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2814+
dist.save_state_dict(
2815+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2816+
output_dir,
2817+
)
2818+
if self.args.should_save:
2819+
if self.tokenizer is not None and self.args.save_tokenizer:
2820+
self.tokenizer.save_pretrained(output_dir)
2821+
# Good practice: save your training arguments together with the trained model
2822+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2823+
else:
28052824
if self.dp_group.rank > 0: # this should only work for MoE saving
28062825
self._save_ckpt_func(
28072826
self._filter_moe_no_sync_optimizer_params(),
@@ -2821,12 +2840,7 @@ def _save_checkpoint(self, model, metrics=None):
28212840
)
28222841
else:
28232842
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2824-
else:
2825-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2826-
dist.save_state_dict(
2827-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2828-
output_dir,
2829-
)
2843+
28302844
else:
28312845
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
28322846
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
@@ -2852,7 +2866,18 @@ def _save_checkpoint(self, model, metrics=None):
28522866
output_dir,
28532867
signal_dir,
28542868
)
2855-
elif not self.args.using_flex_checkpoint:
2869+
elif self.args.save_flex_checkpoint:
2870+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2871+
dist.save_state_dict(
2872+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2873+
output_dir,
2874+
)
2875+
if self.args.should_save:
2876+
if self.tokenizer is not None and self.args.save_tokenizer:
2877+
self.tokenizer.save_pretrained(output_dir)
2878+
# Good practice: save your training arguments together with the trained model
2879+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2880+
else:
28562881
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
28572882
self._save_ckpt_func(
28582883
self._filter_moe_no_sync_optimizer_params(),
@@ -2866,13 +2891,6 @@ def _save_checkpoint(self, model, metrics=None):
28662891
saved_signal_path,
28672892
)
28682893

2869-
else:
2870-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2871-
dist.save_state_dict(
2872-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2873-
output_dir,
2874-
)
2875-
28762894
# FIXME: maybe only save one copy
28772895
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
28782896

@@ -2893,6 +2911,18 @@ def _save_checkpoint(self, model, metrics=None):
28932911
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
28942912
self._offload_optimizer()
28952913

2914+
else:
2915+
if self.args.save_flex_checkpoint:
2916+
dist.save_state_dict(
2917+
model_sharded_state_dict,
2918+
output_dir,
2919+
)
2920+
if self.args.should_save:
2921+
if self.tokenizer is not None and self.args.save_tokenizer:
2922+
self.tokenizer.save_pretrained(output_dir)
2923+
# Good practice: save your training arguments together with the trained model
2924+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2925+
28962926
self.runtime_timer.stop()
28972927

28982928
# Maybe delete some older checkpoints.
@@ -3107,6 +3137,7 @@ def _save(
31073137
else:
31083138
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
31093139
config_to_save = None
3140+
self.sharding_io.set_optimizer(self.optimizer)
31103141
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
31113142
self.model, merge_tensor_parallel=merge_tensor_parallel
31123143
)

paddlenlp/trainer/training_args.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,12 @@ class TrainingArguments:
407407
Whether to release gradients during training. Default is `False`.
408408
ckpt_quant_stage (`str`, *optional*):
409409
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
410-
using_flex_checkpoint(`bool`, *optional*):
411-
Whether to use FlexCheckpoint for save and load. Default is False.
412410
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
413411
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
412+
save_checkpoint_mode (`str`, *optional*):
413+
Specifies the method for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured.
414+
load_checkpoint_mode (`str`, *optional*):
415+
Specifies the method for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured.
414416
"""
415417

416418
output_dir: str = field(
@@ -935,10 +937,6 @@ class TrainingArguments:
935937
default=False,
936938
metadata={"help": "Whether to use async_save instead of paddle.save."},
937939
)
938-
using_flex_checkpoint: Optional[bool] = field(
939-
default=False,
940-
metadata={"help": "Whether use FlexCheckpoint."},
941-
)
942940
ordered_save_group_size: int = field(
943941
default=0,
944942
metadata={
@@ -1111,6 +1109,30 @@ class TrainingArguments:
11111109
},
11121110
)
11131111

1112+
save_checkpoint_mode: Optional[str] = field(
1113+
default=None,
1114+
metadata={
1115+
"help": (
1116+
"Specifies the method used to save checkpoints. "
1117+
"Available options: 'sharding_io', 'unified_checkpoint', "
1118+
"'flex_checkpoint', 'safetensor'."
1119+
"This setting is ignored if the corresponding switch is configured."
1120+
)
1121+
},
1122+
)
1123+
1124+
load_checkpoint_mode: Optional[str] = field(
1125+
default=None,
1126+
metadata={
1127+
"help": (
1128+
"Specifies the method used to load checkpoints. "
1129+
"Available options: 'sharding_io', 'unified_checkpoint', "
1130+
"'flex_checkpoint', 'safetensor'."
1131+
"This setting is ignored if the corresponding switch is configured."
1132+
)
1133+
},
1134+
)
1135+
11141136
def __post_init__(self):
11151137
world_size = paddle.distributed.get_world_size()
11161138
if in_auto_parallel_align_mode():
@@ -1210,6 +1232,8 @@ def __post_init__(self):
12101232
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")
12111233

12121234
self._post_init_parallel_degree()
1235+
self._post_init_save_checkpoint_mode()
1236+
self._post_init_load_checkpoint_mode()
12131237

12141238
if self.to_static:
12151239
assert world_size == 1 or self.enable_auto_parallel, (
@@ -1862,7 +1886,7 @@ def is_context_parallel_supported():
18621886
# DP use hybrid group
18631887
strategy = fleet.DistributedStrategy()
18641888
fleet.init(is_collective=True, strategy=strategy)
1865-
elif self.using_flex_checkpoint:
1889+
elif self.save_flex_checkpoint or self.load_flex_checkpoint:
18661890
strategy = fleet.DistributedStrategy()
18671891
fleet.init(is_collective=True, strategy=strategy)
18681892
else:
@@ -2131,6 +2155,64 @@ def _post_init_parallel_degree(self):
21312155
if self.use_hybrid_parallel and self.enable_auto_parallel:
21322156
self.use_hybrid_parallel = False
21332157

2158+
def _post_init_save_checkpoint_mode(self):
2159+
if not self.save_checkpoint_mode:
2160+
return
2161+
2162+
# Ensure that only one checkpoint mode is set at a time
2163+
if self.unified_checkpoint or self.save_sharded_model:
2164+
return
2165+
2166+
self.save_flex_checkpoint = False
2167+
2168+
valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"]
2169+
assert (
2170+
self.save_checkpoint_mode in valid_modes
2171+
), f"Invalid save_checkpoint_mode: {self.save_checkpoint_mode}, Only these modes are allowed: {valid_modes}."
2172+
2173+
if self.save_checkpoint_mode == "safetensor":
2174+
raise NotImplementedError("safetensor checkpoint saving is not implemented yet.")
2175+
elif self.save_checkpoint_mode == "unified_checkpoint":
2176+
assert (
2177+
getattr(self, "load_checkpoint_mode", None) == "unified_checkpoint"
2178+
), "When saving in unified_checkpoint mode, load_checkpoint_mode must also be 'unified_checkpoint'."
2179+
self.unified_checkpoint = True
2180+
elif self.save_checkpoint_mode == "sharding_io":
2181+
self.save_sharded_model = True
2182+
elif self.save_checkpoint_mode == "flex_checkpoint":
2183+
self.save_flex_checkpoint = True
2184+
else:
2185+
raise NotImplementedError(f"Checkpoint mode '{self.save_checkpoint_mode}' is not supported.")
2186+
2187+
def _post_init_load_checkpoint_mode(self):
2188+
if not self.load_checkpoint_mode:
2189+
return
2190+
2191+
self.load_flex_checkpoint = False
2192+
2193+
# Ensure that only one checkpoint mode is set at a time
2194+
if self.unified_checkpoint or self.load_sharded_model:
2195+
return
2196+
2197+
valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"]
2198+
assert (
2199+
self.load_checkpoint_mode in valid_modes
2200+
), f"Invalid load_checkpoint_mode: {self.load_checkpoint_mode}, Only these modes are allowed: {valid_modes}."
2201+
2202+
if self.load_checkpoint_mode == "safetensor":
2203+
raise NotImplementedError("safetensor checkpoint loading is not implemented yet.")
2204+
elif self.load_checkpoint_mode == "unified_checkpoint":
2205+
assert (
2206+
getattr(self, "save_checkpoint_mode", None) == "unified_checkpoint"
2207+
), "When loading in unified_checkpoint mode, save_checkpoint_mode must also be 'unified_checkpoint'."
2208+
self.unified_checkpoint = True
2209+
elif self.load_checkpoint_mode == "sharding_io":
2210+
self.load_sharded_model = True
2211+
elif self.load_checkpoint_mode == "flex_checkpoint":
2212+
self.load_flex_checkpoint = True
2213+
else:
2214+
raise NotImplementedError(f"Checkpoint mode '{self.load_checkpoint_mode}' is not supported.")
2215+
21342216
def add_moe_comm_group(self):
21352217
hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs
21362218
hcg = fleet.get_hybrid_communicate_group()
@@ -2459,7 +2541,7 @@ def should_save_model_state(self):
24592541
return True
24602542
elif self.enable_auto_parallel:
24612543
return True
2462-
elif self.using_flex_checkpoint:
2544+
elif self.save_flex_checkpoint:
24632545
return False
24642546
elif self.use_hybrid_parallel:
24652547
# save on dataset rank 0

0 commit comments

Comments
 (0)