Skip to content

Commit 6a90f36

Browse files
committed
add save/load_checkpoint_mode flag
1 parent 6c87ea1 commit 6a90f36

File tree

2 files changed

+162
-45
lines changed

2 files changed

+162
-45
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,10 @@ def in_auto_parallel_align_mode():
223223

224224
__all__ = ["Trainer"]
225225

226+
FLEX_CKPT_MODEL_STATE_DIR_NAME = "model_state"
227+
FLEX_CKPT_OPT_STATE_DIR_NAME = "optimizer_states"
228+
FLEC_CKPT_MASTER_WEIGHTS_INDEX_NAME = "master_weights"
229+
226230

227231
class Trainer:
228232
"""
@@ -929,13 +933,13 @@ def train(
929933
self._memory_tracker.start()
930934

931935
if not self.args.enable_auto_parallel:
932-
if not self.args.should_load_sharding_stage1_model and not self.args.using_flex_checkpoint:
936+
if not self.args.should_load_sharding_stage1_model and not self.args.load_flex_checkpoint:
933937
self._load_from_checkpoint(resume_from_checkpoint)
934938

935939
if self.args.should_load_sharding_stage1_model:
936940
model = self._wrap_model_and_load_sharded_checkpoint(resume_from_checkpoint)
937941

938-
elif self.args.should_save_sharding_stage1_model:
942+
elif self.args.should_save_sharding_stage1_model and not self.args.load_flex_checkpoint:
939943
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940944
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941945
model = self._wrap_model(self.model_wrapped)
@@ -949,36 +953,44 @@ def train(
949953
if delay_optimizer_creation:
950954
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
951955
self._load_optimizer_and_scheduler(resume_from_checkpoint)
952-
elif not self.args.using_flex_checkpoint:
956+
957+
elif self.args.load_flex_checkpoint:
953958
model = self._wrap_model(self.model_wrapped)
954-
# for the rest of this function `model` is the outside model, whether it was wrapped or not
955959
if model is not self.model:
956960
self.model_wrapped = model
961+
957962
if delay_optimizer_creation:
958963
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
959-
self._load_optimizer_and_scheduler(resume_from_checkpoint)
960-
else:
961-
assert self.args.using_flex_checkpoint, "default using flex_checkpoint!"
962964

965+
if resume_from_checkpoint is not None:
966+
if not self.args.ignore_load_lr_and_optim:
967+
model_sharded_state_dict = self.model.sharded_state_dict()
968+
accessible_files = os.listdir(resume_from_checkpoint)
969+
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
970+
assert len(metadata_files) == 1, "Only support one metadata file now."
971+
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
972+
state_dict_metadata = metadata.state_dict_metadata
973+
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
974+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
975+
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
976+
dist.load_state_dict(
977+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config, offload=False
978+
)
979+
self._load_scheduler(resume_from_checkpoint)
980+
else:
981+
model_sharded_state_dict = self.model.sharded_state_dict()
982+
sharded_state_dict = model_sharded_state_dict
983+
dist.load_state_dict(
984+
sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config
985+
)
986+
else:
963987
model = self._wrap_model(self.model_wrapped)
988+
# for the rest of this function `model` is the outside model, whether it was wrapped or not
964989
if model is not self.model:
965990
self.model_wrapped = model
966-
967991
if delay_optimizer_creation:
968992
self.create_optimizer_and_scheduler(num_training_steps=max_steps)
969-
970-
if resume_from_checkpoint is not None:
971-
model_sharded_state_dict = self.model.sharded_state_dict()
972-
accessible_files = os.listdir(resume_from_checkpoint)
973-
metadata_files = [file for file in accessible_files if file.endswith(".metadata")]
974-
assert len(metadata_files) == 1, "Only support one metadata file now."
975-
metadata = paddle.load(os.path.join(resume_from_checkpoint, metadata_files[0]))
976-
state_dict_metadata = metadata.state_dict_metadata
977-
init_optimizer(self.optimizer, model_sharded_state_dict, state_dict_metadata)
978-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
979-
sharded_state_dict = {**model_sharded_state_dict, **optimizer_sharded_state_dict}
980-
dist.load_state_dict(sharded_state_dict, resume_from_checkpoint, aoa_config=self.args.aoa_config)
981-
self._load_scheduler(resume_from_checkpoint)
993+
self._load_optimizer_and_scheduler(resume_from_checkpoint)
982994
else:
983995
model = self.model_wrapped
984996
if delay_optimizer_creation:
@@ -2738,7 +2750,7 @@ def _save_checkpoint(self, model, metrics=None):
27382750
else:
27392751
self.save_model(output_dir)
27402752

2741-
if self.args.using_flex_checkpoint:
2753+
if self.args.save_flex_checkpoint:
27422754
model_sharded_state_dict = self.model.sharded_state_dict()
27432755
os.makedirs(output_dir, exist_ok=True)
27442756

@@ -2801,7 +2813,18 @@ def _save_checkpoint(self, model, metrics=None):
28012813
signal_dir,
28022814
)
28032815
else:
2804-
if not self.args.using_flex_checkpoint:
2816+
if self.args.save_flex_checkpoint:
2817+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2818+
dist.save_state_dict(
2819+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2820+
output_dir,
2821+
)
2822+
if self.args.should_save:
2823+
if self.tokenizer is not None and self.args.save_tokenizer:
2824+
self.tokenizer.save_pretrained(output_dir)
2825+
# Good practice: save your training arguments together with the trained model
2826+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2827+
else:
28052828
if self.dp_group.rank > 0: # this should only work for MoE saving
28062829
self._save_ckpt_func(
28072830
self._filter_moe_no_sync_optimizer_params(),
@@ -2821,12 +2844,7 @@ def _save_checkpoint(self, model, metrics=None):
28212844
)
28222845
else:
28232846
self._save_ckpt_func(state_dict, save_path, saved_signal_path)
2824-
else:
2825-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2826-
dist.save_state_dict(
2827-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2828-
output_dir,
2829-
)
2847+
28302848
else:
28312849
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
28322850
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
@@ -2852,7 +2870,18 @@ def _save_checkpoint(self, model, metrics=None):
28522870
output_dir,
28532871
signal_dir,
28542872
)
2855-
elif not self.args.using_flex_checkpoint:
2873+
elif self.args.save_flex_checkpoint:
2874+
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2875+
dist.save_state_dict(
2876+
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2877+
output_dir,
2878+
)
2879+
if self.args.should_save:
2880+
if self.tokenizer is not None and self.args.save_tokenizer:
2881+
self.tokenizer.save_pretrained(output_dir)
2882+
# Good practice: save your training arguments together with the trained model
2883+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2884+
else:
28562885
if self.args.data_parallel_rank > 0 and self.args.use_expert_parallel:
28572886
self._save_ckpt_func(
28582887
self._filter_moe_no_sync_optimizer_params(),
@@ -2866,13 +2895,6 @@ def _save_checkpoint(self, model, metrics=None):
28662895
saved_signal_path,
28672896
)
28682897

2869-
else:
2870-
optimizer_sharded_state_dict = self.optimizer.sharded_state_dict(model_sharded_state_dict)
2871-
dist.save_state_dict(
2872-
{**model_sharded_state_dict, **optimizer_sharded_state_dict},
2873-
output_dir,
2874-
)
2875-
28762898
# FIXME: maybe only save one copy
28772899
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
28782900

@@ -2893,6 +2915,18 @@ def _save_checkpoint(self, model, metrics=None):
28932915
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
28942916
self._offload_optimizer()
28952917

2918+
else:
2919+
if self.args.save_flex_checkpoint:
2920+
dist.save_state_dict(
2921+
model_sharded_state_dict,
2922+
output_dir,
2923+
)
2924+
if self.args.should_save:
2925+
if self.tokenizer is not None and self.args.save_tokenizer:
2926+
self.tokenizer.save_pretrained(output_dir)
2927+
# Good practice: save your training arguments together with the trained model
2928+
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
2929+
28962930
self.runtime_timer.stop()
28972931

28982932
# Maybe delete some older checkpoints.
@@ -3107,6 +3141,7 @@ def _save(
31073141
else:
31083142
if isinstance(self.model, PretrainedModel) and self.args.should_save_sharding_stage1_model:
31093143
config_to_save = None
3144+
self.sharding_io.set_optimizer(self.optimizer)
31103145
state_dict, config_to_save, weight_name_suffix = self.sharding_io.manipulate_state_dict_and_config(
31113146
self.model, merge_tensor_parallel=merge_tensor_parallel
31123147
)

paddlenlp/trainer/training_args.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -407,10 +407,12 @@ class TrainingArguments:
407407
Whether to release gradients during training. Default is `False`.
408408
ckpt_quant_stage (`str`, *optional*):
409409
Whether activate checkpoint quantization. O0: deactivate, O1: Int8 compression, O2: Int4 compression. (default: O0).
410-
using_flex_checkpoint(`bool`, *optional*):
411-
Whether to use FlexCheckpoint for save and load. Default is False.
412410
aoa_config (`Optional[dict[str, list[str]]]`, *optional*):
413411
The AoA configuration of FlexCheckpoint, used to describe the mapping between model weights and the checkpoint content. Default is None.
412+
save_checkpoint_mode (`str`, *optional*):
413+
Specifies the method for saving checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured.
414+
load_checkpoint_mode (`str`, *optional*):
415+
Specifies the method for loading checkpoints. Options are: None, 'sharding_io', 'unified_checkpoint', 'flex_checkpoint', and 'safetensor'. (default: None). This setting is ignored if the corresponding switch is configured.
414416
"""
415417

