12
12
CheckpointLoadConfig ,
13
13
CheckpointLoadMetadataConfig ,
14
14
CheckpointSaveConfig ,
15
+ CheckpointSaveMetadataConfig ,
15
16
DistributedCheckpointFormat ,
16
17
ModelConfigType ,
17
18
export_safetensors_metadata ,
@@ -28,7 +29,13 @@ class DistributedCheckpointHandler(CheckpointHandler):
28
29
format : typing .ClassVar [type [CheckpointFormat ]] = DistributedCheckpointFormat
29
30
30
31
@classmethod
31
- def load_metadata (cls , config : CheckpointLoadMetadataConfig ) -> CheckpointMetadata :
32
+ def save_metadata (cls , config : CheckpointSaveMetadataConfig , metadata : CheckpointMetadata ):
33
+ config .path .mkdir (parents = True , exist_ok = True )
34
+ serialized_metadata = metadata .to_dict ()
35
+ yaml .safe_dump (serialized_metadata , (config .path / "metadata.yaml" ).open ("w" ))
36
+
37
+ @classmethod
38
+ def _load_metadata (cls , config : CheckpointLoadMetadataConfig ) -> CheckpointMetadata :
32
39
return CheckpointMetadata .from_dict (yaml .safe_load ((config .path / "metadata.yaml" ).open ("r" )))
33
40
34
41
def save (self , config : CheckpointSaveConfig , metadata : CheckpointMetadata ) -> None :
@@ -41,17 +48,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
41
48
metadata = export_safetensors_metadata (serialized_metadata ),
42
49
)
43
50
44
- def load (self , config : CheckpointLoadConfig , metadata : CheckpointMetadata ) -> None :
51
+ def load (self , config : CheckpointLoadConfig ) -> dict [ str , typing . Any ] | None :
45
52
# TODO: More safety checks
46
- loaded_config_dict = config .to_copy ({"load_config" : ModelConfigType .fast_llm })
47
- loaded_config = self ._model .config_class .from_metadata (loaded_config_dict , metadata )
53
+ loaded_metadata = self ._model .config .load_metadata (config .to_copy ({"load_config" : ModelConfigType .fast_llm }))
48
54
shard_names = self .get_shard_names (config )
49
55
# Make sure all shards to load are in the checkpoint.
50
- Assert .leq (set (self .get_shard_names (config )), set (metadata .shards ))
51
- Assert .eq (metadata .shards [: len (shard_names )], list (shard_names ))
56
+ Assert .leq (set (self .get_shard_names (config )), set (loaded_metadata .shards ))
57
+ Assert .eq (loaded_metadata .shards [: len (shard_names )], list (shard_names ))
52
58
53
59
# Using `log_fn=bool` sets the output to true if the error list is non-empty.
54
- same_format = config .optimizer_state and not loaded_config .compare (self ._model .config , log_fn = bool )
60
+ same_format = config .optimizer_state and not loaded_metadata . config .compare (self ._model .config , log_fn = bool )
55
61
# Make sure all nodes agree on which loading scheme to use.
56
62
# Note: they may not agree before the broadcast because of the rank comparison, but that's ok.
57
63
same_format = broadcast_scalar (same_format , torch .uint8 , self ._model .distributed .world_group )
@@ -70,7 +76,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
70
76
log_main_rank ("Using legacy distributed checkpoint loader." , log_fn = logger .warning )
71
77
for shard_name in shard_names :
72
78
self ._model .get_shard (shard_name ).copy_ (
73
- f .get_slice ("state_shard" )[metadata .shards .index (shard_name )]
79
+ f .get_slice ("state_shard" )[loaded_metadata .shards .index (shard_name )]
74
80
)
75
81
else :
76
82
# TODO: Does this copy twice?
@@ -79,11 +85,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
79
85
80
86
else :
81
87
log_main_rank ("Checkpoint format doesn't match, using safe load" , log_fn = logger .info )
82
- self ._model .config .base_model .compare_architecture (loaded_config . base_model , config . compare_log_fn )
88
+ self ._model .config .base_model .compare_architecture (loaded_metadata . config . base_model , logger . warning )
83
89
with SafeLoad (self ._model , shard_names = shard_names , timeout = config .timeout ) as context :
84
- for rank in range (loaded_config .distributed .world_size ):
90
+ for rank in range (loaded_metadata . config .distributed .world_size ):
85
91
loaded_model = self ._model .__class__ (
86
- loaded_config .to_copy ({("distributed" , "rank" ): rank }),
92
+ loaded_metadata . config .to_copy ({("distributed" , "rank" ): rank }),
87
93
optimizer_state_names = shard_names [1 :],
88
94
verbose = False ,
89
95
)
@@ -97,7 +103,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
97
103
# TODO v0.3: Use checkpoint version? Drop support?
98
104
log_main_rank ("Using legacy distributed checkpoint loader." , log_fn = logger .warning )
99
105
loaded_shards = {
100
- shard_name : f .get_slice ("state_shard" )[metadata .shards .index (shard_name )]
106
+ shard_name : f .get_slice ("state_shard" )[loaded_metadata .shards .index (shard_name )]
101
107
for shard_name in shard_names
102
108
}
103
109
else :
@@ -122,3 +128,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
122
128
)
123
129
124
130
context .mark_as_loaded (counter .item ())
131
+
132
+ return loaded_metadata .metadata
0 commit comments