@@ -223,6 +223,10 @@ def in_auto_parallel_align_mode():
223
223
224
224
__all__ = ["Trainer" ]
225
225
226
+ FLEX_CKPT_MODEL_STATE_DIR_NAME = "model_state"
227
+ FLEX_CKPT_OPT_STATE_DIR_NAME = "optimizer_states"
228
+ FLEC_CKPT_MASTER_WEIGHTS_INDEX_NAME = "master_weights"
229
+
226
230
227
231
class Trainer :
228
232
"""
@@ -929,13 +933,13 @@ def train(
929
933
self ._memory_tracker .start ()
930
934
931
935
if not self .args .enable_auto_parallel :
932
- if not self .args .should_load_sharding_stage1_model and not self .args .using_flex_checkpoint :
936
+ if not self .args .should_load_sharding_stage1_model and not self .args .load_flex_checkpoint :
933
937
self ._load_from_checkpoint (resume_from_checkpoint )
934
938
935
939
if self .args .should_load_sharding_stage1_model :
936
940
model = self ._wrap_model_and_load_sharded_checkpoint (resume_from_checkpoint )
937
941
938
- elif self .args .should_save_sharding_stage1_model :
942
+ elif self .args .should_save_sharding_stage1_model and not self . args . load_flex_checkpoint :
939
943
# In the non-sharded mode, should invoke _load_from_checkpoint before _wrap_model.
940
944
# In this mode, the rank0 load all params and the _wrap_model implicitly broadcast params from rank0 to the other ranks.
941
945
model = self ._wrap_model (self .model_wrapped )
@@ -949,36 +953,44 @@ def train(
949
953
if delay_optimizer_creation :
950
954
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
951
955
self ._load_optimizer_and_scheduler (resume_from_checkpoint )
952
- elif not self .args .using_flex_checkpoint :
956
+
957
+ elif self .args .load_flex_checkpoint :
953
958
model = self ._wrap_model (self .model_wrapped )
954
- # for the rest of this function `model` is the outside model, whether it was wrapped or not
955
959
if model is not self .model :
956
960
self .model_wrapped = model
961
+
957
962
if delay_optimizer_creation :
958
963
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
959
- self ._load_optimizer_and_scheduler (resume_from_checkpoint )
960
- else :
961
- assert self .args .using_flex_checkpoint , "default using flex_checkpoint!"
962
964
965
+ if resume_from_checkpoint is not None :
966
+ if not self .args .ignore_load_lr_and_optim :
967
+ model_sharded_state_dict = self .model .sharded_state_dict ()
968
+ accessible_files = os .listdir (resume_from_checkpoint )
969
+ metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
970
+ assert len (metadata_files ) == 1 , "Only support one metadata file now."
971
+ metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
972
+ state_dict_metadata = metadata .state_dict_metadata
973
+ init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
974
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
975
+ sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
976
+ dist .load_state_dict (
977
+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config , offload = False
978
+ )
979
+ self ._load_scheduler (resume_from_checkpoint )
980
+ else :
981
+ model_sharded_state_dict = self .model .sharded_state_dict ()
982
+ sharded_state_dict = model_sharded_state_dict
983
+ dist .load_state_dict (
984
+ sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config
985
+ )
986
+ else :
963
987
model = self ._wrap_model (self .model_wrapped )
988
+ # for the rest of this function `model` is the outside model, whether it was wrapped or not
964
989
if model is not self .model :
965
990
self .model_wrapped = model
966
-
967
991
if delay_optimizer_creation :
968
992
self .create_optimizer_and_scheduler (num_training_steps = max_steps )
969
-
970
- if resume_from_checkpoint is not None :
971
- model_sharded_state_dict = self .model .sharded_state_dict ()
972
- accessible_files = os .listdir (resume_from_checkpoint )
973
- metadata_files = [file for file in accessible_files if file .endswith (".metadata" )]
974
- assert len (metadata_files ) == 1 , "Only support one metadata file now."
975
- metadata = paddle .load (os .path .join (resume_from_checkpoint , metadata_files [0 ]))
976
- state_dict_metadata = metadata .state_dict_metadata
977
- init_optimizer (self .optimizer , model_sharded_state_dict , state_dict_metadata )
978
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
979
- sharded_state_dict = {** model_sharded_state_dict , ** optimizer_sharded_state_dict }
980
- dist .load_state_dict (sharded_state_dict , resume_from_checkpoint , aoa_config = self .args .aoa_config )
981
- self ._load_scheduler (resume_from_checkpoint )
993
+ self ._load_optimizer_and_scheduler (resume_from_checkpoint )
982
994
else :
983
995
model = self .model_wrapped
984
996
if delay_optimizer_creation :
@@ -2738,7 +2750,7 @@ def _save_checkpoint(self, model, metrics=None):
2738
2750
else :
2739
2751
self .save_model (output_dir )
2740
2752
2741
- if self .args .using_flex_checkpoint :
2753
+ if self .args .save_flex_checkpoint :
2742
2754
model_sharded_state_dict = self .model .sharded_state_dict ()
2743
2755
os .makedirs (output_dir , exist_ok = True )
2744
2756
@@ -2801,7 +2813,18 @@ def _save_checkpoint(self, model, metrics=None):
2801
2813
signal_dir ,
2802
2814
)
2803
2815
else :
2804
- if not self .args .using_flex_checkpoint :
2816
+ if self .args .save_flex_checkpoint :
2817
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2818
+ dist .save_state_dict (
2819
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2820
+ output_dir ,
2821
+ )
2822
+ if self .args .should_save :
2823
+ if self .tokenizer is not None and self .args .save_tokenizer :
2824
+ self .tokenizer .save_pretrained (output_dir )
2825
+ # Good practice: save your training arguments together with the trained model
2826
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2827
+ else :
2805
2828
if self .dp_group .rank > 0 : # this should only work for MoE saving
2806
2829
self ._save_ckpt_func (
2807
2830
self ._filter_moe_no_sync_optimizer_params (),
@@ -2821,12 +2844,7 @@ def _save_checkpoint(self, model, metrics=None):
2821
2844
)
2822
2845
else :
2823
2846
self ._save_ckpt_func (state_dict , save_path , saved_signal_path )
2824
- else :
2825
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2826
- dist .save_state_dict (
2827
- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2828
- output_dir ,
2829
- )
2847
+
2830
2848
else :
2831
2849
if self .args .unified_checkpoint and "async_save" in self .args .unified_checkpoint_config :
2832
2850
global_rank = paddle .distributed .get_rank () if paddle .distributed .get_world_size () > 1 else - 1
@@ -2852,7 +2870,18 @@ def _save_checkpoint(self, model, metrics=None):
2852
2870
output_dir ,
2853
2871
signal_dir ,
2854
2872
)
2855
- elif not self .args .using_flex_checkpoint :
2873
+ elif self .args .save_flex_checkpoint :
2874
+ optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2875
+ dist .save_state_dict (
2876
+ {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2877
+ output_dir ,
2878
+ )
2879
+ if self .args .should_save :
2880
+ if self .tokenizer is not None and self .args .save_tokenizer :
2881
+ self .tokenizer .save_pretrained (output_dir )
2882
+ # Good practice: save your training arguments together with the trained model
2883
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2884
+ else :
2856
2885
if self .args .data_parallel_rank > 0 and self .args .use_expert_parallel :
2857
2886
self ._save_ckpt_func (
2858
2887
self ._filter_moe_no_sync_optimizer_params (),
@@ -2866,13 +2895,6 @@ def _save_checkpoint(self, model, metrics=None):
2866
2895
saved_signal_path ,
2867
2896
)
2868
2897
2869
- else :
2870
- optimizer_sharded_state_dict = self .optimizer .sharded_state_dict (model_sharded_state_dict )
2871
- dist .save_state_dict (
2872
- {** model_sharded_state_dict , ** optimizer_sharded_state_dict },
2873
- output_dir ,
2874
- )
2875
-
2876
2898
# FIXME: maybe only save one copy
2877
2899
paddle .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
2878
2900
@@ -2893,6 +2915,18 @@ def _save_checkpoint(self, model, metrics=None):
2893
2915
if self .args .unified_checkpoint and (self .args .offload_optim or self .args .tensorwise_offload_optimizer ):
2894
2916
self ._offload_optimizer ()
2895
2917
2918
+ else :
2919
+ if self .args .save_flex_checkpoint :
2920
+ dist .save_state_dict (
2921
+ model_sharded_state_dict ,
2922
+ output_dir ,
2923
+ )
2924
+ if self .args .should_save :
2925
+ if self .tokenizer is not None and self .args .save_tokenizer :
2926
+ self .tokenizer .save_pretrained (output_dir )
2927
+ # Good practice: save your training arguments together with the trained model
2928
+ paddle .save (self .args , os .path .join (output_dir , TRAINING_ARGS_NAME ))
2929
+
2896
2930
self .runtime_timer .stop ()
2897
2931
2898
2932
# Maybe delete some older checkpoints.
@@ -3107,6 +3141,7 @@ def _save(
3107
3141
else :
3108
3142
if isinstance (self .model , PretrainedModel ) and self .args .should_save_sharding_stage1_model :
3109
3143
config_to_save = None
3144
+ self .sharding_io .set_optimizer (self .optimizer )
3110
3145
state_dict , config_to_save , weight_name_suffix = self .sharding_io .manipulate_state_dict_and_config (
3111
3146
self .model , merge_tensor_parallel = merge_tensor_parallel
3112
3147
)
0 commit comments