416418
output_dir: str = field(
@@ -935,10 +937,6 @@ class TrainingArguments:
935937
default=False,
936938
metadata={"help": "Whether to use async_save instead of paddle.save."},
937939
)
938-
using_flex_checkpoint: Optional[bool] = field(
939-
default=False,
940-
metadata={"help": "Whether use FlexCheckpoint."},
941-
)
942940
ordered_save_group_size: int = field(
943941
default=0,
944942
metadata={
@@ -1111,6 +1109,30 @@ class TrainingArguments:
11111109
},
11121110
)
11131111

1112+
save_checkpoint_mode: Optional[str] = field(
1113+
default=None,
1114+
metadata={
1115+
"help": (
1116+
"Specifies the method used to save checkpoints. "
1117+
"Available options: 'sharding_io', 'unified_checkpoint', "
1118+
"'flex_checkpoint', 'safetensor'."
1119+
"This setting is ignored if the corresponding switch is configured."
1120+
)
1121+
},
1122+
)
1123+
1124+
load_checkpoint_mode: Optional[str] = field(
1125+
default=None,
1126+
metadata={
1127+
"help": (
1128+
"Specifies the method used to load checkpoints. "
1129+
"Available options: 'sharding_io', 'unified_checkpoint', "
1130+
"'flex_checkpoint', 'safetensor'."
1131+
"This setting is ignored if the corresponding switch is configured."
1132+
)
1133+
},
1134+
)
1135+
11141136
def __post_init__(self):
11151137
world_size = paddle.distributed.get_world_size()
11161138
if in_auto_parallel_align_mode():
@@ -1210,6 +1232,8 @@ def __post_init__(self):
12101232
raise ValueError("AdamW Mini currently doesn't support tensor parallelism.")
12111233

