Skip to content

Commit 1550bd1

Browse files
Make the specified config parameters update the pretrained config (#211)
Co-authored-by: Torsten Scholak <[email protected]>
1 parent 6ad0a96 commit 1550bd1

File tree

13 files changed

+282
-144
lines changed

13 files changed

+282
-144
lines changed

fast_llm/engine/checkpoint/config.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class CheckpointSaveConfig(CheckpointSaveMetadataConfig, CheckpointStateSaveConf
201201
@config_class()
202202
class CheckpointLoadMetadataConfig(CheckpointPathConfigBase):
203203
_abstract = False
204-
204+
# TODO: Set default to model? (Not backward compatible)
205205
load_config: ModelConfigType = Field(
206206
default=ModelConfigType.architecture,
207207
desc="Configuration to save/load.",
@@ -213,10 +213,6 @@ def _validate(self) -> None:
213213
if self.format.enforce_architecture_match:
214214
assert self.load_config.load_architecture
215215

216-
@property
217-
def compare_log_fn(self):
218-
return ValueError if self.load_config.load_architecture else logger.warning
219-
220216

221217
@config_class()
222218
class CheckpointLoadConfig(CheckpointLoadMetadataConfig, CheckpointStateConfigBase):
@@ -237,19 +233,37 @@ class CheckpointHandler(abc.ABC):
237233
def __init__(self, model: "FastLLMModel"):
238234
self._model = model
239235

240-
# TODO: save_metadata?
241-
242236
@classmethod
243237
@abc.abstractmethod
238+
def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: "CheckpointMetadata"):
239+
pass
240+
241+
@classmethod
244242
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
243+
updates = {}
244+
metadata = cls._load_metadata(config)
245+
if not config.load_config.load_fast_llm:
246+
updates[("config", "multi_stage")] = {}
247+
updates[("config", "distributed")] = {}
248+
if not config.load_config.load_architecture:
249+
updates[("config", "base_model")] = {}
250+
elif not config.load_config.load_base_model:
251+
updates[("config", "base_model")] = metadata.config.base_model.get_architecture().to_dict()
252+
if updates:
253+
metadata = metadata.to_copy(updates)
254+
return metadata
255+
256+
@classmethod
257+
@abc.abstractmethod
258+
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
245259
pass
246260

247261
@abc.abstractmethod
248262
def save(self, config: CheckpointSaveConfig, metadata: "CheckpointMetadata"):
249263
pass
250264

251265
@abc.abstractmethod
252-
def load(self, config: CheckpointLoadConfig, metadata: "CheckpointMetadata"):
266+
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
253267
pass
254268

255269
def get_shard_names(self, config: CheckpointStateConfigBase) -> tuple[str, ...]:

fast_llm/engine/checkpoint/distributed.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
CheckpointLoadConfig,
1313
CheckpointLoadMetadataConfig,
1414
CheckpointSaveConfig,
15+
CheckpointSaveMetadataConfig,
1516
DistributedCheckpointFormat,
1617
ModelConfigType,
1718
export_safetensors_metadata,
@@ -28,7 +29,13 @@ class DistributedCheckpointHandler(CheckpointHandler):
2829
format: typing.ClassVar[type[CheckpointFormat]] = DistributedCheckpointFormat
2930

3031
@classmethod
31-
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
32+
def save_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata):
33+
config.path.mkdir(parents=True, exist_ok=True)
34+
serialized_metadata = metadata.to_dict()
35+
yaml.safe_dump(serialized_metadata, (config.path / "metadata.yaml").open("w"))
36+
37+
@classmethod
38+
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
3239
return CheckpointMetadata.from_dict(yaml.safe_load((config.path / "metadata.yaml").open("r")))
3340

3441
def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
@@ -41,17 +48,16 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
4148
metadata=export_safetensors_metadata(serialized_metadata),
4249
)
4350

