Skip to content

Block interface: refactor #339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 31 commits into
base: tp_mamba
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Megatron-LM
30 changes: 15 additions & 15 deletions docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define:

```python
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))

# We usually want to loop dynamically over layers
for i in range(num_layers):
# A `SplitWeightConverter` example, splitting a weight in two.
converters.append(SplitWeightConverter(
f"layers.{i + 1}.weight",
(f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"),
))
return converters
```

And that's it! We're ready to use the new checkpoint format in Fast-LLM.
Expand Down
22 changes: 22 additions & 0 deletions fast_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,28 @@ def __init__(self, config: ConfigType, *args, **kwargs):
# Handle multiple inheritance.
super().__init__(*args, **kwargs)

def __init_subclass__(cls):
# Automatically set `config_class` based on the bound type.
# Make sure `ConfigType` is bound and respects class hierarchy.
try:
config_class = None
for base in types.get_original_bases(cls):
if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable):
for arg in base.__args__:
if arg.__name__ == "ConfigType":
if config_class is None:
config_class = arg.__bound__
else:
assert arg.__bound__ is config_class
assert config_class is not None
except Exception as e:
raise TypeError(
f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. "
"Please make sure to declare in the format "
f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] "
)
cls.config_class = config_class

@property
def config(self) -> ConfigType:
return self._config
Expand Down
2 changes: 0 additions & 2 deletions fast_llm/data/preparator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def _get_runnable(self) -> typing.Callable[[], None]:


class DatasetPreparator[ConfigType: DatasetPreparatorConfig](Configurable[ConfigType], abc.ABC):
config_class: typing.ClassVar[type[DatasetPreparatorConfig]] = DatasetPreparatorConfig

@abc.abstractmethod
def run(self) -> None:
raise NotImplementedError
19 changes: 0 additions & 19 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@ class SourceSchemaConfig(Config):
pass