12121234
self._post_init_parallel_degree()
1235+
self._post_init_save_checkpoint_mode()
1236+
self._post_init_load_checkpoint_mode()
12131237

12141238
if self.to_static:
12151239
assert world_size == 1 or self.enable_auto_parallel, (
@@ -1862,7 +1886,7 @@ def is_context_parallel_supported():
18621886
# DP use hybrid group
18631887
strategy = fleet.DistributedStrategy()
18641888
fleet.init(is_collective=True, strategy=strategy)
1865-
elif self.using_flex_checkpoint:
1889+
elif self.save_flex_checkpoint or self.load_flex_checkpoint:
18661890
strategy = fleet.DistributedStrategy()
18671891
fleet.init(is_collective=True, strategy=strategy)
18681892
else:
@@ -2131,6 +2155,64 @@ def _post_init_parallel_degree(self):
21312155
if self.use_hybrid_parallel and self.enable_auto_parallel:
21322156
self.use_hybrid_parallel = False
21332157

2158+
def _post_init_save_checkpoint_mode(self):
2159+
if not self.save_checkpoint_mode:
2160+
return
2161+
2162+
# Ensure that only one checkpoint mode is set at a time
2163+
if self.unified_checkpoint or self.save_sharded_model:
2164+
return
2165+
2166+
self.save_flex_checkpoint = False
2167+
2168+
valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"]
2169+
assert (
2170+
self.save_checkpoint_mode in valid_modes
2171+
), f"Invalid save_checkpoint_mode: {self.save_checkpoint_mode}, Only these modes are allowed: {valid_modes}."
2172+
2173+
if self.save_checkpoint_mode == "safetensor":
2174+
raise NotImplementedError("safetensor checkpoint saving is not implemented yet.")
2175+
elif self.save_checkpoint_mode == "unified_checkpoint":
2176+
assert (
2177+
getattr(self, "load_checkpoint_mode", None) == "unified_checkpoint"
2178+
), "When saving in unified_checkpoint mode, load_checkpoint_mode must also be 'unified_checkpoint'."
2179+
self.unified_checkpoint = True
2180+
elif self.save_checkpoint_mode == "sharding_io":
2181+
self.save_sharded_model = True
2182+
elif self.save_checkpoint_mode == "flex_checkpoint":
2183+
self.save_flex_checkpoint = True
2184+
else:
2185+
raise NotImplementedError(f"Checkpoint mode '{self.save_checkpoint_mode}' is not supported.")
2186+
2187+
def _post_init_load_checkpoint_mode(self):
2188+
if not self.load_checkpoint_mode:
2189+
return
2190+
2191+
self.load_flex_checkpoint = False
2192+
2193+
# Ensure that only one checkpoint mode is set at a time
2194+
if self.unified_checkpoint or self.load_sharded_model:
2195+
return
2196+
2197+
valid_modes = ["unified_checkpoint", "sharding_io", "safetensor", "flex_checkpoint"]
2198+
assert (
2199+
self.load_checkpoint_mode in valid_modes
2200+
), f"Invalid load_checkpoint_mode: {self.load_checkpoint_mode}, Only these modes are allowed: {valid_modes}."
2201+
2202+
if self.load_checkpoint_mode == "safetensor":
2203+
raise NotImplementedError("safetensor checkpoint loading is not implemented yet.")
2204+
elif self.load_checkpoint_mode == "unified_checkpoint":
2205+
assert (
2206+
getattr(self, "save_checkpoint_mode", None) == "unified_checkpoint"
2207+
), "When loading in unified_checkpoint mode, save_checkpoint_mode must also be 'unified_checkpoint'."
2208+
self.unified_checkpoint = True
2209+
elif self.load_checkpoint_mode == "sharding_io":
2210+
self.load_sharded_model = True
2211+
elif self.load_checkpoint_mode == "flex_checkpoint":
2212+
self.load_flex_checkpoint = True
2213+
else:
2214+
raise NotImplementedError(f"Checkpoint mode '{self.load_checkpoint_mode}' is not supported.")
2215+
21342216
def add_moe_comm_group(self):
21352217
hybrid_configs = fleet.fleet._user_defined_strategy.hybrid_configs
21362218
hcg = fleet.get_hybrid_communicate_group()
@@ -2459,7 +2541,7 @@ def should_save_model_state(self):
24592541
return True
24602542
elif self.enable_auto_parallel:
24612543
return True
2462-
elif self.using_flex_checkpoint:
2544+
elif self.save_flex_checkpoint:
24632545
return False
24642546
elif self.use_hybrid_parallel:
24652547
# save on dataset rank 0

0 commit comments

Comments
 (0)