Skip to content

Commit 6ad0a96

Browse files
authored
improvements to MTP implementation (#218)
1 parent f12b70a commit 6ad0a96

File tree

12 files changed

+1383
-57
lines changed

12 files changed

+1383
-57
lines changed

fast_llm/data/dataset/gpt/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class GPTSamplingParameters(SamplingParameters):
7474
vocab_size: int
7575
use_loss_masking_spans: bool = False
7676
cross_document_attention: bool = True
77+
# How many extra tokens to add to the sequence length.
78+
# This is used to provide labels even for the last tokens in the sequence.
79+
extra_tokens: int = 1
7780

7881

7982
@dataclasses.dataclass(kw_only=True)
@@ -258,7 +261,7 @@ def build(self) -> SamplableDataset:
258261
return config.build()
259262

260263
def _load_config(self):
261-
assert self.path.is_file()
264+
assert self.path.is_file(), f"File {self.path} does not exist."
262265
return GPTSampledDatasetConfig.from_dict(self._convert_paths(yaml.safe_load(self.path.open("r"))))
263266

264267
def _convert_paths(self, config):

fast_llm/data/dataset/gpt/sampled.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,17 +145,19 @@ def _sample(self) -> None:
145145
raise RuntimeError(
146146
f" > No documents shorter than {self._parameters.sequence_length+1} tokens found in dataset {self._indexed_dataset.name}."
147147
)
148-
# TODO MTP: Produce more labels to provide labels for the multi-token prediction heads?
149-
# We produce sequences of length `self._sequence_length + 1` so the last token has a label,
150-
# but in case of truncations we also include that last label in the following sample,
151-
# so we need `sequence_length * num_samples + 1` tokens in total.
152-
num_epochs = math.ceil(
153-
(
154-
(self._parameters.sequence_length + 1 - self._truncate_documents) * self._parameters.num_samples
155-
+ 1 * self._truncate_documents
148+
# We produce sequences of length `self._sequence_length + extra_tokens` so the last token has a label for all prediction heads,
149+
# but in case of truncations we also include those last labels in the following sample,
150+
# so we need `sequence_length * num_samples + extra_tokens` tokens in total.
151+
if self._truncate_documents:
152+
num_epochs = math.ceil(
153+
(self._parameters.sequence_length * self._parameters.num_samples + self._parameters.extra_tokens)
154+
/ tokens_per_epoch
155+
)
156+
else:
157+
num_epochs = math.ceil(
158+
((self._parameters.sequence_length + self._parameters.extra_tokens) * self._parameters.num_samples)
159+
/ tokens_per_epoch
156160
)
157-
/ tokens_per_epoch
158-
)
159161

160162
# Prepare for shuffling.
161163
generator = torch.Generator(device=self._device)
@@ -349,8 +351,13 @@ def __getitem__(self, index: int) -> typing.Any:
349351
self._lazy_load()
350352
# tokens at the boundary are included in only one sample when we pack without truncations
351353
# in case of packing with truncations, the last token from the previous sample is also the first token of the next sample
352-
token_start = index * (self._parameters.sequence_length + 1 - self._truncate_documents)
353-
token_end = token_start + self._parameters.sequence_length + 1
354+
sample_length = (
355+
self._parameters.sequence_length
356+
if self._truncate_documents
357+
else self._parameters.sequence_length + self._parameters.extra_tokens
358+
)
359+
token_start = index * sample_length
360+
token_end = token_start + self._parameters.sequence_length + self._parameters.extra_tokens
354361

355362
if token_start < self._unshuffled_tokens:
356363
token_start_array = self._token_cumsum_unshuffled.array
@@ -410,7 +417,9 @@ def __getitem__(self, index: int) -> typing.Any:
410417
if self._parameters.use_loss_masking_spans:
411418
for loss_masking_span in sample.loss_masking_spans:
412419
span = np.clip(
413-
loss_masking_span + token_count - token_start, 0, self._parameters.sequence_length + 1
420+
loss_masking_span + token_count - token_start,
421+
0,
422+
self._parameters.sequence_length + self._parameters.extra_tokens,
414423
)
415424
if span[1] > span[0]:
416425
loss_masking_spans.append(span)
@@ -430,7 +439,7 @@ def __getitem__(self, index: int) -> typing.Any:
430439
if self._parameters.use_loss_masking_spans
431440
else None
432441
)
433-
Assert.eq(len(token_ids), self._parameters.sequence_length + 1)
442+
Assert.eq(len(token_ids), self._parameters.sequence_length + self._parameters.extra_tokens)
434443

435444
return GPTSample(token_ids=token_ids, loss_masking_spans=loss_masking_spans, sequence_lengths=sequence_lengths)
436445

fast_llm/engine/checkpoint/huggingface.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
import abc
22
import json
33
import pathlib
4+
import shutil
45
import typing
56

67
import safetensors
78
import torch
9+
from transformers.configuration_utils import PretrainedConfig
810

9-
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveMetadataConfig
11+
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, CheckpointSaveConfig, CheckpointSaveMetadataConfig
1012
from fast_llm.engine.checkpoint.external import (
1113
ConstantExportParamConverter,
1214
ExternalStateDictCheckpointHandler,
@@ -118,3 +120,33 @@ def _load_weights(
118120
yield from torch.load(path)
119121
else:
120122
raise NotImplementedError(f"Unknown file format for {path}")
123+
124+
125+
class CustomModelingExportMixin:
126+
"""
127+
Mixin class for HuggingfaceStateDictCheckpointHandler to handle custom modeling files.
128+
"""
129+
130+
modeling_file: typing.ClassVar[str]
131+
configuration_file: typing.ClassVar[str]
132+
configuration_cls: typing.ClassVar[type[PretrainedConfig]]
133+
134+
# Use custom config instead of relying on the transformers library
135+
@classmethod
136+
def _load_config(cls, directory: pathlib.Path | str) -> dict:
137+
config = cls.configuration_cls.from_pretrained(directory).to_dict()
138+
Assert.eq(config["model_type"], cls.get_huggingface_model_type())
139+
return config
140+
141+
@classmethod
142+
def _save_config(cls, directory: pathlib.Path | str, config: dict[str, typing.Any]) -> None:
143+
cls.configuration_cls.from_dict(config).save_pretrained(directory)
144+
145+
def save(self, config: CheckpointSaveConfig, metadata: CheckpointMetadata) -> None:
146+
super().save(config, metadata)
147+
self._copy_modeling_files(config)
148+
149+
def _copy_modeling_files(self, config: CheckpointSaveConfig) -> None:
150+
# Copy the modeling files to the output directory
151+
shutil.copy(self.modeling_file, config.path)
152+
shutil.copy(self.configuration_file, config.path)

fast_llm/layers/language_model/head.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,6 @@ def __init__(
7070
Assert.geq(prediction_distance, 0)
7171
self._prediction_distance = prediction_distance
7272
self.is_last_head = self._prediction_distance == config.prediction_heads - 1
73-
if self._prediction_distance > 0:
74-
assert (
75-
not self._sequence_parallel_logits
76-
), "Sequence parallel logits not supported for multi-token prediction."
77-
assert not self._cross_entropy_splits, "Cross-entropy splits not supported for multi-token prediction."
7873

7974
self._init_output_weights(hidden_dim, config)
8075

@@ -137,8 +132,9 @@ def forward(
137132
# Last head should return the loss for backward.
138133
return language_model_loss
139134
else:
140-
# Backward hook to compute the gradient of the loss
141-
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
135+
if self.training:
136+
# Backward hook to compute the gradient of the loss
137+
shared_hidden = AuxiliaryLoss.apply(shared_hidden, language_model_loss, 1.0)
142138
# MTP: Return shared_hidden to be used by the next head.
143139
return shared_hidden
144140

@@ -147,18 +143,22 @@ def _forward_backward(
147143
) -> tuple[torch.Tensor, torch.Tensor | None]:
148144
labels = kwargs[LanguageModelKwargs.labels] if LanguageModelKwargs.labels in kwargs else None
149145
# MTP: Shift the labels
150-
labels = labels[:, self._prediction_distance :].flatten() if labels is not None else None
146+
if labels is not None:
147+
labels = (
148+
labels[self._prediction_distance : self._prediction_distance + input_.size(0),]
149+
if kwargs[TransformerKwargs.sequence_first]
150+
else labels[
151+
:,
152+
self._prediction_distance : self._prediction_distance + input_.size(1),
153+
]
154+
)
155+
labels = labels.flatten()
151156
if self._sequence_parallel_logits:
152157
labels = split_op(labels, self._tensor_space.distributed.tensor_group, 0)
153158
do_grad = labels is not None and self.training
154159
input_ = input_.detach().requires_grad_(do_grad)
155160
with torch.enable_grad():
156-
# MTP: truncate the input
157-
if self._prediction_distance > 0:
158-
truncated_input = input_[:, : -self._prediction_distance, :].contiguous()
159-
else:
160-
truncated_input = input_
161-
ln_output = self.final_norm(truncated_input)
161+
ln_output = self.final_norm(input_)
162162

163163
grad_output = kwargs[TransformerKwargs.grad_output] / (
164164
self._group_size if self._sequence_parallel_logits else 1
@@ -197,7 +197,7 @@ def _logits_cross_entropy_forward_backward_split(
197197
)
198198
if labels is None:
199199
# TODO: Make a proper way of returning the model output.
200-
kwargs["logits"] = loss
200+
kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = loss
201201
return None, None
202202
else:
203203
loss = None

fast_llm/models/gpt/config.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
class GPTHuggingfaceCheckpointFormat(CheckpointFormat):
2121
support_optimizer: typing.ClassVar[bool] = False
22+
trust_remote_code: typing.ClassVar[bool] = False
2223

2324
@classmethod
2425
def get_handler_class(cls) -> type[CheckpointHandler]:
@@ -51,6 +52,11 @@ class MixtralGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
5152
name: typing.ClassVar[str] = "mixtral"
5253

5354

55+
class MTPLlamaGPTHuggingfaceCheckpointFormat(GPTHuggingfaceCheckpointFormat):
56+
name: typing.ClassVar[str] = "mtp_llama"
57+
trust_remote_code: typing.ClassVar[bool] = True
58+
59+
5460
@config_class()
5561
class GPTArchitectureConfig(LanguageModelArchitectureConfig):
5662
_abstract = False
@@ -145,6 +151,7 @@ class GPTModelConfig(FastLLMModelConfig):
145151
Qwen2GPTHuggingfaceCheckpointFormat,
146152
MistralGPTHuggingfaceCheckpointFormat,
147153
MixtralGPTHuggingfaceCheckpointFormat,
154+
MTPLlamaGPTHuggingfaceCheckpointFormat,
148155
)
149156

150157
@classmethod

0 commit comments

Comments
 (0)