@config_class(dynamic_type={SourceSchemaConfig: "prompt_completion"})
class PromptCompletionConfig(SourceSchemaConfig):
prompt_column: str = Field(
default="prompt",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
completion_column: str = Field(
default="completion",
desc="Field of the dataset to use.",
hint=FieldHint.optional,
)
delimiter: str = Field(
default="",
desc="Delimiter between prompt and completion.",
hint=FieldHint.optional,
)


@config_class(dynamic_type={SourceSchemaConfig: "text_column"})
class TextColumnConfig(SourceSchemaConfig):
input_column: str = Field(
Expand Down
94 changes: 28 additions & 66 deletions fast_llm/data/preparator/gpt_memmap/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from fast_llm.data.dataset.gpt.memmap import GPTMemmapDataset
from fast_llm.data.dataset.gpt.sampled import GPTSample
from fast_llm.data.preparator.config import DatasetPreparator
from fast_llm.data.preparator.gpt_memmap.config import (
GPTMemmapDatasetPreparatorConfig,
PromptCompletionConfig,
TextColumnConfig,
)
from fast_llm.data.preparator.gpt_memmap.config import GPTMemmapDatasetPreparatorConfig, TextColumnConfig
from fast_llm.data.tokenizer import Tokenizer
from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type
from fast_llm.utils import Assert, normalize_probabilities, padded_cumsum
Expand All @@ -37,8 +33,6 @@


class GPTMemmapDatasetPreparator[ConfigType: GPTMemmapDatasetPreparatorConfig](DatasetPreparator[ConfigType]):
config_class: typing.ClassVar[type[GPTMemmapDatasetPreparatorConfig]] = GPTMemmapDatasetPreparatorConfig

_tokenizer: Tokenizer
_data_type: DataType
_text_column: str
Expand All @@ -54,30 +48,6 @@ def _tokenize_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[
"num_tokens": num_tokens,
}

def _tokenize_prompt_completion_batch(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
"""
Tokenize prompt and completion columns separately, then concatenate.
Returns input_ids, token_spans (prompt len), and num_tokens.
"""
prompt_col = self._config.dataset.source_schema.prompt_column
completion_col = self._config.dataset.source_schema.completion_column
delimiter = self._config.dataset.source_schema.delimiter
input_ids = []
token_spans = []
for prompt, completion in zip(batch[prompt_col], batch[completion_col]):
prompt_tokens = self._tokenizer.tokenize(prompt, begin=True, end=False)
completion_tokens = self._tokenizer.tokenize(f"{delimiter}{completion}", begin=False, end=True)
combined = prompt_tokens + completion_tokens
input_ids.append(np.array(combined, dtype=self._data_type.numpy))
token_spans.append(np.array((0, len(prompt_tokens) - 1), dtype=np.int32).reshape(-1, 2))

num_tokens = [len(x) for x in input_ids]
return {
"input_ids": input_ids,
"token_spans": token_spans,
"num_tokens": num_tokens,
}

def _tokenize_batch_with_spans(self, batch: dict[str, list[typing.Any]]) -> dict[str, list[typing.Any]]:
input_ids, token_spans = map(
list,
Expand Down Expand Up @@ -171,7 +141,7 @@ def _save_shard(self, args: tuple[int, datasets.Dataset]) -> GPTMemmapDatasetCon
shard_output_path = self._config.output_path / prefix

def _document_generator():
if "token_spans" in shard_dataset.column_names:
if "token_spans" in shard_dataset.column_names and self._loss_masking_spans_column is not None:
for item in tqdm.tqdm(shard_dataset, desc=f"Saving shard {shard_idx}", unit="docs"):
yield GPTSample(
np.array(item["input_ids"], dtype=self._data_type.numpy),
Expand Down Expand Up @@ -317,46 +287,37 @@ def run(self) -> None:
)

# Set data column and loss masking spans column based on source schema
source_schema = self._config.dataset.source_schema
if isinstance(source_schema, TextColumnConfig):
self._text_column = source_schema.input_column
self._loss_masking_spans_column = source_schema.loss_masking_spans_column
elif isinstance(source_schema, PromptCompletionConfig):
Assert.incl(source_schema.prompt_column, dataset.column_names)
Assert.incl(source_schema.completion_column, dataset.column_names)
tokenize_fn = self._tokenize_prompt_completion_batch
if isinstance(self._config.dataset.source_schema, TextColumnConfig):
self._text_column = self._config.dataset.source_schema.input_column
self._loss_masking_spans_column = self._config.dataset.source_schema.loss_masking_spans_column
else:
raise ValueError(
f"Dataset source_schema set incorrectly. source_schema: '{self._config.dataset.source_schema}'."
)

# TODO: Add a new schema for preference datasets then drop class vars _loss_masking_spans_column & _text_column
if isinstance(source_schema, TextColumnConfig):
if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")
if self._text_column not in dataset.column_names:
raise ValueError(f"Dataset does not have field '{self._text_column}'.")

if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(
f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'."
)
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch
if self._config.dataset.source_schema.loss_masking_spans_column is not None and (
self._config.dataset.chosen_text is not None or self._config.dataset.rejected_text is not None
):
raise ValueError(f"Can not enable both loss masking spans and chosen/rejected loss masking spans.")
if (self._config.dataset.chosen_text is None) != (self._config.dataset.rejected_text is None):
raise ValueError(f"Both chosen and rejected loss masking spans must be specified if one is specified.")

# route tokenize function
if self._loss_masking_spans_column is not None:
if self._loss_masking_spans_column not in dataset.column_names:
raise ValueError(f"Dataset does not have spans field '{self._loss_masking_spans_column}'.")
tokenize_fn = self._tokenize_batch_with_spans
elif self._config.dataset.chosen_text is not None and self._config.dataset.rejected_text is not None:
if self._config.dataset.chosen_text not in dataset.column_names:
raise ValueError(f"Dataset does not have chosen spans field '{self._config.dataset.chosen_text}'.")
if self._config.dataset.rejected_text not in dataset.column_names:
raise ValueError(f"Dataset does not have rejected spans field '{self._config.dataset.rejected_text}'.")
tokenize_fn = self._tokenize_preference_batch_with_spans
else:
tokenize_fn = self._tokenize_batch

# Tokenize the dataset in parallel
tokenized_dataset = dataset.map(
Expand All @@ -368,6 +329,7 @@ def run(self) -> None:

# Calculate total number of tokens
total_tokens = sum(tqdm.tqdm(tokenized_dataset["num_tokens"], desc="Counting tokens", unit="tokens"))

# Split dataset into shards based on number of tokens
num_shards = int(np.ceil(total_tokens / self._config.tokens_per_shard))
shards = [
Expand Down
58 changes: 26 additions & 32 deletions fast_llm/engine/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from fast_llm.config import Configurable
from fast_llm.engine.base_model.config import BaseModelConfig
from fast_llm.engine.config_utils.tensor_space import TensorSpace
from fast_llm.engine.distributed.config import DistributedConfig, PhaseType
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.tensor import ParameterMeta, TensorMeta
Expand All @@ -20,11 +19,18 @@
class Module(torch.nn.Module, abc.ABC):
""" """

def forward(self, input_, kwargs):
"""
Run a forward pass for the module, with autograd support.
"""
raise NotImplementedError()
_is_setup: bool = False
_distributed: Distributed

def __init__(self, distributed_config: DistributedConfig):
self._distributed_config = distributed_config
super().__init__()

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._distributed_config)
self._distributed = distributed
self._is_setup = True


class Layer(Module):
Expand All @@ -39,9 +45,9 @@ def forward(


class Sequential(Layer):
def __init__(self, layers: list[Layer]):
super().__init__()
self.layers = torch.nn.ModuleList(layers)
def __init__(self, distributed_config: DistributedConfig):
super().__init__(distributed_config)
self.layers = torch.nn.ModuleList(self.get_layers())

def __getitem__(self, item):
return self.layers[item]
Expand All @@ -59,6 +65,15 @@ def forward(
input_ = layer(input_, kwargs, losses, metrics)
return input_

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
for layer in self.layers:
layer.setup(distributed)


@dataclasses.dataclass()
class LossDef:
Expand All @@ -71,29 +86,14 @@ class LossDef:
dtype: torch.dtype = torch.float32


class SequentialLayers(Sequential, abc.ABC):
# Small class defined to fix the MRO of BaseModel.__init__
def __init__(self):
super().__init__(self.get_layers())

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass


class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC):
config_class: typing.ClassVar[type[BaseModelConfig]] = BaseModelConfig
_is_setup: bool = False
class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential):

def __init__(
self,
config: BaseModelConfig,
distributed_config: DistributedConfig,
):
self._tensor_space: TensorSpace = TensorSpace(distributed_config)
config.setup_tensor_space(self._tensor_space)

super().__init__(config)
super().__init__(config, distributed_config)

for key, value in self.named_parameters():
Assert.custom(isinstance, value, ParameterMeta)
Expand All @@ -104,12 +104,6 @@ def __init__(
# TODO: Add basic handling (preprocessor) in this class.
self._reference_models: dict[str, "InferenceRunner"] = {}

def setup(self, distributed: Distributed) -> None:
assert not self._is_setup
distributed.check_config(self._tensor_space.distributed_config)
self._tensor_space.setup(distributed)
self._is_setup = True

@abc.abstractmethod
def get_layers(self) -> list[Layer]:
pass
Expand Down
7 changes: 2 additions & 5 deletions fast_llm/engine/base_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fast_llm.utils import compare_nested, log

if typing.TYPE_CHECKING:
from fast_llm.engine.config_utils.tensor_space import TensorSpace
import torch


@config_class()
Expand All @@ -18,9 +18,6 @@ class BaseModelConfig(Config):

_abstract = True

def setup_tensor_space(self, tensor_space: "TensorSpace") -> None:
raise NotImplementedError()

def compare_architecture(
self,
model_config: typing.Self,
Expand Down Expand Up @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None:
pass

@abc.abstractmethod
def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None:
def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None:
pass
Loading