44-
def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
51+
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
4552
# TODO: More safety checks
46-
loaded_config_dict = config.to_copy({"load_config": ModelConfigType.fast_llm})
47-
loaded_config = self._model.config_class.from_metadata(loaded_config_dict, metadata)
53+
loaded_metadata = self._model.config.load_metadata(config.to_copy({"load_config": ModelConfigType.fast_llm}))
4854
shard_names = self.get_shard_names(config)
4955
# Make sure all shards to load are in the checkpoint.
50-
Assert.leq(set(self.get_shard_names(config)), set(metadata.shards))
51-
Assert.eq(metadata.shards[: len(shard_names)], list(shard_names))
56+
Assert.leq(set(self.get_shard_names(config)), set(loaded_metadata.shards))
57+
Assert.eq(loaded_metadata.shards[: len(shard_names)], list(shard_names))
5258

5359
# Using `log_fn=bool` sets the output to true if the error list is non-empty.
54-
same_format = config.optimizer_state and not loaded_config.compare(self._model.config, log_fn=bool)
60+
same_format = config.optimizer_state and not loaded_metadata.config.compare(self._model.config, log_fn=bool)
5561
# Make sure all nodes agree on which loading scheme to use.
5662
# Note: they may not agree before the broadcast because of the rank comparison, but that's ok.
5763
same_format = broadcast_scalar(same_format, torch.uint8, self._model.distributed.world_group)
@@ -70,7 +76,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
7076
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
7177
for shard_name in shard_names:
7278
self._model.get_shard(shard_name).copy_(
73-
f.get_slice("state_shard")[metadata.shards.index(shard_name)]
79+
f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)]
7480
)
7581
else:
7682
# TODO: Does this copy twice?
@@ -79,11 +85,11 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
7985

8086
else:
8187
log_main_rank("Checkpoint format doesn't match, using safe load", log_fn=logger.info)
82-
self._model.config.base_model.compare_architecture(loaded_config.base_model, config.compare_log_fn)
88+
self._model.config.base_model.compare_architecture(loaded_metadata.config.base_model, logger.warning)
8389
with SafeLoad(self._model, shard_names=shard_names, timeout=config.timeout) as context:
84-
for rank in range(loaded_config.distributed.world_size):
90+
for rank in range(loaded_metadata.config.distributed.world_size):
8591
loaded_model = self._model.__class__(
86-
loaded_config.to_copy({("distributed", "rank"): rank}),
92+
loaded_metadata.config.to_copy({("distributed", "rank"): rank}),
8793
optimizer_state_names=shard_names[1:],
8894
verbose=False,
8995
)
@@ -97,7 +103,7 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
97103
# TODO v0.3: Use checkpoint version? Drop support?
98104
log_main_rank("Using legacy distributed checkpoint loader.", log_fn=logger.warning)
99105
loaded_shards = {
100-
shard_name: f.get_slice("state_shard")[metadata.shards.index(shard_name)]
106+
shard_name: f.get_slice("state_shard")[loaded_metadata.shards.index(shard_name)]
101107
for shard_name in shard_names
102108
}
103109
else:
@@ -122,3 +128,5 @@ def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> No
122128
)
123129

124130
context.mark_as_loaded(counter.item())
131+
132+
return loaded_metadata.metadata

fast_llm/engine/checkpoint/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def __init__(self, model: "FastLLMModel"):
226226
}
227227

