@@ -929,13 +929,13 @@ def train(
929
929
self ._memory_tracker .start ()
930
930
931
931
if not self .args .enable_auto_parallel :
932
- if not self .args .should_load_sharding_stage1_model and not self .args .using_flex_checkpoint :
932
+ if not self .args .should_load_sharding_stage1_model and not self .args .load_flex_checkpoint :
933
933
self ._load_from_checkpoint (resume_from_checkpoint )
934
934
935
935
if self .args .should_load_sharding_stage1_model :
936
936
model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
937
937
938
- elif self .args .should_save_sharding_stage1_model :
938
+ elif self .args .should_save_sharding_stage1_model and not self . args . load_flex_checkpoint :
939
939
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940
940
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941
941
model = self ._wrap_model (self .model_wrapped )
@@ -949,36 +949,44 @@ def train(
949
949
if delay_optimizer_creation :
950
950
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
951
951
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
952
- elif not self .args .using_flex_checkpoint :
952
+
953
+ elif self .args .load_flex_checkpoint :
953
954
model = self ._wrap_model (self .model_wrapped )
954
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
955
955
if model is not self .model :
956
956
self .model_wrapped = model
957
+
957
958
if delay_optimizer_creation :
958
959
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
959
- self ._load_optimizer_and_scheduler (resume_from_checkpoint )
960
- else :
961
- assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
962
960
961
+ if resume_from_checkpoint is not None :
962
+ if not self .args .ignore_load_lr_and_optim :
963
+ model_sharded_state_dict = self .model .sharded_state_dict ()
964
+ accessible_files = os .listdir (resume_from_checkpoint )
965
+ metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
966
+ assert len (metadata_files ) == 1 , "Only support one metadata file now."
967
+ metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
968
+ state_dict_metadata = metadata .state_dict_metadata
969
+ init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
970
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
971
+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
972
+ dist .load_state_dict (
973
+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config , offload = False
974
+ )
975
+ self ._load_scheduler (resume_from_checkpoint )
976
+ else :
977
+ model_sharded_state_dict = self .model .sharded_state_dict ()
978
+ sharded_state_dict = model_sharded_state_dict
979
+ dist .load_state_dict (
980
+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config
981
+ )
982
+ else :
963
983
model = self ._wrap_model (self .model_wrapped )
984
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
964
985
if model is not self .model :
965
986
self .model_wrapped = model
966
-
967
987
if delay_optimizer_creation :
968
988
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
969
-
970
- if resume_from_checkpoint is not None :
971
- model_sharded_state_dict = self .model .sharded_state_dict ()
972
- accessible_files = os .listdir (resume_from_checkpoint )
973
- metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
974
- assert len (metadata_files ) == 1 , "Only support one metadata file now."
975
- metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
976
- state_dict_metadata = metadata .state_dict_metadata
977
- init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
978
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
979
- sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
980
- dist .load_state_dict (sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config )
981
- self ._load_scheduler (resume_from_checkpoint )
989
+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
982
990
else :
983
991
model = self .model_wrapped
984
992
if delay_optimizer_creation :
@@ -2738,7 +2746,7 @@ def _save_checkpoint(self, model, metrics=None):
2738
2746
else :
2739
2747
self .save_model (output_dir )
2740
2748
2741
- if self .args .using_flex_checkpoint :
2749
+ if self .args .save_flex_checkpoint :
2742
2750
model_sharded_state_dict = self .model .sharded_state_dict ()
2743
2751
os .makedirs (output_dir , exist_ok = True )
2744
2752
@@ -2801,7 +2809,18 @@ def _save_checkpoint(self, model, metrics=None):
2801
2809
signal_dir ,
2802
2810
)
2803
2811
else :
2804
- if not self .args .using_flex_checkpoint :
2812
+ if self .args .save_flex_checkpoint :
2813
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2814
+ dist .save_state_dict (
2815
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2816
+ output_dir ,
2817
+ )
2818
+ if self .args .should_save :
2819
+ if self .tokenizer is not None and self .args .save_tokenizer :
2820
+ self .tokenizer .save_pretrained (output_dir )
2821
+ # Good practice: save your training arguments together with the trained model
2822
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2823
+ else :
2805
2824
if self .dp_group .rank > 0 : # this should only work for MoE saving
2806
2825
self ._save_ckpt_func (
2807
2826
self ._filter_moe_no_sync_optimizer_params (),
@@ -2821,12 +2840,7 @@ def _save_checkpoint(self, model, metrics=None):
2821
2840
)
2822
2841
else :
2823
2842
self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2824
- else :
2825
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2826
- dist .save_state_dict (
2827
- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2828
- output_dir ,
2829
- )
2843
+
2830
2844
else :
2831
2845
if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2832
2846
global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2852,7 +2866,18 @@ def _save_checkpoint(self, model, metrics=None):
2852
2866
output_dir ,
2853
2867
signal_dir ,
2854
2868
)
2855
- elif not self .args .using_flex_checkpoint :
2869
+ elif self .args .save_flex_checkpoint :
2870
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2871
+ dist .save_state_dict (
2872
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2873
+ output_dir ,
2874
+ )
2875
+ if self .args .should_save :
2876
+ if self .tokenizer is not None and self .args .save_tokenizer :
2877
+ self .tokenizer .save_pretrained (output_dir )
2878
+ # Good practice: save your training arguments together with the trained model
2879
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2880
+ else :
2856
2881
if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
2857
2882
self ._save_ckpt_func (
2858
2883
self ._filter_moe_no_sync_optimizer_params (),
@@ -2866,13 +2891,6 @@ def _save_checkpoint(self, model, metrics=None):
2866
2891
saved_signal_path ,
2867
2892
)
2868
2893
2869
- else :
2870
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2871
- dist .save_state_dict (
2872
- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2873
- output_dir ,
2874
- )
2875
-
2876
2894
# FIXME: maybe only save one copy
2877
2895
paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
2878
2896
@@ -2893,6 +2911,18 @@ def _save_checkpoint(self, model, metrics=None):
2893
2911
if self .args .unified_checkpoint and (self .args .offload_optim or self .args .tensorwise_offload_optimizer ):
2894
2912
self ._offload_optimizer ()
2895
2913
2914
+ else :
2915
+ if self .args .save_flex_checkpoint :
2916
+ dist .save_state_dict (
2917
+ model_sharded_state_dict ,
2918
+ output_dir ,
2919
+ )
2920
+ if self .args .should_save :
2921
+ if self .tokenizer is not None and self .args .save_tokenizer :
2922
+ self .tokenizer .save_pretrained (output_dir )
2923
+ # Good practice: save your training arguments together with the trained model
2924
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2925
+
2896
2926
self .runtime_timer .stop ()
2897
2927
2898
2928
# Maybe delete some older checkpoints.
@@ -3107,6 +3137,7 @@ def _save(
3107
3137
else :
3108
3138
if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
3109
3139
config_to_save = None
3140
+ self .sharding_io .set_optimizer (self .optimizer )
3110
3141
state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
3111
3142
self .model , merge_tensor_parallel = merge_tensor_parallel
3112
3143
)
0 commit comments