228228
@classmethod
229-
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
229+
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
230230
imported_model_config = cls._import_config(cls._load_config(config.path), True)
231231
return CheckpointMetadata(
232232
fast_llm_version=__version__,

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222

2323
class HuggingfaceStateDictCheckpointHandler(ExternalStateDictCheckpointHandler, abc.ABC):
2424

25-
def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
26-
path = config.path / f"{self.base_file_name}.safetensors.index.json"
25+
@classmethod
26+
def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
27+
config.path.mkdir(parents=True, exist_ok=True)
28+
path = config.path / f"{cls.base_file_name}.safetensors.index.json"
2729
logger.info(f"Saving index to {path}")
2830
# Save the index.
2931
json.dump(
@@ -41,10 +43,11 @@ def _serialize_metadata(self, config: CheckpointSaveMetadataConfig, metadata: Ch
4143
"format": "pt",
4244
}
4345

44-
def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
46+
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
4547
assert not config.optimizer_state
46-
self._model.config.base_model.compare_architecture(metadata.config.base_model, config.compare_log_fn)
47-
super().load(config, metadata)
48+
metadata = self._model.config.load_metadata(config)
49+
self._model.config.base_model.compare_architecture(metadata.config.base_model, logger.warning)
50+
super().load(config)
4851

4952
@classmethod
5053
def get_huggingface_model_type(self) -> str:

fast_llm/engine/checkpoint/state_dict.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@
3030
class StateDictCheckpointHandler(CheckpointHandler):
3131
base_file_name: typing.ClassVar[str] = "model"
3232

33+
@classmethod
34+
def save_metadata(
35+
cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata, index: dict | None = None
36+
):
37+
serialized_metadata = cls._serialize_metadata(config, metadata)
38+
cls._save_serialized_metadata(config, serialized_metadata, {} if index is None else index)
39+
3340
def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
3441
serialized_metadata = self._serialize_metadata(config, metadata)
3542
saver = StateDictSaver(
@@ -64,16 +71,18 @@ def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> No
6471
if self._model.config.distributed.rank == 0:
6572
self._save_serialized_metadata(config, serialized_metadata, index)
6673

74+
@classmethod
6775
@abc.abstractmethod
68-
def _save_serialized_metadata(self, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
76+
def _save_serialized_metadata(cls, config: CheckpointSaveMetadataConfig, metadata: dict, index: dict) -> None:
6977
pass
7078

79+
@classmethod
7180
def _serialize_metadata(
72-
self, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata
81+
cls, config: CheckpointSaveMetadataConfig, metadata: CheckpointMetadata
7382
) -> dict[str, typing.Any]:
7483
return metadata.to_dict()
7584

76-
def load(self, config: CheckpointLoadConfig, metadata: CheckpointMetadata) -> None:
85+
def load(self, config: CheckpointLoadConfig) -> dict[str, typing.Any] | None:
7786
with SafeLoad(self._model, shard_names=self.get_shard_names(config), timeout=config.timeout) as context:
7887
# The tensor mapping may not be one-to-one. `convert_state_dict` pops all tensors from
7988
# `state_dict` that are ready for conversion,
@@ -116,14 +125,16 @@ class FastLLMCheckpointHandler(StateDictCheckpointHandler):
116125
format: typing.ClassVar[type[CheckpointFormat]] = FastLLMCheckpointFormat
117126

118127
@classmethod
119-
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
128+
def _load_metadata(cls, config: CheckpointLoadMetadataConfig) -> CheckpointMetadata:
120129
path = config.path / f"metadata.yaml"
121130
logger.warning(f"Loading metadata from {path}")
122131
return CheckpointMetadata.from_dict(yaml.safe_load(path.open("r")))
123132

133+
@classmethod
124134
def _save_serialized_metadata(
125-
self, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict
135+
cls, config: CheckpointSaveMetadataConfig, serialized_metadata: dict, index: dict
126136
) -> None:
137+
config.path.mkdir(parents=True, exist_ok=True)
127138
path = config.path / f"metadata.yaml"
128139
logger.info(f"Saving metadata to {path}")
129140
if "metadata" not in serialized_metadata:

fast_llm/engine/inference/huggingface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,14 @@ def from_pretrained(
5454
format=FastLLMCheckpointFormat,
5555
)
5656

57-
config_updates = {}
57+
updates = {}
5858
torch_dtype = kwargs.pop("torch_dtype", None)
5959
if torch_dtype is not None:
60-
config_updates[("distributed", "training_dtype")] = torch_dtype
60+
updates[("distributed", "training_dtype")] = torch_dtype
6161

6262
# Create the model
6363
fast_llm_model = cls.runner_class.model_class.from_pretrained(
64-
pretrained_model_name_or_path, config_updates=config_updates, mode=mode
64+
pretrained_model_name_or_path, updates, mode=mode
6565
)
6666
config = cls.config_class(fast_llm_model.config)
6767

fast_llm/engine/multi_stage/config.py

Lines changed: 16 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
Field,
1111
FieldHint,
1212
NoAutoValidate,
13+
UpdateType,
1314
ValidationError,
1415
check_field,
1516
config_class,
@@ -186,11 +187,12 @@ class MultiStageConfig(StageConfig):
186187
def _validate(self) -> None:
187188
super()._validate()
188189
if self.zero_stage is not None:
189-
Assert.in_range_incl(self.zero_stage, 1, 3)
190-
if self.zero_stage >= 2:
191-
self.num_grad_buffers = 2
192-
if self.zero_stage >= 3:
193-
self.num_weight_buffers = 2
190+
with self._set_implicit_default():
191+
Assert.in_range_incl(self.zero_stage, 1, 3)
192+
if self.zero_stage >= 2:
193+
self.num_grad_buffers = 2
194+
if self.zero_stage >= 3:
195+
self.num_weight_buffers = 2
194196
if self.num_grad_buffers is not None:
195197
Assert.geq(self.num_grad_buffers, 1)
196198
if self.num_weight_buffers is not None:
@@ -254,49 +256,13 @@ def get_base_model_config_class(cls) -> type[BaseModelConfig]:
254256

255257
@classmethod
256258
def from_pretrained(
257-
cls,
258-
pretrained: CheckpointLoadMetadataConfig,
259-
default: typing.Self | None = None,
260-
) -> typing.Self:
261-
# TODO: Add *updates?
262-
assert pretrained.path is not None
263-
metadata = cls.load_metadata(pretrained)
264-
return cls.from_metadata(pretrained, metadata, default)
265-
266-
@classmethod
267-
def from_metadata(
268-
cls,
269-
pretrained: CheckpointLoadMetadataConfig,
270-
metadata: "CheckpointMetadata",
271-
default: typing.Self | None = None,
272-
updates: dict[str | tuple[str, ...], typing.Any] | None = None,
259+
cls, pretrained: CheckpointLoadMetadataConfig, *updates: Config | dict[str | tuple[str, ...], typing.Any]
273260
) -> typing.Self:
274-
# TODO: Standardize to *updates?
275-
# TODO v0.3: Update, remove support for older checkpoints.
276-
if metadata.fast_llm_version.major != 0 or metadata.fast_llm_version.minor not in (0, 1, 2):
277-
raise ValueError(f"Invalid checkpoint version: {metadata.fast_llm_version}")
278-
pretrained_config = cls.from_dict(metadata.config)
279-
if not pretrained.load_config.load_architecture:
280-
assert default is not None
281-
config = default.to_copy()
282-
config.base_model.compare_architecture(pretrained_config.base_model, pretrained.compare_log_fn)
283-
elif pretrained.load_config.load_fast_llm:
284-
config = pretrained_config
285-
else:
286-
with NoAutoValidate():
287-
config = cls() if default is None else default.to_copy()
288-
if pretrained.load_config.load_base_model:
289-
config.base_model = pretrained_config.base_model
290-
else:
291-
config.base_model = config.base_model.to_copy(pretrained_config.base_model.get_architecture())
292-
config.validate()
293-
294-
if updates:
295-
config = config.to_copy(updates)
296-
return config
261+
return cls.from_dict(cls.load_metadata(pretrained).config, *updates, update_type=UpdateType.update)
297262

298263
@classmethod
299264
def load_metadata(cls, config: CheckpointLoadMetadataConfig) -> "CheckpointMetadata":
265+
assert config.path is not None
300266
with NoAutoValidate():
301267
metadata = config.format.get_handler_class().load_metadata(config)
302268
try:
@@ -316,6 +282,9 @@ def to_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> "Checkp
316282
**kwargs,
317283
)
318284

285+
def save_metadata(self, config: CheckpointSaveMetadataConfig, **kwargs) -> None:
286+
self.get_checkpoint_handler_class(config.format).save_metadata(config, self.to_metadata(config, **kwargs))
287+
319288

320289
@config_class()
321290
class PretrainedFastLLMModelConfig(Config):
@@ -336,7 +305,7 @@ def _validate(self) -> None:
336305
self.pretrained.setup(self.model)
337306
self.pretrained.validate()
338307
if self.pretrained.path is not None:
339-
self.model = self.model.from_pretrained(self.pretrained, default=self.model)
308+
self.model = self.model.from_pretrained(self.pretrained, self.model)
340309
self._setup()
341310
super()._validate()
342311

@@ -388,6 +357,8 @@ def _validate(self) -> None:
388357

389358
self.format = self.model.get_checkpoint_format(self.format)
390359
super()._validate()
360+
if self.fast_llm_version.major != 0 or self.fast_llm_version.minor not in (0, 1, 2):
361+
raise ValueError(f"Invalid checkpoint version: {self.fast_llm_version}")
391362
Assert.eq(self.config.__class__, self.model)
392363

393364
@classmethod

0 commit comments

Comments
 (0)