From 3ed07e56b136415573abbef95143fadec34f45bf Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 Aug 2025 17:01:10 -0400 Subject: [PATCH 01/19] stuff --- docs/developer_guide/conversion.md | 30 ++-- fast_llm/config.py | 7 + .../engine/config_utils/initialization.py | 121 +++++++++++++++ fast_llm/layers/block/block.py | 87 ++++++----- fast_llm/layers/block/config.py | 118 +++++++++++---- fast_llm/layers/block/mlp/config.py | 140 +++++++++++------- .../layers/block/mlp/mixture_of_experts.py | 9 +- fast_llm/layers/block/mlp/mlp.py | 37 ++--- fast_llm/layers/common/config.py | 76 +++++----- fast_llm/layers/common/normalization.py | 114 ++++++++------ fast_llm/layers/language_model/config.py | 92 ++++++------ fast_llm/layers/language_model/embedding.py | 13 +- fast_llm/layers/language_model/head.py | 7 +- .../layers/language_model/preprocessing.py | 2 +- fast_llm/layers/transformer/attention.py | 38 ++--- fast_llm/layers/transformer/config.py | 134 ++++++++--------- fast_llm/models/custom/model.py | 2 +- fast_llm/models/gpt/config.py | 5 +- fast_llm/models/gpt/conversion.py | 26 ++-- fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 12 +- fast_llm/models/ssm/conversion.py | 6 +- tests/models/test_generate.py | 2 +- 23 files changed, 631 insertions(+), 459 deletions(-) diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 35a324db0..19d3ba926 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters + converters = [] + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_blocks + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. diff --git a/fast_llm/config.py b/fast_llm/config.py index c534b11f3..c36110790 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1032,6 +1032,13 @@ def __init__(self, config: ConfigType, *args, **kwargs): def config(self) -> ConfigType: return self._config + def __init_subclass__(cls): + # Automatically set `config_class` based on the bound type. + # Make sure `ConfigType` is bound and respects class hierarchy. + # TODO: Remove manual sets. + Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + cls.config_class = config_class + def set_nested_dict_value[ KeyType, ValueType diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index b60070562..cdee37935 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -1,12 +1,133 @@ import abc import typing +from fast_llm.config import Config, Field, FieldHint, check_field, config_class +from fast_llm.utils import Assert + if typing.TYPE_CHECKING: import torch from fast_llm.tensor import ParameterMeta +@config_class(registry=True) +class InitializationConfig(Config): + _abstract = True + has_initialization: typing.ClassVar[bool] = True + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is InitializationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return DefaultInitializationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + def get_initializer(self) -> "Initializer": + raise NotImplementedError() + + +@config_class(dynamic_type={InitializationConfig: "default"}) +class DefaultInitializationConfig(InitializationConfig): + # A placeholder indicating that the class default should be used instead. + _abstract = False + has_initialization = False + + +@config_class(dynamic_type={InitializationConfig: "fill"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + value: float = Field( + default=1, + desc="Initialization value.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self): + return init_fill_(self.value) + + +@config_class(dynamic_type={InitializationConfig: "zeros"}) +class ZerosInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_zeros_ + + +@config_class(dynamic_type={InitializationConfig: "ones"}) +class OnesInitializationConfig(InitializationConfig): + def get_initializer(self): + return init_ones_ + + +@config_class(dynamic_type={InitializationConfig: "normal"}) +class NormalInitializationConfig(InitializationConfig): + """ + Normal initialization: normal(mean, std).clamp(min,max) + """ + + _abstract = False + + std: float = Field( + default=1, + desc="Standard deviation for normal initialization.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=0, + desc="Mean for normal initialization.", + hint=FieldHint.optional, + ) + min: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + max: float | None = Field( + default=None, + desc="Min value for initialization clamping.", + hint=FieldHint.optional, + ) + + def get_initializer(self): + return init_normal_(self.mean, self.std, self.min, self.max) + + +@config_class(dynamic_type={InitializationConfig: "uniform"}) +class UniformInitializationConfig(InitializationConfig): + """ + Uniform initialization: uniform(mean - scale, mean + scale).clamp(min,max) + """ + + _abstract = False + + scale: float = Field( + default=None, + desc="Initialization scale.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + mean: float = Field( + default=None, + desc="Initialization mean.", + hint=FieldHint.optional, + valid=check_field(Assert.geq, 0), + ) + + def get_initializer(self) -> "Initializer": + return init_uniform_centered_(self.scale, self.mean) + + class Initializer(abc.ABC): @abc.abstractmethod def __call__(self, meta: "ParameterMeta", tensor: "torch.Tensor", generator: "torch.Generator") -> None: diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 03e0df928..ba11675b7 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -8,9 +8,10 @@ from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.engine.base_model.base_model import Layer +from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -87,13 +88,15 @@ def __call__[ ) -class BlockLayer(torch.nn.Module, abc.ABC): +class BlockLayerBase[ConfigType: BaseModelConfig](Configurable[ConfigType], torch.nn.Module): """ - Base class for mixer and MLP modules. + Base class for blocks, mixer and MLP modules. """ - def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug_level: int, debug_memory: bool): - super().__init__() + def __init__( + self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, block_config: BlockConfig + ): + super().__init__(config) self._tensor_space = tensor_space self._block_index = block_index self._name = name @@ -101,10 +104,23 @@ def __init__(self, tensor_space: TensorSpace, block_index: int, name: str, debug self._debug = DebugLayer( tensor_space, self._name, - debug_level, - debug_memory, + block_config.debug_transformer, + block_config.debug_transformer_memory, ) + # @property + # def name(self) -> str: + # return self._name + + +class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): + """ + Base class for mixer and MLP modules. + """ + + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name, config.block) + @abc.abstractmethod def forward( self, @@ -116,68 +132,49 @@ def forward( pass -class Block[ConfigType: BlockConfig](Configurable[ConfigType], Layer): +class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): """ A transformer-like decoder base block with abstract mixer. """ - config_class: typing.ClassVar[type[BlockConfig]] = BlockConfig # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config) - # TODO: Argument? - self._block_index = block_index - self._name = f"Block {self._block_index}" - self._tensor_space: TensorSpace = tensor_space - self._dropout_p: float = self._config.hidden_dropout + def __init__( + self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, return_input: bool = False + ): + super().__init__(config, tensor_space, block_index, name, config) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._debug = DebugLayer( - tensor_space, - self._name, - self._config.debug_transformer, - self._config.debug_transformer_memory, - ) hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.normalization.get_layer(hidden_dim) - self.norm_2 = self._config.normalization.get_layer(hidden_dim) + self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) + self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) # The mixer needs to be created here for backward-compatible weight ordering. - setattr(self, self._mixer_module_name, self._create_mixer()) - - # TODO: Use dynamic type. - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - from fast_llm.layers.block.mlp.mlp import MLP - - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, + setattr( + self, + self._mixer_module_name, + self._config.mixer.get_layer( + self._tensor_space, + self._block_index, + f"{self._name} mixer", + ), + ) + self._config.mlp.get_layer( self._tensor_space, self._block_index, + f"{self._name} mlp", ) - # PEFT. - self.norm_1 = self._config.peft.apply_other(self.norm_1) - self.norm_2 = self._config.peft.apply_other(self.norm_2) - - @abc.abstractmethod - def _create_mixer(self) -> BlockLayer: - pass - @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor ) -> torch.Tensor: if bias is not None: input_ = input_ + bias - return residual + torch.dropout(input_, self._dropout_p, self.training) - - # @property - # def name(self) -> str: - # return f"{self._name} {self._block_index}" + return residual + torch.dropout(input_, self._config.hidden_dropout, self.training) def forward( self, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 919f95b3f..10cb88485 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,13 +1,16 @@ import enum +import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import BlockLayer + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -34,15 +37,92 @@ class BlockKwargs: class AddLinearBiasChoices(str, enum.Enum): + # TODO: Review nowhere = "nowhere" everywhere = "everywhere" only_attn_qkv = "only_attn_qkv" @config_class() -# TODO: Use composition instead -class BlockConfig(MLPConfig, BaseModelConfig): +class BlockLayerConfig(BaseModelConfig): + """ + A common class for mixers and mlps, which have the exact same interface. + """ + + _abstract = True + block: "BlockConfig" = Field(init=False) + + @property + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": + return self.layer_class(self, tensor_space, block_index, name) + + +@config_class(registry=True) +class MixerConfig(BlockLayerConfig): + _abstract = True + + # Needed for backward compatibility. + module_name: typing.ClassVar[str] = "mixer" + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mixer, self) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.transformer.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + +@config_class(registry=True) +class MLPBaseConfig(BlockLayerConfig): + _abstract = True + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mlp, self) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.block.mlp.config import MLPConfig + + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class() +# TODO: Use composition instead +class BlockConfig(BaseModelConfig): + _abstract = False + mixer: MixerConfig = Field( + desc="Configuration for the mixer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the MLP.", + hint=FieldHint.architecture, + ) # TODO: Review names normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", @@ -58,11 +138,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -81,9 +156,14 @@ class BlockConfig(MLPConfig, BaseModelConfig): ) # TODO: Move these, not specific to a single block. - num_layers: int = Field( + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + num_blocks: int = Field( default=12, - desc="Number of layers in the transformer.", + desc="Number of blocks in the model.", hint=FieldHint.architecture, valid=check_field(Assert.geq, 0), ) @@ -100,24 +180,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, ) - # TODO: Review initialization - init_method_std: float = Field( - default=None, - desc="Default scale for weight initialization. Default: hidden_size**-0.5", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max: float | None = Field( - default=None, - desc="Max value for clamping initialized weights. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min: float | None = Field( - default=None, - desc="Min value for clamping initialized weights. Default: -float('inf')", - hint=FieldHint.optional, - ) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: super().setup_tensor_space(tensor_space) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index a99debacc..bde775a27 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,11 +1,18 @@ import enum +import functools +import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.mlp.mlp import MLPBase + class MLPDimNames: # MLP dimensions @@ -32,9 +39,10 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class() -class MLPConfig(Config): +@config_class(dynamic_type={BlockLayerConfig: "mlp"}) +class MLPConfig(BlockLayerConfig): # TODO: Review names + # TODO: Separate MoE? _abstract = False ffn_hidden_size: int = Field( default=None, @@ -122,75 +130,58 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) - # TODO: Review initialization - init_method_std_mlp_1: float = Field( - default=None, - desc="Scale for the MLP first layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_mlp_1: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP first layer. Default: float('inf')", - hint=FieldHint.optional, + layer_1_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer weights. Default: normal(0, hidden_size**-0.5).", + hint=FieldHint.feature, ) - init_method_min_mlp_1: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP first layer. Default: -float('inf')", - hint=FieldHint.optional, + layer_1_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, ) - init_method_std_mlp_2: float = Field( - default=None, - desc="Scale for the MLP second layer weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), + layer_2_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer weights." + " Default: normal((2 * num_blocks * hidden_size)**-0.5)", + hint=FieldHint.feature, ) - init_method_max_mlp_2: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for MLP second layer. Default: float('inf')", - hint=FieldHint.optional, + layer_2_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, ) - init_method_min_mlp_2: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for MLP second layer. Default: -float('inf')", - hint=FieldHint.optional, + router_weight_initialization: InitializationConfig = Field( + # TODO: Improve default? + desc="Initialization configuration for the MoE router weight. Default: normal(0, hidden_size**-0.5).", + hint=FieldHint.feature, ) @property - def add_mlp_bias(self) -> bool: + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + + @property + def add_bias(self) -> bool: from fast_llm.layers.block.config import AddLinearBiasChoices - # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: return True return False def _validate(self) -> None: + assert hasattr(self, "block") with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu - # TODO: Make this work without inheritance. + # TODO: `hidden_size` not yet validated. if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size - # TODO: Review initialization - if self.init_method_std_mlp_1 is None: - self.init_method_std_mlp_1 = self.init_method_std - if self.init_method_std_mlp_2 is None: - self.init_method_std_mlp_2 = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_mlp_1 is None: - self.init_method_max_mlp_1 = self.init_method_max - if self.init_method_min_mlp_1 is None: - self.init_method_min_mlp_1 = self.init_method_min - if self.init_method_max_mlp_2 is None: - self.init_method_max_mlp_2 = self.init_method_max - if self.init_method_min_mlp_2 is None: - self.init_method_min_mlp_2 = self.init_method_min - if self.init_method_min_mlp_1 is not None and self.init_method_max_mlp_1 is not None: - Assert.leq(self.init_method_min_mlp_1, self.init_method_max_mlp_1) - if self.init_method_min_mlp_2 is not None and self.init_method_max_mlp_2 is not None: - Assert.leq(self.init_method_min_mlp_2, self.init_method_max_mlp_2) + self.ffn_hidden_size = 4 * self.block.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts @@ -207,6 +198,45 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) + if self.layer_1_bias_initialization.has_initialization or self.layer_2_bias_initialization.has_initialization: + assert self.add_bias + + @functools.cached_property + def layer_1_weight_initialization_method(self) -> Initializer: + if self.layer_1_weight_initialization.has_initialization: + return self.layer_1_weight_initialization.get_initializer() + else: + return init_normal_(0, self.block.hidden_size**-0.5) + + @functools.cached_property + def layer_1_bias_initialization_method(self) -> Initializer: + if self.layer_1_bias_initialization.has_initialization: + return self.layer_1_bias_initialization.get_initializer() + else: + return init_zeros_ + + @functools.cached_property + def layer_2_weight_initialization_method(self) -> Initializer: + if self.layer_2_weight_initialization.has_initialization: + return self.layer_2_weight_initialization.get_initializer() + else: + return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) + + @functools.cached_property + def layer_2_bias_initialization_method(self) -> Initializer: + if self.layer_2_bias_initialization.has_initialization: + return self.layer_2_bias_initialization.get_initializer() + else: + return init_zeros_ + + @functools.cached_property + def router_weight_initialization_method(self) -> Initializer: + if self.router_weight_initialization.has_initialization: + assert self.add_bias + return self.router_weight_initialization.get_initializer() + else: + return init_zeros_ + def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index e53693460..f401371a4 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -4,12 +4,11 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockDimNames, BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -18,7 +17,7 @@ logger = logging.getLogger(__name__) -class MixtureOfExpertMLP(MLPBase): +class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): """ MoeLayer following implementation from https://github.com/NVIDIA/Megatron-LM/blob/46ebc0e4202c980d98900000d455f754a7ff9d4b/megatron/model/transformer.py#L346 @@ -32,7 +31,7 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 06850c8d0..8e68e6274 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,39 +2,20 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPDimNames +from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, get_lr_scale -class MLPBase(BlockLayer): +class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): - super().__init__( - tensor_space, - block_index, - name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, - ) - self._config = config - - init_method_1 = init_normal_( - std=self._config.init_method_std_mlp_1, - min_val=self._config.init_method_min_mlp_1, - max_val=self._config.init_method_max_mlp_1, - ) - init_method_2 = init_normal_( - std=self._config.init_method_std_mlp_2, - min_val=self._config.init_method_min_mlp_2, - max_val=self._config.init_method_max_mlp_2, - ) + super().__init__(config, tensor_space, block_index, name) hidden_dim = self._tensor_space[BlockDimNames.hidden] self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] @@ -52,17 +33,17 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_1 = LinearBase( hidden_dim, self._tensor_space[MLPDimNames.composite_gated_expert_mlp], - bias=self._config.add_mlp_bias, - weight_init_method=init_method_1, - bias_init_method=init_zeros_, + bias=self._config.add_bias, + weight_init_method=self._config.layer_1_weight_initialization_method, + bias_init_method=self._config.layer_1_bias_initialization_method, lr_scale=lr_scale, ) self.layer_2 = LinearBase( self._intermediate_dim, hidden_dim, - bias=self._config.add_mlp_bias, - weight_init_method=init_method_2, - bias_init_method=init_zeros_, + bias=self._config.add_bias, + weight_init_method=self._config.layer_2_weight_initialization_method, + bias_init_method=self._config.layer_2_bias_initialization_method, auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 2f45fdf9f..8483dc573 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,17 +1,17 @@ import abc import enum +import functools import typing from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.utils import Assert if typing.TYPE_CHECKING: - import torch - - from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization import LayerNorm, RMSNorm + from fast_llm.layers.common.normalization import Normalization class NormalizationImplementation(str, enum.Enum): @@ -30,10 +30,14 @@ class NormalizationImplementation(str, enum.Enum): class NormalizationConfig(BaseModelConfig): pass + @property @abc.abstractmethod - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": + def module_class(self) -> type["Normalization"]: pass + def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "Normalization": + return self.module_class(self, hidden_dim, lr_scale) + @classmethod def _from_dict( cls, @@ -51,8 +55,11 @@ def _from_dict( class NoNormalizationConfig(NormalizationConfig): _abstract = False - def get_layer(self, hidden_dim: "TensorDim") -> "torch.nn.Module": - return torch.nn.Identity() + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization import NoNormalization + + return NoNormalization @config_class() @@ -78,34 +85,11 @@ class LayerNormalizationBaseConfig(NormalizationConfig): desc="The implementation to use for the normalization layer.", hint=FieldHint.performance, ) - # TODO: Rename to normalization_init_range - initialization_range: float = Field( - default=0.0, - desc="Randomize the initialization with a uniform noise. Used to test for issues that may not be visible with the default initialization.", - hint=FieldHint.testing, - valid=check_field(Assert.geq, 0), + weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the normalization weights. Default: fill with ones", + hint=FieldHint.feature, ) - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.engine.config_utils.initialization import init_uniform_centered_ - - kwargs = { - "hidden_dim": hidden_dim, - "eps": self.epsilon, - "implementation": self.implementation, - "zero_centered": self.zero_centered, - "lr_scale": lr_scale, - } - if self.initialization_range: - mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) - return self.module_class(**kwargs) - - @property - @abc.abstractmethod - def module_class(self): - pass - @classmethod def _from_dict( cls, @@ -120,16 +104,34 @@ def _from_dict( cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") return super()._from_dict(default, strict, flat) + @functools.cached_property + def weight_initialization_method(self) -> Initializer: + if self.weight_initialization.has_initialization: + return self.weight_initialization.get_initializer() + else: + return init_ones_ + @config_class(dynamic_type={NormalizationConfig: "layer_norm"}) class LayerNormalizationConfig(LayerNormalizationBaseConfig): _abstract = False + bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the normalization biases. Default: fill with zeros", + hint=FieldHint.feature, + ) + + @functools.cached_property + def bias_initialization_method(self) -> Initializer: + if self.bias_initialization.has_initialization: + return self.bias_initialization.get_initializer() + else: + return init_zeros_ @property def module_class(self): - from fast_llm.layers.common.normalization import LayerNorm + from fast_llm.layers.common.normalization import LayerNormalization - return LayerNorm + return LayerNormalization @config_class(dynamic_type={NormalizationConfig: "rms_norm"}) @@ -138,9 +140,9 @@ class RMSNormalizationConfig(LayerNormalizationBaseConfig): @property def module_class(self): - from fast_llm.layers.common.normalization import RMSNorm + from fast_llm.layers.common.normalization import RMSNormalization - return RMSNorm + return RMSNormalization @config_class() diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index d44be3297..cedfd2294 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -1,11 +1,19 @@ +import abc + import torch -from fast_llm.engine.config_utils.initialization import init_ones_, init_zeros_ +from fast_llm.config import Configurable from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import NormalizationImplementation +from fast_llm.layers.common.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + NormalizationConfig, + NormalizationImplementation, + RMSNormalizationConfig, +) from fast_llm.tensor import ParameterMeta, accumulate_gradient from fast_llm.utils import Assert @@ -139,7 +147,29 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, return grad_input, None, None, None -class LayerNorm(torch.nn.Module): +class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): + def __init__( + self, + config: NormalizationConfig, + hidden_dim: TensorDim, + lr_scale: float | None = None, + ): + super().__init__(config) + self._hidden_dim = hidden_dim + self._lr_scale = lr_scale + assert not self._hidden_dim.is_parallel + + @abc.abstractmethod + def forward(self, input_: torch.Tensor) -> torch.Tensor: + pass + + +class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + +class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): """ A layer normalization layer, supporting multiple implementations. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -149,23 +179,20 @@ class LayerNorm(torch.nn.Module): def __init__( self, + config: LayerNormalizationConfig, hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - bias_init_method=init_zeros_, - zero_centered: bool = False, lr_scale: float | None = None, ): - super().__init__() - assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + super().__init__(config, hidden_dim, lr_scale) + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if _fast_normalization_available and hidden_dim.size in _PERSIST_LN_SIZES and not self._zero_centered: + if ( + _fast_normalization_available + and hidden_dim.size in _PERSIST_LN_SIZES + and not self._config.zero_centered + ): implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._zero_centered: + elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -174,7 +201,7 @@ def __init__( else: log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -187,44 +214,43 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ - self.weight = ParameterMeta.from_dims( (hidden_dim,), - init_method=weight_init_method, + init_method=self._config.weight_initialization_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), - init_method=bias_init_method, + init_method=self._config.bias_initialization_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, lr_scale=lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: return triton_normalization_autograd( - input_, self.weight, self.bias, self._eps, self.training, self._zero_centered + input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered ) def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: - return FastLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedLayerNorm.apply(input_, self.normalized_shape, self.weight, self.bias, self._eps) + return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.layer_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self.bias, self._eps) + return torch.layer_norm( + input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon + ) -class RMSNorm(torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -234,20 +260,15 @@ class RMSNorm(torch.nn.Module): def __init__( self, + config: RMSNormalizationConfig, hidden_dim: TensorDim, - *, - eps=1e-5, - implementation: NormalizationImplementation = NormalizationImplementation.auto, - weight_init_method=None, - zero_centered: bool = False, lr_scale: float | None = None, ): - super().__init__() + super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel - self._eps = eps - self._zero_centered = zero_centered + implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._zero_centered: + if TritonConfig.TRITON_ENABLED or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") @@ -255,7 +276,7 @@ def __init__( else: log_main_rank("Fused RMS norm unavailable, using backup implementation.") implementation = NormalizationImplementation.torch - if self._zero_centered: + if self._config.zero_centered: assert implementation == NormalizationImplementation.triton if implementation == NormalizationImplementation.triton: self._forward = self._forward_triton @@ -266,26 +287,25 @@ def __init__( else: raise NotImplementedError(implementation) - if weight_init_method is None: - weight_init_method = init_zeros_ if self._zero_centered else init_ones_ - self.weight = ParameterMeta.from_dims( (hidden_dim,), - init_method=weight_init_method, + init_method=self._config.weight_initialization_method, weight_decay=False, auto_grad_accumulation=True, lr_scale=lr_scale, ) - self.normalized_shape = self.weight.shape + self._normalized_shape = self.weight.shape def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self.normalized_shape)).view_as(input_) + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd(input_, self.weight, None, self._eps, self.training, self._zero_centered) + return triton_normalization_autograd( + input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered + ) def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedRMSNorm.apply(input_, self.normalized_shape, self.weight, self._eps) + return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.rms_norm(input_.to(self.weight.dtype), self.normalized_shape, self.weight, self._eps) + return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index b667e5318..a5d1b6a2e 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,13 +1,12 @@ -import typing +import functools from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.utils import Assert @@ -48,25 +47,20 @@ class LanguageModelKwargs(BlockKwargs): @config_class() class LanguageModelBaseConfig(BaseModelConfig): # TODO: block - transformer: TransformerConfig = Field( + transformer: BlockConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - max_position_embeddings: int = Field( - default=2048, - desc="Number of absolute position embeddings, if applicable.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - use_position_embeddings: bool = Field( + absolute_position_embeddings: int | None = Field( + # TODO: backward compatibility? default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, ) tie_word_embeddings: bool = Field( @@ -203,6 +197,18 @@ class LanguageModelBaseConfig(BaseModelConfig): doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) + word_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for word embeddings. Default: normal(std=hidden_size**-0.5)", + hint=FieldHint.feature, + ) + position_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for position embeddings. Default: normal(hidden_size**-0.5)", + hint=FieldHint.feature, + ) + output_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for untied output weights. Default: normal(hidden_size**-0.5)", + hint=FieldHint.feature, + ) def _validate(self) -> None: self.transformer.validate() @@ -212,14 +218,6 @@ def _validate(self) -> None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 - if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) - if self.init_method_std_embed is None: - self.init_method_std_embed = self.transformer.init_method_std - if self.init_method_max_embed is None: - self.init_method_max_embed = self.transformer.init_method_max - if self.init_method_min_embed is None: - self.init_method_min_embed = self.transformer.init_method_min super()._validate() if self.init_method_max_embed is not None and self.init_method_min_embed is not None: Assert.leq(self.init_method_min_embed, self.init_method_max_embed) @@ -234,39 +232,47 @@ def _validate(self) -> None: # -1 because the first prediction head's transformer layer is accounted for in num_layers # +1 because the layer index starts at 1 Assert.eq( - len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + len(self.transformer.per_layer_lr_scale), self.transformer.num_blocks + self.prediction_heads - 1 + 1 ) + if self.output_weight_initialization.has_initialization: + assert self.use_absolute_position_embeddings + if self.output_weight_initialization.has_initialization: + assert not self.tie_word_embeddings def setup_tensor_space(self, tensor_space: TensorSpace) -> None: self.transformer.setup_tensor_space(tensor_space) tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) # Embedding dimensions - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.position_embed, self.max_position_embeddings)) + if self.use_absolute_position_embeddings: + tensor_space.add_tensor_dim( + TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) + ) # TODO: Need both? tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property - def num_absolute_position_embeddings(self) -> int: - # TODO: Rename from max embeddings. - return self.max_position_embeddings if self.use_absolute_position_embeddings else None - @property def use_absolute_position_embeddings(self) -> int: - # TODO: Set through num embeddings instead instead. - return self.use_position_embeddings + return self.absolute_position_embeddings is not None + + @functools.cached_property + def word_embedding_weight_initialization_method(self) -> Initializer: + if self.word_embedding_weight_initialization.has_initialization: + return self.word_embedding_weight_initialization.get_initializer() + else: + return self.transformer.hidden_size**-0.5 + + @functools.cached_property + def position_embedding_weight_initialization_method(self) -> Initializer: + if self.position_embedding_weight_initialization.has_initialization: + return self.position_embedding_weight_initialization.get_initializer() + else: + return self.transformer.hidden_size**-0.5 - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) + @functools.cached_property + def output_weight_initialization_method(self) -> Initializer: + if self.output_weight_initialization.has_initialization: + return self.output_weight_initialization.get_initializer() + else: + return self.transformer.hidden_size**-0.5 diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 1ecafb344..a546159dd 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -6,7 +6,6 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta @@ -59,21 +58,13 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), + init_method=self._config.word_embedding_weight_initialization_method, lr_scale=config.embeddings_lr_scale, ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), + init_method=self._config.position_embedding_weight_initialization_method, allow_sequence_tensor_parallel=not config.parallel_embeddings, lr_scale=config.embeddings_lr_scale, ) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 6d1fedd26..691914b86 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -8,7 +8,6 @@ from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward @@ -96,11 +95,7 @@ def __init__( ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), - init_method=init_normal_( - std=self._config.init_method_std_embed, - min_val=self._config.init_method_min_embed, - max_val=self._config.init_method_max_embed, - ), + init_method=self._config.output_weight_initialization_method, lr_scale=self._config.output_lr_scale, ) diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index f5d915855..440ce9580 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -34,7 +34,7 @@ def _create_tensors(self, sequence_length: int) -> None: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) + Assert.leq(sequence_length, self._config.absolute_position_embeddings) self._position_ids = torch.arange( 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 ) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index ba7f2bb6e..1978597fd 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,13 +4,12 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs from fast_llm.utils import get_lr_scale try: @@ -46,7 +45,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(BlockLayer): +class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): """ A self-attention layer. """ @@ -71,28 +70,11 @@ class Attention(BlockLayer): AttentionDimNames.composite_dense, ) - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): - super().__init__( - tensor_space, - block_index, - self._mixer_name, - debug_level=config.debug_transformer, - debug_memory=config.debug_transformer_memory, - ) + def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + super().__init__(config, tensor_space, block_index, name) self._config = config self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, - ) - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size @@ -110,8 +92,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i hidden_dim, self._tensor_space[AttentionDimNames.composite_query], bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -119,8 +101,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i hidden_dim, self._tensor_space[AttentionDimNames.composite_key_value], bias=self._config.add_qkv_bias, - weight_init_method=init_method_qkv, - bias_init_method=init_zeros_, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) @@ -134,8 +116,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i self._tensor_space[AttentionDimNames.composite_dense], hidden_dim, bias=self._config.add_dense_bias, - weight_init_method=init_method_std_attn_proj, - bias_init_method=init_zeros_, + weight_init_method=self._config.dense_weight_initialization_method, + bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index f7c7fea9c..d1759dc2f 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -3,12 +3,13 @@ import typing import warnings -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs, MixerConfig from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div @@ -45,8 +46,8 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class() -class AttentionConfig(Config): +@config_class(dynamic_type={MixerConfig: "attention"}) +class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. _abstract = False @@ -106,64 +107,28 @@ class AttentionConfig(Config): " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) - # TODO: Review initialization - init_method_std_qkv: float = Field( - default=None, - desc="Scale for the query, key and value weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), - ) - init_method_max_qkv: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for query, key and value matrices. Default: float('inf')", - hint=FieldHint.optional, - ) - init_method_min_qkv: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for query, key and value matrices. Default: -float('inf')", - hint=FieldHint.optional, + qkv_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer weights. Default: normal(std=hidden_size**-0.5)", + hint=FieldHint.feature, ) - init_method_std_attn_proj: float = Field( - default=None, - desc="Scale for the attention projection weight initialization. Default: init_method_std", - hint=FieldHint.optional, - valid=check_field(Assert.geq, 0), + qkv_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", + hint=FieldHint.feature, ) - init_method_max_attn_proj: float | None = Field( - default=None, - desc="Max value for clamping initialized weights for attention projection. Default: float('inf')", - hint=FieldHint.optional, + dense_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer weight. Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)", + hint=FieldHint.feature, ) - init_method_min_attn_proj: float | None = Field( - default=None, - desc="Min value for clamping initialized weights for attention projection. Default: -float('inf')", - hint=FieldHint.optional, + dense_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", + hint=FieldHint.feature, ) def _validate(self) -> None: with self._set_implicit_default(): # TODO: Make this work without inheritance. if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - # TODO: Review initialization - if self.init_method_std_qkv is None: - self.init_method_std_qkv = self.init_method_std - if self.init_method_std_attn_proj is None: - self.init_method_std_attn_proj = self.init_method_std / max(2 * self.num_layers, 1) ** 0.5 - if self.init_method_max_qkv is None: - self.init_method_max_qkv = self.init_method_max - if self.init_method_min_qkv is None: - self.init_method_min_qkv = self.init_method_min - if self.init_method_max_attn_proj is None: - self.init_method_max_attn_proj = self.init_method_max - if self.init_method_min_attn_proj is None: - self.init_method_min_attn_proj = self.init_method_min - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min, self.init_method_max) - if self.init_method_min_qkv is not None and self.init_method_max_qkv is not None: - Assert.leq(self.init_method_min_qkv, self.init_method_max_qkv) - if self.init_method_min_attn_proj is not None and self.init_method_max_attn_proj is not None: - Assert.leq(self.init_method_min_attn_proj, self.init_method_max_attn_proj) + self.kv_channels = div(self.block.hidden_size, self.num_attention_heads) super()._validate() @@ -171,6 +136,10 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") Assert.multiple(self.num_attention_heads, self.head_groups) + if self.qkv_bias_initialization.has_initialization: + assert self.add_qkv_bias + if self.dense_bias_initialization.has_initialization: + assert self.add_dense_bias @functools.cached_property def projection_size(self): @@ -213,34 +182,47 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: @property def add_qkv_bias(self) -> bool: # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.nowhere: - return False - return True + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + return self.block.add_linear_biases != AddLinearBiasChoices.nowhere @property def add_dense_bias(self) -> bool: # TODO: Make this work without inheritance. - if isinstance(self.add_linear_biases, bool): - return self.add_linear_biases - if self.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + if isinstance(self.block.add_linear_biases, bool): + return self.block.add_linear_biases + return self.block.add_linear_biases == AddLinearBiasChoices.everywhere + @functools.cached_property + def qkv_weight_initialization_method(self) -> Initializer: + if self.qkv_weight_initialization.has_initialization: + return self.qkv_weight_initialization.get_initializer() + else: + return self.block.hidden_size**-0.5 -@config_class() -# TODO: Use composition instead -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False + @functools.cached_property + def qkv_bias_initialization_method(self) -> Initializer: + if self.qkv_bias_initialization.has_initialization: + return self.qkv_bias_initialization.get_initializer() + else: + return init_zeros_ - def _validate(self) -> None: - with self._set_implicit_default(): - # Kept here for initialization order. - # TODO: Review initialization - if self.init_method_std is None: - self.init_method_std = self.hidden_size**-0.5 - if self.init_method_min is not None and self.init_method_max is not None: - Assert.leq(self.init_method_min, self.init_method_max) + @functools.cached_property + def dense_weight_initialization_method(self) -> Initializer: + if self.dense_weight_initialization.has_initialization: + return self.dense_weight_initialization.get_initializer() + else: + return self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1) - super()._validate() + @functools.cached_property + def dense_bias_initialization_method(self) -> Initializer: + if self.dense_bias_initialization.has_initialization: + return self.dense_bias_initialization.get_initializer() + else: + return init_zeros_ + + +@config_class() +# TODO: Remove +class TransformerConfig(BlockConfig): + pass diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index 3c0ad8ab4..eb24ef183 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -36,7 +36,7 @@ def get_layers(self) -> list[Layer]: self._tensor_space, block_index=i + 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], CustomHead(self._config, self._tensor_space), ] diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 0da16428e..84960c0f2 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -192,15 +192,12 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: - if self.batch.sequence_length is None: - # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 6e79388b0..0ef970db2 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -241,13 +241,13 @@ def _create_transformer_layer_converters( converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_1", (), - transformer_config.add_mlp_bias, + transformer_config.add_bias, cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_2", (), - transformer_config.add_mlp_bias, + transformer_config.add_bias, cls=IgnoreExportWeightConverter, ) converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -344,12 +344,12 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig transformer_config: TransformerConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_bias ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -463,13 +463,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -531,13 +531,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -641,20 +641,20 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 187ca618d..30842597d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -68,7 +68,7 @@ def get_output_layers(self) -> list[Layer]: self._config.transformer, self._tensor_space, # TODO MTP: which index? - block_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_blocks + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -93,9 +93,9 @@ def get_layers(self) -> list[Layer]: block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, + return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, ) - for i in range(self._config.transformer.num_layers) + for i in range(self._config.transformer.num_blocks) ], *self.get_output_layers(), ] @@ -372,7 +372,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -380,7 +380,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.logit_z_loss: @@ -421,7 +421,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9427f69be..ecbcb0c35 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -62,13 +62,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_blocks: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" + if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats @@ -179,7 +179,7 @@ def _validate(self) -> None: else: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) # if self.model.base_model.distillation_model is not None: # # TODO: Support loss masking for distillation? # assert not self.batch.use_loss_masking_spans diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 43e3c67e5..fb24c1aec 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -219,7 +219,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -383,7 +383,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -572,7 +572,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False # Embedding and output diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f8..cb9c69ccb 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks ) From 94bf7ac75f7c20f93ee27c0cadaa1cb6d7bed128 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 Aug 2025 17:03:10 -0400 Subject: [PATCH 02/19] stuff --- docs/developer_guide/conversion.md | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 19d3ba926..a465cb9a2 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -230,21 +230,21 @@ Continuing our `AwesomeModel` handler example, we define: ```python def _create_weight_converters(self) -> list[WeightConverter]: - converters = [] - # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_blocks - - # A simple renaming example, for the word embeddings. - converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) - - # We usually want to loop dynamically over layers - for i in range(num_layers): - # A `SplitWeightConverter` example, splitting a weight in two. - converters.append(SplitWeightConverter( - f"layers.{i + 1}.weight", - (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), - )) - return converters + converters = [] + # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. + num_layers = self._model.config.base_model.transformer.num_blocks + + # A simple renaming example, for the word embeddings. + converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) + + # We usually want to loop dynamically over layers + for i in range(num_layers): + # A `SplitWeightConverter` example, splitting a weight in two. + converters.append(SplitWeightConverter( + f"layers.{i + 1}.weight", + (f"model.layers.{i}.weight_1", f"model.layers.{i}.weight_2"), + )) + return converters ``` And that's it! We're ready to use the new checkpoint format in Fast-LLM. From a9d1d56e6c3c904c1447b9c0c253fb2d81b1b61c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 Aug 2025 18:29:32 -0400 Subject: [PATCH 03/19] stuff --- fast_llm/layers/block/block.py | 23 ++++++++++++----------- fast_llm/layers/block/config.py | 2 +- fast_llm/layers/transformer/attention.py | 2 -- fast_llm/layers/transformer/config.py | 3 +++ 4 files changed, 16 insertions(+), 14 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index ba11675b7..528523bd0 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -113,7 +113,7 @@ def __init__( # return self._name -class BlockLayer[ConfigType: BlockLayerConfig](Configurable[ConfigType], torch.nn.Module): +class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch.nn.Module): """ Base class for mixer and MLP modules. """ @@ -137,8 +137,13 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" + # TODO: Needed for pycharm? + _config: ConfigType + _tensor_space: TensorSpace + _block_index: int + _name: str + _sequence_parallel: bool + _debug: DebugLayer def __init__( self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, return_input: bool = False @@ -152,21 +157,17 @@ def __init__( self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) - # The mixer needs to be created here for backward-compatible weight ordering. + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, - self._mixer_module_name, + self._config.mixer.module_name, self._config.mixer.get_layer( self._tensor_space, self._block_index, f"{self._name} mixer", ), ) - self._config.mlp.get_layer( - self._tensor_space, - self._block_index, - f"{self._name} mlp", - ) + self.mlp = self._config.mlp.get_layer(self._tensor_space, self._block_index, f"{self._name} mlp") @torch.compile def _bias_dropout_add( @@ -199,7 +200,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 10cb88485..786474262 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -64,7 +64,7 @@ def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> " class MixerConfig(BlockLayerConfig): _abstract = True - # Needed for backward compatibility. + # Needed for backward compatibility. TODO: Standardize to `mixer` module_name: typing.ClassVar[str] = "mixer" def _validate(self) -> None: diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 1978597fd..0bea58d9a 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -50,8 +50,6 @@ class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): A self-attention layer. """ - _mixer_name: typing.ClassVar[str] = "attn" - _QUERY_DIMS = ( AttentionDimNames.batch, AttentionDimNames.sequence_q, diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index d1759dc2f..6630397eb 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -51,6 +51,9 @@ class AttentionConfig(MixerConfig): # TODO: Make mixer class dynamic. _abstract = False + # Needed for backward compatibility. TODO: remove + module_name: typing.ClassVar[str] = "attn" + # TODO: Review names rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", From 87988d5e789d0f63e8349762b43c0559f096dfc0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 Aug 2025 19:14:51 -0400 Subject: [PATCH 04/19] stuff --- fast_llm/layers/block/block.py | 17 +++---------- fast_llm/layers/block/config.py | 13 +++++----- fast_llm/layers/block/mlp/mlp.py | 8 ++---- fast_llm/layers/language_model/config.py | 27 +++------------------ fast_llm/layers/language_model/embedding.py | 25 +++++++++---------- fast_llm/layers/language_model/head.py | 23 +++--------------- 6 files changed, 31 insertions(+), 82 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 528523bd0..3283aef9c 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -88,7 +88,7 @@ def __call__[ ) -class BlockLayerBase[ConfigType: BaseModelConfig](Configurable[ConfigType], torch.nn.Module): +class BlockLayerBase[ConfigType: BaseModelConfig](Configurable[ConfigType], torch.nn.Module, abc.ABC): """ Base class for blocks, mixer and MLP modules. """ @@ -120,6 +120,7 @@ class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): super().__init__(config, tensor_space, block_index, name, config.block) + self._block_index = block_index @abc.abstractmethod def forward( @@ -137,14 +138,6 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - # TODO: Needed for pycharm? - _config: ConfigType - _tensor_space: TensorSpace - _block_index: int - _name: str - _sequence_parallel: bool - _debug: DebugLayer - def __init__( self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, return_input: bool = False ): @@ -161,11 +154,7 @@ def __init__( setattr( self, self._config.mixer.module_name, - self._config.mixer.get_layer( - self._tensor_space, - self._block_index, - f"{self._name} mixer", - ), + self._config.mixer.get_layer(self._tensor_space, self._block_index, f"{self._name} mixer"), ) self.mlp = self._config.mlp.get_layer(self._tensor_space, self._block_index, f"{self._name} mlp") diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 786474262..680a122eb 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -11,6 +11,8 @@ if typing.TYPE_CHECKING: from fast_llm.layers.block.block import BlockLayer +# TODO: Generalize these beyond language models? (Ex. vision) + class BlockDimNames: # A set of common tensor dim names packed into a namespace. @@ -112,7 +114,6 @@ def _from_dict( @config_class() -# TODO: Use composition instead class BlockConfig(BaseModelConfig): _abstract = False mixer: MixerConfig = Field( @@ -156,11 +157,6 @@ class BlockConfig(BaseModelConfig): ) # TODO: Move these, not specific to a single block. - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) num_blocks: int = Field( default=12, desc="Number of blocks in the model.", @@ -173,6 +169,11 @@ class BlockConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) per_layer_lr_scale: list[float] | None = Field( default=None, desc="Custom learning rate scale for each layer.", diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 8e68e6274..a96755e0a 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -10,7 +10,7 @@ from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import get_lr_scale class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -54,11 +54,7 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) -class MLP(MLPBase): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): - Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, block_index, name) - +class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index a5d1b6a2e..943c64d01 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -2,7 +2,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl @@ -74,22 +74,6 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - init_method_std_embed: float = Field( - default=None, - desc="Initialization scale for the vocabulary embedding and output weights (logits).", - hint=FieldHint.feature, - valid=check_field(Assert.geq, 0), - ) - init_method_max_embed: float | None = Field( - default=None, - desc="Max value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) - init_method_min_embed: float | None = Field( - default=None, - desc="Min value for clamping initialized weights of the vocabulary embedding and output (logits).", - hint=FieldHint.feature, - ) enable_dpo: bool | None = Field( default=False, desc="Whether to enable DPO loss", @@ -211,7 +195,6 @@ class LanguageModelBaseConfig(BaseModelConfig): ) def _validate(self) -> None: - self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: @@ -219,8 +202,6 @@ def _validate(self) -> None: else: self.language_model_loss_factor = 0.0 super()._validate() - if self.init_method_max_embed is not None and self.init_method_min_embed is not None: - Assert.leq(self.init_method_min_embed, self.init_method_max_embed) if self.distillation_model is not None: if self.prediction_heads > 1: raise NotImplementedError("Multi-token prediction not supported with distillation.") @@ -261,18 +242,18 @@ def word_embedding_weight_initialization_method(self) -> Initializer: if self.word_embedding_weight_initialization.has_initialization: return self.word_embedding_weight_initialization.get_initializer() else: - return self.transformer.hidden_size**-0.5 + return init_normal_(self.transformer.hidden_size**-0.5) @functools.cached_property def position_embedding_weight_initialization_method(self) -> Initializer: if self.position_embedding_weight_initialization.has_initialization: return self.position_embedding_weight_initialization.get_initializer() else: - return self.transformer.hidden_size**-0.5 + return init_normal_(self.transformer.hidden_size**-0.5) @functools.cached_property def output_weight_initialization_method(self) -> Initializer: if self.output_weight_initialization.has_initialization: return self.output_weight_initialization.get_initializer() else: - return self.transformer.hidden_size**-0.5 + return init_normal_(self.transformer.hidden_size**-0.5) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index a546159dd..d99144e4c 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -36,17 +36,14 @@ def __init__( self._distributed_config = self._tensor_space.distributed_config self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if self._config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch self._group_size = self._distributed_config.tensor_parallel self._sequence_parallel = self._distributed_config.sequence_tensor_parallel self._parallel_embeddings = ( - self._tensor_space.distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings ) - self._dropout_p = config.transformer.hidden_dropout - self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab @@ -59,14 +56,14 @@ def __init__( self.word_embeddings_weight = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=self._config.word_embedding_weight_initialization_method, - lr_scale=config.embeddings_lr_scale, + lr_scale=self._config.embeddings_lr_scale, ) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=self._config.position_embedding_weight_initialization_method, - allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + allow_sequence_tensor_parallel=not self._config.parallel_embeddings, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. @@ -78,21 +75,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: - Assert.eq(position_ids is not None, self._use_absolute_position_embeddings) + Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa embeddings = reduce_forward(embeddings, group) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: input_ = split(input_, group=group, dim=0) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: @@ -101,7 +98,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) else: embeddings = torch.embedding(self.word_embeddings_weight, input_) - if self._use_absolute_position_embeddings: + if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) @@ -110,7 +107,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if self._sequence_parallel else self._tensor_space.distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._dropout_p, self.training) + embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 691914b86..8f13a3582 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -5,7 +5,6 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace @@ -15,7 +14,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer +from fast_llm.layers.block.block import BlockLayerBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -32,30 +31,16 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig - def __init__( - self, - config: ConfigType, - tensor_space: TensorSpace, - prediction_distance: int, - ): - super().__init__(config) - self._debug = DebugLayer( - tensor_space, - f"Language model head", - self._config.transformer.debug_transformer, - self._config.transformer.debug_transformer_memory, - ) - self._tensor_space = tensor_space - + def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_distance: int): + super().__init__(config, tensor_space, None, "embedding layer", config.transformer) self._group_size = tensor_space.distributed_config.tensor_parallel - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._parallel_embeddings = ( tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings ) From 2a2b764b826497abb2b596028f106640b13914b9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 1 Aug 2025 19:31:47 -0400 Subject: [PATCH 05/19] stuff --- fast_llm/layers/block/block.py | 1 - fast_llm/layers/block/mlp/mlp.py | 2 +- fast_llm/layers/language_model/head.py | 14 +++++++++++--- fast_llm/layers/transformer/config.py | 24 ++++++++++-------------- tests/utils/model_configs.py | 2 ++ 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 3283aef9c..070f5dc67 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -120,7 +120,6 @@ class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): super().__init__(config, tensor_space, block_index, name, config.block) - self._block_index = block_index @abc.abstractmethod def forward( diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index a96755e0a..0716bf777 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -14,7 +14,7 @@ class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int = 0, name: str = "mlp"): + def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int, name: str): super().__init__(config, tensor_space, block_index, name) hidden_dim = self._tensor_space[BlockDimNames.hidden] diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8f13a3582..fcd1fae2b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -5,6 +5,7 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce +from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace @@ -14,7 +15,7 @@ from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.block import DebugLayer from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.language_model.config import ( LanguageModelBaseConfig, @@ -31,7 +32,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ @@ -39,8 +40,15 @@ class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[Conf config_class: typing.ClassVar[type[LanguageModelBaseConfig]] = LanguageModelBaseConfig def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_distance: int): - super().__init__(config, tensor_space, None, "embedding layer", config.transformer) + super().__init__(config) + self._debug = DebugLayer( + tensor_space, + f"Language model head", + self._config.transformer.debug_transformer, + self._config.transformer.debug_transformer_memory, + ) self._group_size = tensor_space.distributed_config.tensor_parallel + self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._parallel_embeddings = ( tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 6630397eb..e8c319b0f 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -5,7 +5,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_zeros_ +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig @@ -13,9 +13,6 @@ from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div -if typing.TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) @@ -48,7 +45,6 @@ class AttentionKwargs(BlockKwargs): @config_class(dynamic_type={MixerConfig: "attention"}) class AttentionConfig(MixerConfig): - # TODO: Make mixer class dynamic. _abstract = False # Needed for backward compatibility. TODO: remove @@ -111,7 +107,8 @@ class AttentionConfig(MixerConfig): valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) qkv_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the query, key and value layer weights. Default: normal(std=hidden_size**-0.5)", + desc="Initialization configuration for the query, key and value layer weights." + " Default: normal(std=hidden_size**-0.5)", hint=FieldHint.feature, ) qkv_bias_initialization: InitializationConfig = Field( @@ -119,7 +116,8 @@ class AttentionConfig(MixerConfig): hint=FieldHint.feature, ) dense_weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the dense layer weight. Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)", + desc="Initialization configuration for the dense layer weight." + " Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)", hint=FieldHint.feature, ) dense_bias_initialization: InitializationConfig = Field( @@ -129,7 +127,7 @@ class AttentionConfig(MixerConfig): def _validate(self) -> None: with self._set_implicit_default(): - # TODO: Make this work without inheritance. + # TODO: hidden_size not yet validated. if self.kv_channels is None: self.kv_channels = div(self.block.hidden_size, self.num_attention_heads) @@ -182,16 +180,14 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) ) - @property + @functools.cached_property def add_qkv_bias(self) -> bool: - # TODO: Make this work without inheritance. if isinstance(self.block.add_linear_biases, bool): return self.block.add_linear_biases return self.block.add_linear_biases != AddLinearBiasChoices.nowhere - @property + @functools.cached_property def add_dense_bias(self) -> bool: - # TODO: Make this work without inheritance. if isinstance(self.block.add_linear_biases, bool): return self.block.add_linear_biases return self.block.add_linear_biases == AddLinearBiasChoices.everywhere @@ -201,7 +197,7 @@ def qkv_weight_initialization_method(self) -> Initializer: if self.qkv_weight_initialization.has_initialization: return self.qkv_weight_initialization.get_initializer() else: - return self.block.hidden_size**-0.5 + return init_normal_(0, self.block.hidden_size**-0.5) @functools.cached_property def qkv_bias_initialization_method(self) -> Initializer: @@ -215,7 +211,7 @@ def dense_weight_initialization_method(self) -> Initializer: if self.dense_weight_initialization.has_initialization: return self.dense_weight_initialization.get_initializer() else: - return self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1) + return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) @functools.cached_property def dense_bias_initialization_method(self) -> Initializer: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 722d8d63a..4705ebb79 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,6 +162,7 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -258,6 +259,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", + "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ], From dfe4780b19a61a488d07175c990ceaa19ebe4a1e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 8 Aug 2025 16:37:39 -0400 Subject: [PATCH 06/19] stuff --- fast_llm/config.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/fast_llm/config.py b/fast_llm/config.py index 87cac34ad..099670625 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1038,13 +1038,6 @@ def __init_subclass__(cls): def config(self) -> ConfigType: return self._config - def __init_subclass__(cls): - # Automatically set `config_class` based on the bound type. - # Make sure `ConfigType` is bound and respects class hierarchy. - # TODO: Remove manual sets. - Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) - cls.config_class = config_class - def set_nested_dict_value[ KeyType, ValueType From 45618432cf29756d190c434d1eccee560302abee Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 17:05:54 -0400 Subject: [PATCH 07/19] stuff --- Megatron-LM | 2 +- fast_llm/config.py | 18 +- fast_llm/engine/base_model/base_model.py | 57 ++-- fast_llm/engine/base_model/config.py | 7 +- fast_llm/engine/config_utils/tensor_space.py | 267 ------------------ fast_llm/engine/multi_stage/fsdp.py | 16 +- fast_llm/engine/multi_stage/multi_stage.py | 2 +- fast_llm/engine/multi_stage/stage.py | 13 +- fast_llm/engine/multi_stage/stage_base.py | 4 +- fast_llm/engine/schedule/runner.py | 2 +- fast_llm/layers/common/peft.py | 2 +- .../layers/language_model/preprocessing.py | 38 +-- fast_llm/layers/transformer/preprocessing.py | 51 ++-- fast_llm/layers/transformer/rotary/config.py | 6 +- .../transformer/rotary/preprocessing.py | 68 ----- fast_llm/layers/transformer/rotary/rotary.py | 57 ++-- fast_llm/logging.py | 6 +- fast_llm/models/custom/model.py | 29 +- fast_llm/tensor.py | 25 +- tests/functional/test_triton_kernels.py | 4 +- tests/test_attention.py | 35 +-- tests/test_mlp.py | 29 -- tests/utils/global_variables.py | 4 +- tests/utils/utils.py | 7 +- 24 files changed, 160 insertions(+), 589 deletions(-) delete mode 100644 fast_llm/engine/config_utils/tensor_space.py delete mode 100644 fast_llm/layers/transformer/rotary/preprocessing.py delete mode 100644 tests/test_mlp.py diff --git a/Megatron-LM b/Megatron-LM index 75b0d9787..f02b413f7 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 +Subproject commit f02b413f793af05ade3893bccd8aef6d644d3edf diff --git a/fast_llm/config.py b/fast_llm/config.py index 099670625..3352f3570 100644 --- a/fast_llm/config.py +++ b/fast_llm/config.py @@ -1031,7 +1031,23 @@ def __init__(self, config: ConfigType, *args, **kwargs): def __init_subclass__(cls): # Automatically set `config_class` based on the bound type. # Make sure `ConfigType` is bound and respects class hierarchy. - Assert.custom(issubclass, config_class := ConfigType.__bound__, cls.config_class) + try: + config_class = None + for base in types.get_original_bases(cls): + if hasattr(base, "__origin__") and issubclass(base.__origin__, Configurable): + for arg in base.__args__: + if arg.__name__ == "ConfigType": + if config_class is None: + config_class = arg.__bound__ + else: + assert arg.__bound__ is config_class + assert config_class is not None + except Exception as e: + raise TypeError( + f"Could not determine the configuration class for the configurable class {cls.__name__}: {e.args}. " + "Please make sure to declare in the format " + f"`class {cls.__name__}[ConfigType: ConfigClass](BaseConfigurable[ConfigType])`.] " + ) cls.config_class = config_class @property diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index caaf94794..832225803 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -7,7 +7,6 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.tensor import ParameterMeta, TensorMeta @@ -20,11 +19,18 @@ class Module(torch.nn.Module, abc.ABC): """ """ - def forward(self, input_, kwargs): - """ - Run a forward pass for the module, with autograd support. - """ - raise NotImplementedError() + _is_setup: bool = False + _distributed: Distributed + + def __init__(self, distributed_config: DistributedConfig): + self._distributed_config = distributed_config + super().__init__() + + def setup(self, distributed: Distributed) -> None: + assert not self._is_setup + distributed.check_config(self._distributed_config) + self._distributed = distributed + self._is_setup = True class Layer(Module): @@ -39,9 +45,9 @@ def forward( class Sequential(Layer): - def __init__(self, layers: list[Layer]): - super().__init__() - self.layers = torch.nn.ModuleList(layers) + def __init__(self, distributed_config: DistributedConfig): + super().__init__(distributed_config) + self.layers = torch.nn.ModuleList(self.get_layers()) def __getitem__(self, item): return self.layers[item] @@ -59,6 +65,15 @@ def forward( input_ = layer(input_, kwargs, losses, metrics) return input_ + @abc.abstractmethod + def get_layers(self) -> list[Layer]: + pass + + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + for layer in self.layers: + layer.setup(distributed) + @dataclasses.dataclass() class LossDef: @@ -71,28 +86,14 @@ class LossDef: dtype: torch.dtype = torch.float32 -class SequentialLayers(Sequential, abc.ABC): - # Small class defined to fix the MRO of BaseModel.__init__ - def __init__(self): - super().__init__(self.get_layers()) - - @abc.abstractmethod - def get_layers(self) -> list[Layer]: - pass - - -class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], SequentialLayers, abc.ABC): - _is_setup: bool = False +class BaseModel[ConfigType: BaseModelConfig](Configurable[ConfigType], Sequential): def __init__( self, config: BaseModelConfig, distributed_config: DistributedConfig, ): - self._tensor_space: TensorSpace = TensorSpace(distributed_config) - config.setup_tensor_space(self._tensor_space) - - super().__init__(config) + super().__init__(config, distributed_config) for key, value in self.named_parameters(): Assert.custom(isinstance, value, ParameterMeta) @@ -103,12 +104,6 @@ def __init__( # TODO: Add basic handling (preprocessor) in this class. self._reference_models: dict[str, "InferenceRunner"] = {} - def setup(self, distributed: Distributed) -> None: - assert not self._is_setup - distributed.check_config(self._tensor_space.distributed_config) - self._tensor_space.setup(distributed) - self._is_setup = True - @abc.abstractmethod def get_layers(self) -> list[Layer]: pass diff --git a/fast_llm/engine/base_model/config.py b/fast_llm/engine/base_model/config.py index 4be42e069..22abb021b 100644 --- a/fast_llm/engine/base_model/config.py +++ b/fast_llm/engine/base_model/config.py @@ -6,7 +6,7 @@ from fast_llm.utils import compare_nested, log if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_space import TensorSpace + import torch @config_class() @@ -18,9 +18,6 @@ class BaseModelConfig(Config): _abstract = True - def setup_tensor_space(self, tensor_space: "TensorSpace") -> None: - raise NotImplementedError() - def compare_architecture( self, model_config: typing.Self, @@ -64,5 +61,5 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass @abc.abstractmethod - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: "torch.Tensor", kwargs: dict[str, typing.Any]) -> None: pass diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py deleted file mode 100644 index 6c4b95b20..000000000 --- a/fast_llm/engine/config_utils/tensor_space.py +++ /dev/null @@ -1,267 +0,0 @@ -import logging -import math -import typing - -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim -from fast_llm.utils import Assert, div - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed - -logger = logging.getLogger(__name__) - - -class TensorDim: - def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): - # TODO: Handle None for unknown sizes? - self._name = name - self._global_size = global_size - self._size = self._global_size if parallel_dim is None else div(global_size, parallel_dim.size) - self._parallel_dim = parallel_dim - - def __repr__(self) -> str: - return ( - f"{type(self).__name__}(" - f"name={self._name}," - f" size={self._size}," - f" global_size={self._global_size}," - f" parallel_dim={self._parallel_dim}" - f")" - ) - - def __eq__(self, other): - return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ - - @property - def name(self) -> str: - return self._name - - @property - def size(self) -> int: - return self._size - - @property - def global_size(self) -> int: - return self._global_size - - @property - def is_parallel(self) -> bool: - return self._parallel_dim is not None and self._parallel_dim.size > 1 - - @property - def parallel_dim(self) -> DistributedDim | None: - # TODO: Make more flexible for derived classes? - return self._parallel_dim - - @property - def parallel_group(self) -> "ProcessGroup|None": - # TODO: Make more flexible for derived classes? - return None if self._parallel_dim is None else self._parallel_dim.group - - def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.is_parallel - return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) - - def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - if self.is_parallel: - from fast_llm.core.ops import gather_op - - return gather_op(tensor, self.parallel_group, dim) - else: - return tensor - - def local_to_global_partial( - self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 - ) -> "torch.Tensor": - if self.is_parallel: - output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) - output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) - return output.flatten(dim, dim + 1) - else: - return tensor - - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - return ( - tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] - if self.parallel_dim is not None and self.parallel_dim.size > 1 - else tensor - ) - - -class CompositeTensorDim(TensorDim): - def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): - parallel_dim = None - for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.parallel_dim is not None: - # TODO: Allow more than one parallel subdim? - assert parallel_dim is None - parallel_dim = tensor_dim.parallel_dim - self._parallel_dim_index = dim - - super().__init__( - name=name, - global_size=math.prod(dim.global_size for dim in tensor_dims), - parallel_dim=parallel_dim, - ) - self._tensor_dims = tensor_dims - - def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self._parallel_dim_index is not None - dims = list(self._tensor_dims) - dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) - - def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) - for i, tensor_dim in enumerate(self._tensor_dims): - tensor = tensor_dim.local_to_global(tensor, dim + i) - - return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - - def local_to_global_partial( - self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 - ) -> "torch.Tensor": - tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) - for i, tensor_dim in enumerate(self._tensor_dims): - tensor = tensor_dim.local_to_global_partial(tensor, dim + i) - - return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) - for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): - tensor = tensor_dim.global_to_local(tensor, dim + i) - return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - - -class ConcatenatedTensorDim(TensorDim): - def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): - parallel_dim = tensor_dims[0].parallel_dim - for dim, tensor_dim in enumerate(tensor_dims[1:]): - # TODO: Allow more flexibility? - Assert.is_(tensor_dim.parallel_dim, parallel_dim) - - super().__init__( - name=name, - global_size=sum(dim.global_size for dim in tensor_dims), - parallel_dim=parallel_dim, - ) - self._tensor_dims = tensor_dims - - def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.is_parallel - return ConcatenatedTensorDim( - self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) - ) - - def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": - import torch - - return ( - torch.concatenate( - [ - tensor_dim.local_to_global(tensor_, dim) - for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), - self._tensor_dims, - strict=True, - ) - ], - dim, - ) - if self.is_parallel - else tensor - ) - - def local_to_global_partial( - self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 - ) -> "torch.Tensor": - import torch - - return ( - torch.concatenate( - [ - tensor_dim.local_to_global_partial(tensor_, dim) - for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), - self._tensor_dims, - strict=True, - ) - ], - dim, - ) - if self.is_parallel - else tensor - ) - - def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": - if self.is_parallel and expand: - raise NotImplementedError() - import torch - - return ( - torch.concatenate( - [ - tensor_dim.global_to_local(tensor_, dim) - for tensor_, tensor_dim in zip( - tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), - self._tensor_dims, - strict=True, - ) - ], - dim, - ) - if self.is_parallel - else tensor - ) - - -class DefaultDimNames: - # Scalar - scalar = "scalar" - - -class TensorSpace: - _is_setup: bool = False - _distributed: "Distributed" - - def __init__(self, distributed_config: DistributedConfig): - self._distributed_config = distributed_config - self._tensor_dims: dict[str, TensorDim] = {} - self.add_tensor_dim(TensorDim(DefaultDimNames.scalar, 1)) - - def setup(self, distributed: "Distributed") -> None: - assert not self._is_setup - if distributed.config is not self._distributed_config: - distributed.config.compare(self._distributed_config, ValueError) - self._is_setup = True - self._distributed = distributed - - @property - def distributed_config(self) -> DistributedConfig: - return self._distributed_config - - @property - def distributed(self) -> "Distributed": - assert self._is_setup - return self._distributed - - def add_tensor_dim(self, tensor_dim: TensorDim) -> None: - if tensor_dim.name in self._tensor_dims: - Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) - else: - if tensor_dim.parallel_dim is not None: - assert ( - tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims - ), tensor_dim.parallel_dim.name - Assert.eq( - tensor_dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, - ) - self._tensor_dims[tensor_dim.name] = tensor_dim - - def __getitem__(self, name: str) -> TensorDim: - return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index be15cd37a..cb0a02a67 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -9,7 +9,7 @@ from fast_llm.core.distributed import ProcessGroup from fast_llm.core.ops import gather_op from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import SHARD_PAD_TO_MULTIPLE, ShardName, StageMode @@ -320,27 +320,31 @@ def import_state_tensor( return end - begin def export_shard( - self, shard: torch.Tensor, distributed: Distributed, data_type: DataType | None = None + self, shard: torch.Tensor, data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: if data_type is not None: shard = shard.to(dtype=data_type.torch) tensors = self.split_buffer(self.reconstruct_from_shard(shard)) for name, meta in self._parameter_metas.items(): - yield name, meta.local_to_global(tensors[name], distributed=distributed)[0] + yield name, meta.local_to_global(tensors[name])[0] def log_shard(self, name, shard, *, distributed: Distributed, level, global_: bool) -> None: # if global_ is None: # global_ = self._config.debug_global_tensors parameters = self.split_buffer(self.reconstruct_from_shard(shard)) if global_ else self.split_shard(shard) for parameter_name, parameter in parameters.items(): + meta = self.get_parameter_meta(parameter_name) log_distributed_tensor( name, parameter, level=level, - distributed=distributed, global_=global_, - duplicate_groups=(distributed.data_group,), - meta=self.get_parameter_meta(parameter_name), + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=( + distributed.data_group, + distributed.tensor_group, + ), + meta=meta, ) def restore_parameters(self) -> None: diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index e17bc4ff8..d939bda2b 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -12,7 +12,7 @@ from fast_llm.engine.base_model.base_model import BaseModel from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.run import log_main_rank, log_model_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, ShardName, StageMode diff --git a/fast_llm/engine/multi_stage/stage.py b/fast_llm/engine/multi_stage/stage.py index 87eac31c4..35547cd87 100644 --- a/fast_llm/engine/multi_stage/stage.py +++ b/fast_llm/engine/multi_stage/stage.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import check_parallel_match from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.engine.multi_stage.config import StageMode +from fast_llm.engine.multi_stage.config import StageConfig, StageMode from fast_llm.engine.multi_stage.stage_base import StageBase from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage, log_tensor from fast_llm.tensor import ParameterMeta, TensorMeta, accumulate_gradient @@ -30,7 +30,7 @@ def hook(grad_inputs, grad_outputs): # noqa return hook -class Stage(StageBase): +class Stage[ConfigType: StageConfig](StageBase[ConfigType]): _is_restored: bool _training: bool | None = None # TODO: Handle all buffer sharing in multi_stage @@ -123,7 +123,7 @@ def forward( # Last layer does not provide output if output is not None: meta = self._meta_outputs[i] - output_global, _ = meta.local_to_global(output.detach(), distributed=self._distributed) + output_global, _ = meta.local_to_global(output.detach()) kwargs["hidden_states"][self._layer_range[i]] = { "layer_type": type(layer).__name__, "tensor": output_global, @@ -216,11 +216,13 @@ def _log_layer_forward(self, output: torch.Tensor, kwargs: dict[str, typing.Any] if (nms := kwargs.get("micro_batch_splits", 1)) > 1: name = f"{name}, ms={kwargs.get('micro_batch_split',0)}/{nms}" + # Assuming all tensors are either duplicated of parallel in the TP direction. log_distributed_tensor( name, output, level=self._config.debug_layer_outputs, - distributed=self._distributed, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_outputs[i], ) @@ -250,8 +252,9 @@ def _log_layer_backward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any name, input_, level=self._config.debug_layer_gradients, - distributed=self._distributed, grad_fn=lambda grad: grad / self._fsdp_size, + # Assuming all tensors are either duplicated of parallel in the TP direction. + duplicate_groups=(self._distributed.tensor_group,), global_=self._config.debug_global_tensors, meta=self._meta_inputs[i], ) diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 387a53a03..ded24e538 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -20,7 +20,7 @@ logger = logging.getLogger(__name__) -class StageBase(Configurable[StageConfig]): +class StageBase[ConfigType: StageConfig](Configurable[ConfigType]): _distributed: Distributed _mode: StageMode @@ -314,7 +314,7 @@ def _export_shard( self, shards: tuple[torch.Tensor], data_type: DataType | None = None ) -> typing.Generator[tuple[str, torch.Tensor], None, None]: for fsdp, shard in zip(self._fsdps, shards, strict=True): - yield from fsdp.export_shard(shard, self._distributed, data_type) + yield from fsdp.export_shard(shard, data_type) def _get_parameter_metas(self) -> tuple[list[ParameterMeta], list[ParameterMeta]]: # Get all the stage parameters, diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 8eca4559d..21ecbe476 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -63,7 +63,7 @@ def __repr__(self): ) -class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ScheduleConfig]): +class ScheduleRunner[ConfigType: ScheduleConfig](Configurable[ConfigType]): _is_setup: bool = False _compute_stream: torch.cuda.Stream _data_stream: torch.cuda.Stream diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 08f3e535b..87991ef29 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -2,7 +2,7 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.common.linear import Linear, LinearBase diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index 440ce9580..5ba31c0d0 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import TensorMeta from fast_llm.utils import Assert @@ -13,40 +14,31 @@ class PositionEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _position_ids: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: LanguageModelBaseConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config assert config.use_absolute_position_embeddings - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length - Assert.leq(sequence_length, self._config.absolute_position_embeddings) - self._position_ids = torch.arange( - 0, sequence_length, device=self._tensor_space.distributed.device, dtype=torch.int64 - ) + Assert.leq(sequence_length, self._config.num_absolute_position_embeddings) + self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[LanguageModelKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[LanguageModelKwargs.sequence_length], batch.device) sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size if (sequence_lengths := kwargs.get(LanguageModelKwargs.sequence_lengths)) is not None: position_ids = torch.stack( [torch.cat([torch.arange(x) for x in sample_lens]) for sample_lens in sequence_lengths] - ).to(self._tensor_space.distributed.device, dtype=torch.int64) + ).to(batch.device, dtype=torch.int64) position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] if kwargs[LanguageModelKwargs.sequence_first]: position_ids = position_ids.transpose(0, 1) @@ -61,9 +53,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: sequence_q_dim = kwargs[LanguageModelKwargs.sequence_q_dim] kwargs[LanguageModelKwargs.position_ids] = TensorMeta.from_dims( ( - (sequence_q_dim, self._scalar_dim) + (sequence_q_dim, scalar_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (self._scalar_dim, sequence_q_dim) + else (scalar_dim, sequence_q_dim) ), tensor_name=LanguageModelKwargs.position_ids, dtype=torch.int64, @@ -71,11 +63,9 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: class PreferenceSpanPreprocessor(Preprocessor): - def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): + def __init__(self, config: LanguageModelBaseConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._distributed_config = distributed_config def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index 16e5811e6..769177668 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -4,7 +4,8 @@ import torch from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta @@ -12,25 +13,18 @@ class BackupAttentionPreprocessor(Preprocessor): - _scalar_dim: TensorDim _kv_channels_dim: TensorDim _rotary_embedding_frequencies: torch.Tensor _mask: torch.Tensor _mask_value: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: AttentionConfig, - tensor_space: TensorSpace, - ): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -38,7 +32,7 @@ def _create_tensors(self, sequence_length: int) -> None: self._mask = torch.ones( (sequence_length, sequence_length), dtype=torch.bool, - device=self._tensor_space.distributed.device, + device=device, ).tril_() if self._config.window_size is not None: @@ -47,11 +41,11 @@ def _create_tensors(self, sequence_length: int) -> None: [], torch.finfo(self._distributed_config.training_dtype.torch).min, dtype=self._distributed_config.training_dtype.torch, - device=self._tensor_space.distributed.device, + device=device, ) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size sequence_q = kwargs[AttentionKwargs.sequence_q_dim].size kwargs[AttentionKwargs.attention_mask] = self._mask[ @@ -64,7 +58,7 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: for sample_lens in sequence_lengths ] ) - document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(self._tensor_space.distributed.device) + document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None]).to(batch.device) kwargs[AttentionKwargs.attention_mask] = ( kwargs[AttentionKwargs.attention_mask] & document_mask[:, None, sequence_k - sequence_q : sequence_k, None, :sequence_k] @@ -74,30 +68,29 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.attention_mask] = TensorMeta.from_dims( ( - self._scalar_dim, - self._scalar_dim, + scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_k_dim], ), tensor_name=AttentionKwargs.attention_mask, dtype=torch.bool, ) kwargs[AttentionKwargs.attention_mask_value] = TensorMeta.from_dims( - (self._scalar_dim,), + (scalar_dim,), tensor_name=AttentionKwargs.attention_mask_value, - dtype=self._tensor_space.distributed_config.training_dtype.torch, + dtype=self._distributed_config.training_dtype.torch, ) class FlashAttnVarlenPreprocessor(Preprocessor): - def __init__(self, config: AttentionConfig, tensor_space: TensorSpace): + def __init__(self, config: AttentionConfig, distributed_config: DistributedConfig): self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + self._distributed_config = distributed_config assert self._config.do_use_flash_attention(self._distributed_config) - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: """ Prepares cu_seqlens_q and cu_seqlens_k for flash_attn_varlen_func: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_interface.py#L1375 @@ -148,14 +141,14 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: seqlens_k = torch.cat(sequence_lengths) kwargs[AttentionKwargs.cu_seqlens_q] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_q, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.cu_seqlens_k] = torch.cat( ( - torch.zeros(1, dtype=torch.int32, device=self._tensor_space.distributed.device), - torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(self._tensor_space.distributed.device), + torch.zeros(1, dtype=torch.int32, device=batch.device), + torch.cumsum(seqlens_k, dim=0, dtype=torch.int32).to(batch.device), ) ) kwargs[AttentionKwargs.max_seqlen_q] = seqlens_q.max() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index 748f2af28..f0e0079c7 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -5,7 +5,7 @@ from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.utils import Assert @@ -29,8 +29,8 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, tensor_space: TensorSpace | None = None) -> "Rotary": - return self._get_configurable_class()(self, tensor_space) + def build(self, kv_channels_dim: TensorDim) -> "Rotary": + return self._get_configurable_class()(self, kv_channels_dim) @classmethod @abc.abstractmethod diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py deleted file mode 100644 index 9f8732f85..000000000 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ /dev/null @@ -1,68 +0,0 @@ -import typing - -import torch - -from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.tensor import TensorMeta - - -class RotaryEmbeddingPreprocessor(Preprocessor): - _scalar_dim: TensorDim - _kv_channels_dim: TensorDim - _rotary_embedding_frequencies: torch.Tensor - _mask: torch.Tensor - _mask_value: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - - def __init__( - self, - config: DefaultRotaryConfig, - tensor_space: TensorSpace, - ): - self._config = config - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) - sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size - kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ - :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k - ] - kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] - - def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_q, - ) - kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( - ( - self._scalar_dim, - kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, - self._kv_channels_dim, - ), - tensor_name=AttentionKwargs.rotary_freq_k, - ) - - def _create_tensors(self, sequence_length: int) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - self._rotary_embedding_frequencies = self._config.get_frequencies( - sequence_length, - self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, - ) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index ebb629aa1..bbf8b524a 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -6,9 +6,9 @@ from fast_llm.config import Configurable from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, @@ -41,14 +41,14 @@ def apply_rotary_embeddings(tensor: torch.Tensor, rope_frequencies: torch.Tensor return torch.view_as_real(complex_tensor * rope_frequencies).view_as(tensor).type_as(tensor) -class Rotary[ConfigType: RotaryConfig](Configurable[RotaryConfig], torch.nn.Module, Preprocessor): +class Rotary[ConfigType: RotaryConfig](Configurable[ConfigType], torch.nn.Module, Preprocessor): def __init__( self, config: ConfigType, - # The tensor space is only needed for preprocessing, so we make it optional. - tensor_space: TensorSpace | None = None, + kv_channels_dim: TensorDim, ): super().__init__(config) + self._kv_channels_dim = kv_channels_dim @abc.abstractmethod def forward( @@ -57,7 +57,7 @@ def forward( pass -class NoRotary[ConfigType: NoRotaryConfig](Rotary[NoRotaryConfig]): +class NoRotary[ConfigType: NoRotaryConfig](Rotary[ConfigType]): def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: @@ -70,24 +70,12 @@ def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: pass -class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[DefaultRotaryConfig]): +class DefaultRotary[ConfigType: DefaultRotaryConfig](Rotary[ConfigType]): _rotary_embedding_frequencies: torch.Tensor _tensor_cache_max_sequence_length: int = -1 - def __init__( - self, - config: ConfigType, - tensor_space: TensorSpace | None = None, - ): - super().__init__(config, tensor_space) - self._tensor_space = tensor_space - if self._tensor_space is not None: - self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] - self._kv_channels_dim = self._tensor_space[AttentionDimNames.kv_channels] - - def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None - self._create_tensors(kwargs[AttentionKwargs.sequence_length]) + def preprocess(self, batch: torch.Tensor, kwargs: dict[str, typing.Any]) -> None: + self._create_tensors(kwargs[AttentionKwargs.sequence_length], batch.device) sequence_k = kwargs[AttentionKwargs.sequence_k_dim].size kwargs[AttentionKwargs.rotary_freq_q] = self._rotary_embedding_frequencies[ :, sequence_k - kwargs[AttentionKwargs.sequence_q_dim].size : sequence_k @@ -95,21 +83,20 @@ def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: kwargs[AttentionKwargs.rotary_freq_k] = self._rotary_embedding_frequencies[:, :sequence_k] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: - assert self._tensor_space is not None kwargs[AttentionKwargs.rotary_freq_q] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_q, ) kwargs[AttentionKwargs.rotary_freq_k] = TensorMeta.from_dims( ( - self._scalar_dim, + scalar_dim, kwargs[AttentionKwargs.sequence_q_dim], - self._scalar_dim, + scalar_dim, self._kv_channels_dim, ), tensor_name=AttentionKwargs.rotary_freq_k, @@ -123,7 +110,7 @@ def forward( key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key - def _create_tensors(self, sequence_length: int) -> None: + def _create_tensors(self, sequence_length: int, device: torch.device) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: return self._tensor_cache_max_sequence_length = sequence_length @@ -131,10 +118,10 @@ def _create_tensors(self, sequence_length: int) -> None: self._rotary_embedding_frequencies = self._get_frequencies( sequence_length, self._kv_channels_dim.global_size, - device=self._tensor_space.distributed.device, + device=device, ) - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: # Calculate the complex frequencies (https://blog.eleuther.ai/rotary-embeddings/) # `exp(i * n * a) = cos(n * a) + i sin(n * a)`, # `a = theta ** - (2 * (channel // 2) / kv_channels)`, @@ -149,12 +136,12 @@ def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda" ).contiguous() return frequencies - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: return self._config.theta ** -torch.arange(0, 1, 2 / kv_channels, device=device, dtype=torch.float64) -class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[Llama3RotaryConfig]): - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: +class Llama3Rotary[ConfigType: Llama3RotaryConfig](DefaultRotary[ConfigType]): + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) low_frequency_wavelength = self._config.original_context_length / self._config.low_frequency_factor high_frequency_wavelength = self._config.original_context_length / self._config.high_frequency_factor @@ -173,17 +160,17 @@ def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: return torch.stack(new_scales) -class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[YarnRotaryConfig]): +class YarnRotary[ConfigType: YarnRotaryConfig](DefaultRotary[ConfigType]): """ Yarn scaling: https://github.com/huggingface/transformers/blob/006d9249ec0270ff6c4d3840979d23fe94bdc763/src/transformers/modeling_rope_utils.py#L163 [original paper](https://arxiv.org/abs/2309.00071) """ - def _get_frequencies(self, sequence_length: int, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_frequencies(self, sequence_length: int, kv_channels: int, device: torch.device) -> torch.Tensor: return super()._get_frequencies(sequence_length, kv_channels, device) * self._config.attention_factor - def _get_angle_scales(self, kv_channels: int, device="cuda") -> torch.Tensor: + def _get_angle_scales(self, kv_channels: int, device: torch.device) -> torch.Tensor: scales = super()._get_angle_scales(kv_channels, device) # TODO: max_position_embeddings or original_context_length? # see https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L304 diff --git a/fast_llm/logging.py b/fast_llm/logging.py index 6d555a0bb..024d7d79c 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -14,7 +14,6 @@ if typing.TYPE_CHECKING: from fast_llm.core.distributed import ProcessGroup - from fast_llm.engine.distributed.distributed import Distributed logger = logging.getLogger(__name__) @@ -254,7 +253,6 @@ def log_distributed_tensor[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), global_: bool = True, log_fn: type[BaseException] | typing.Callable[[str], T] | None = logger.info, @@ -263,7 +261,7 @@ def log_distributed_tensor[ if level <= 0: return if global_: - tensor, is_first_rank = meta.local_to_global(tensor, distributed=distributed) + tensor, is_first_rank = meta.local_to_global(tensor) storage = False is_first_rank = is_first_rank and all(group.rank() == 0 for group in duplicate_groups if group) if not is_first_rank: @@ -289,7 +287,6 @@ def log_distributed_grad[ scale: float = 1.0, level: int = 2, storage: bool = False, - distributed: "Distributed", duplicate_groups: tuple[typing.Optional["ProcessGroup"], ...] = (), grad_fn: typing.Callable[[torch.Tensor], torch.Tensor] | None = None, global_: bool = True, @@ -305,7 +302,6 @@ def log_distributed_grad[ scale=scale, level=level, storage=storage, - distributed=distributed, duplicate_groups=duplicate_groups, global_=global_, log_fn=log_fn, diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index ea56b7b5a..3afd88ce1 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -3,11 +3,9 @@ import torch from fast_llm.data.data.gpt.data import GPTBatch -from fast_llm.engine.base_model.base_model import Layer, LossDef +from fast_llm.engine.base_model.base_model import LossDef from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -17,26 +15,21 @@ class CustomBaseModel[ConfigType: CustomBaseModelConfig](GPTBaseModel[ConfigType]): def __init__( self, - config: CustomBaseModelConfig, + config: ConfigType, distributed_config: DistributedConfig, ): # TODO: Implement / update. super().__init__(config, distributed_config) - def get_layers(self) -> list[Layer]: - # TODO: Adjust as needed. - return [ - LanguageModelEmbedding(self._config, self._tensor_space), - *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - ) - for i in range(self._config.transformer.num_blocks) - ], - CustomHead(self._config, self._tensor_space), - ] + def _get_head(self, prediction_distance): + return CustomHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) def preprocess_meta( self, batch_meta: BatchConfig | torch.Tensor, phase: PhaseType diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b12d12072..b6180c190 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -7,7 +7,7 @@ from fast_llm.core.distributed import ReduceOp from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.initialization import Initializer, LambdaInitializer -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy @@ -138,30 +138,11 @@ def from_dims( **kwargs, ) - @classmethod - def from_tensor_space( - cls, - dim_names: tuple[str, ...], - tensor_space: TensorSpace, - *, - tensor_name: str = "", - dtype: torch.dtype = torch.float32, - reductions: tuple[tuple[str, ReduceOp], ...] = (), - **kwargs: typing.Any, - ) -> typing.Self: - dims = tuple(tensor_space[dim_name] for dim_name in dim_names) - if reductions: - # kwarg not available for ParameterMeta, so we only provide if necessary. - kwargs["reductions"] = tuple( - (tensor_space.distributed_config.get_distributed_dim(name), op) for name, op in reductions - ) - return cls.from_dims(dims, tensor_name=tensor_name, dtype=dtype, **kwargs) - @property def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor) -> tuple[torch.Tensor, ...]: """ Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. Returns a view of the input tensor (or the input tensor itself) when possible. @@ -171,7 +152,7 @@ def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank, modified = distributed.config.tensor_rank == 0, False + is_first_rank, modified = True, False for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e61f72244..e4ad937b7 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build() + .build(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build()._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, diff --git a/tests/test_attention.py b/tests/test_attention.py index 534e3800e..7d05e0a66 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -2,11 +2,12 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs, TransformerConfig +from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert @@ -30,19 +31,6 @@ def test_decide_window_size(): assert attention._decide_window_size() == 512 -def test_attention_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - Attention(transformer_conf, tensor_space, 1) - - def test_varlen_preprocessor(): sequence_lengths = [torch.tensor([8, 13, 4, 11], dtype=torch.int32), torch.tensor([11, 16, 9], dtype=torch.int32)] # First micro-sequence: @@ -63,27 +51,24 @@ def test_varlen_preprocessor(): ] micro_sequence_length = 12 sequence_length = 36 - transformer_cfg = TransformerConfig( + transformer_config = TransformerConfig( num_layers=2, num_attention_heads=2, hidden_size=16, use_flash_attention=True, ) - distributed_cfg = DistributedConfig(training_dtype="bfloat16") - distributed = Distributed(distributed_cfg, use_cpu=True) - tensor_space = TensorSpace(distributed_config=distributed_cfg) - tensor_space.setup(distributed) - transformer_cfg.setup_tensor_space(tensor_space) - varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_cfg, tensor_space=tensor_space) + distributed_config = DistributedConfig(training_dtype="bfloat16") + distributed = Distributed(distributed_config, use_cpu=True) + varlen_preprocessor = FlashAttnVarlenPreprocessor(transformer_config, distributed_config=distributed_config) for micro_seq_idx in range(int(sequence_length / micro_sequence_length)): kwargs = { - AttentionKwargs.sequence_q_dim: TensorDim(AttentionDimNames.sequence_k, micro_sequence_length), + AttentionKwargs.sequence_q_dim: TensorDim(BlockDimNames.sequence_k, micro_sequence_length), AttentionKwargs.sequence_k_dim: TensorDim( - AttentionDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length + BlockDimNames.sequence_k, (micro_seq_idx + 1) * micro_sequence_length ), AttentionKwargs.sequence_length: sequence_length, AttentionKwargs.sequence_lengths: sequence_lengths, } - varlen_preprocessor.preprocess(None, kwargs) + varlen_preprocessor.preprocess(torch.empty(1, device="cpu"), kwargs) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_q], cumulative_sequences_q[micro_seq_idx]) Assert.all_equal(kwargs[AttentionKwargs.cu_seqlens_k], cumulative_sequences_k[micro_seq_idx]) diff --git a/tests/test_mlp.py b/tests/test_mlp.py deleted file mode 100644 index 802833eb2..000000000 --- a/tests/test_mlp.py +++ /dev/null @@ -1,29 +0,0 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP -from fast_llm.layers.block.mlp.mlp import MLP -from fast_llm.layers.transformer.config import TransformerConfig - - -def test_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, - num_attention_heads=2, - hidden_size=16, - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MLP(transformer_conf, tensor_space, 0, "name") - - -def test_moe_mlp_constructor(): - transformer_conf = TransformerConfig( - num_layers=2, num_attention_heads=2, hidden_size=16, num_experts=2, add_linear_biases=False - ) - distributed_config = DistributedConfig() - tensor_space = TensorSpace(distributed_config=distributed_config) - transformer_conf.setup_tensor_space(tensor_space) - - MixtureOfExpertMLP(transformer_conf, tensor_space, 0, "name") diff --git a/tests/utils/global_variables.py b/tests/utils/global_variables.py index 80232bf53..42e588911 100644 --- a/tests/utils/global_variables.py +++ b/tests/utils/global_variables.py @@ -29,8 +29,8 @@ def set_testing_global_variables(): num_gpus = len(gpus) gpus = [gpus[(i + worker_id) % num_gpus] for i in range(num_gpus)] os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in gpus) - os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") - os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") + # os.environ["TORCHINDUCTOR_CACHE_DIR"] = str(SHARED_RESULT_PATH / "torchinductor_cache") + # os.environ["TRITON_CACHE_DIR"] = str(SHARED_RESULT_PATH / "triton_cache") # TODO: Fixtures diff --git a/tests/utils/utils.py b/tests/utils/utils.py index 88303a0f4..0dc3462eb 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -13,7 +13,6 @@ from fast_llm.core.distributed import ProcessGroup, allreduce_scalar, safe_barrier from fast_llm.engine.base_model.base_model import BaseModel, Layer from fast_llm.engine.config_utils.logging import configure_logging -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage @@ -33,12 +32,8 @@ def result_path(): def get_base_model(config: FastLLMModelConfig): # Create a base model (and distributed). # Using a full model config so we have the model type and distributed config in the same argument. - distributed = Distributed(config.distributed) - tensor_space = TensorSpace(config.distributed) - config.base_model.setup_tensor_space(tensor_space) - tensor_space.setup(distributed) base_model = config.get_model_class().base_model_class(config.base_model, config.distributed) - base_model.setup(distributed) + base_model.setup(distributed := Distributed(config.distributed)) return base_model, distributed From ccbb38f563da603a8c8a20fa70a08c7ccf74de5d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 17:07:22 -0400 Subject: [PATCH 08/19] stuff --- fast_llm/engine/config_utils/tensor_dim.py | 221 +++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 fast_llm/engine/config_utils/tensor_dim.py diff --git a/fast_llm/engine/config_utils/tensor_dim.py b/fast_llm/engine/config_utils/tensor_dim.py new file mode 100644 index 000000000..f67916a66 --- /dev/null +++ b/fast_llm/engine/config_utils/tensor_dim.py @@ -0,0 +1,221 @@ +import logging +import math +import typing + +from fast_llm.engine.distributed.config import DistributedDim +from fast_llm.utils import Assert, div + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.core.distributed import ProcessGroup + +logger = logging.getLogger(__name__) + + +class TensorDim: + def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): + # TODO: Handle None for unknown sizes? + self._name = name + self._global_size = global_size + self._size = self._global_size if parallel_dim is None else div(global_size, parallel_dim.size) + self._parallel_dim = parallel_dim + + def __repr__(self) -> str: + return ( + f"{type(self).__name__}(" + f"name={self._name}," + f" size={self._size}," + f" global_size={self._global_size}," + f" parallel_dim={self._parallel_dim}" + f")" + ) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ + + @property + def name(self) -> str: + return self._name + + @property + def size(self) -> int: + return self._size + + @property + def global_size(self) -> int: + return self._global_size + + @property + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 + + @property + def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? + return self._parallel_dim + + @property + def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? + return None if self._parallel_dim is None else self._parallel_dim.group + + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.is_parallel + return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + + +class CompositeTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.parallel_dim is not None: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim + + super().__init__( + name=name, + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims + + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims + + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + +scalar_dim = TensorDim("scalar", 1) From b70dd19d73ff0f7d1ef8bea374a91199e744b94f Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 18:29:48 -0400 Subject: [PATCH 09/19] stuff --- fast_llm/layers/block/block.py | 77 ++++---- fast_llm/layers/block/config.py | 15 +- fast_llm/layers/block/mlp/config.py | 44 ----- .../layers/block/mlp/mixture_of_experts.py | 49 +++-- fast_llm/layers/block/mlp/mlp.py | 56 ++++-- fast_llm/layers/common/config.py | 2 +- fast_llm/layers/common/linear.py | 2 +- fast_llm/layers/common/normalization.py | 2 +- fast_llm/layers/language_model/config.py | 26 +-- fast_llm/layers/language_model/embedding.py | 53 +++--- fast_llm/layers/language_model/head.py | 164 ++++++++--------- fast_llm/layers/ssm/block.py | 21 ++- fast_llm/layers/ssm/config.py | 119 +----------- fast_llm/layers/ssm/discrete_mamba2.py | 74 +++++--- .../layers/ssm/{mamba_layer.py => mamba.py} | 54 +++--- fast_llm/layers/ssm/mamba2.py | 95 ++++++---- fast_llm/layers/transformer/attention.py | 173 +++++++++++------- fast_llm/layers/transformer/block.py | 9 +- fast_llm/layers/transformer/config.py | 33 +--- fast_llm/models/gpt/model.py | 73 +++++--- fast_llm/models/ssm/config.py | 9 - fast_llm/models/ssm/model.py | 107 ++++------- 22 files changed, 583 insertions(+), 674 deletions(-) rename fast_llm/layers/ssm/{mamba_layer.py => mamba.py} (79%) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index 070f5dc67..d63ac78c1 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -5,13 +5,13 @@ import torch -from fast_llm.config import Configurable +from fast_llm.config import Config, Configurable from fast_llm.core.distributed import set_generator -from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.base_model.base_model import Layer, Module from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, BlockLayerConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -20,8 +20,7 @@ class DebugLayer: # TODO: Move elsewhere? - def __init__(self, tensor_space: TensorSpace, name: str, debug_level: int = 0, debug_memory: bool = False): - self._tensor_space = tensor_space + def __init__(self, name: str, debug_level: int = 0, debug_memory: bool = False): self._name = name self._debug_level = debug_level self._debug_memory = debug_memory @@ -37,9 +36,9 @@ def _get_meta( ( dim if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else self._tensor_space[dim] + else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) ) - for dim in dims + for i, dim in enumerate(dims) ), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, @@ -70,7 +69,6 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name, dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, @@ -81,46 +79,44 @@ def __call__[ tensor, level=self._debug_level, meta=self._get_meta(tensor, name + " grad", dims, kwargs), - distributed=self._tensor_space.distributed, global_=global_, log_fn=log_fn, scale=scale, ) -class BlockLayerBase[ConfigType: BaseModelConfig](Configurable[ConfigType], torch.nn.Module, abc.ABC): +class BlockLayerBase[ConfigType: Config](Configurable[ConfigType], Module): """ Base class for blocks, mixer and MLP modules. """ def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, block_config: BlockConfig + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + # TODO: Review `hidden_dim` and `block_index` + hidden_dim: TensorDim, + block_index: int, + name: str, ): - super().__init__(config) - self._tensor_space = tensor_space + super().__init__(config, distributed_config) + self._hidden_dim = hidden_dim self._block_index = block_index self._name = name - self._sequence_parallel: bool = self._tensor_space.distributed_config.sequence_tensor_parallel + self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( - tensor_space, self._name, block_config.debug_transformer, block_config.debug_transformer_memory, ) - # @property - # def name(self) -> str: - # return self._name - class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch.nn.Module): """ Base class for mixer and MLP modules. """ - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name, config.block) - @abc.abstractmethod def forward( self, @@ -138,24 +134,39 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): """ def __init__( - self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str, return_input: bool = False + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + return_input: bool = False, ): - super().__init__(config, tensor_space, block_index, name, config) + super().__init__( + config, + config, + distributed_config, + hidden_dim, + block_index, + name, + ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - hidden_dim = self._tensor_space[BlockDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) - # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, self._config.mixer.module_name, - self._config.mixer.get_layer(self._tensor_space, self._block_index, f"{self._name} mixer"), + self._config.mixer.get_layer( + self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mixer" + ), + ) + self.mlp = self._config.mlp.get_layer( + self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mlp" ) - self.mlp = self._config.mlp.get_layer(self._tensor_space, self._block_index, f"{self._name} mlp") @torch.compile def _bias_dropout_add( @@ -177,11 +188,7 @@ def forward( if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self._name} output", dtype=input_.dtype) - generator = ( - self._tensor_space.distributed.tp_generator - if self._tensor_space.distributed_config.sequence_tensor_parallel - else self._tensor_space.distributed.pp_generator - ) + generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator if self._debug.enabled: self._debug(None, "begin", kwargs[BlockKwargs.hidden_dims], kwargs) fw_input = input_ diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 680a122eb..e5f1020e1 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -3,7 +3,8 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.peft import TransformerPeftConfig from fast_llm.layers.common.config import NormalizationConfig from fast_llm.utils import Assert @@ -58,8 +59,10 @@ class BlockLayerConfig(BaseModelConfig): def layer_class(self) -> "type[BlockLayer]": raise NotImplementedError() - def get_layer(self, tensor_space: TensorSpace, block_index: int, name: str) -> "BlockLayer": - return self.layer_class(self, tensor_space, block_index, name) + def get_layer( + self, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str + ) -> "BlockLayer": + return self.layer_class(self, distributed_config, hidden_dim, block_index, name) @config_class(registry=True) @@ -180,9 +183,3 @@ class BlockConfig(BaseModelConfig): doc="May be used to freeze some layers by setting their scale to zero.", hint=FieldHint.feature, ) - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - super().setup_tensor_space(tensor_space) - - # Hidden dimension - tensor_space.add_tensor_dim(TensorDim(BlockDimNames.hidden, self.hidden_size)) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index bde775a27..1d918b0d1 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -4,8 +4,6 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert @@ -14,21 +12,6 @@ from fast_llm.layers.block.mlp.mlp import MLPBase -class MLPDimNames: - # MLP dimensions - mlp = "mlp" - gate_and_up = "gate_and_up" - composite_gated_mlp = "composite_gated_mlp" - experts = "experts" - top_experts = "top_experts" - shared_experts = "shared_experts" - unshared_experts = "unshared_experts" - composite_expert_mlp = "composite_expert_mlp" - composite_gated_expert_mlp = "composite_gated_expert_mlp" - composite_shared_expert_mlp = "composite_shared_expert_mlp" - composite_gated_shared_expert_mlp = "composite_gated_shared_expert_mlp" - - class MLPLossNames: load_balancing_loss = "load_balancing_loss" router_z_loss = "router_z_loss" @@ -236,30 +219,3 @@ def router_weight_initialization_method(self) -> Initializer: return self.router_weight_initialization.get_initializer() else: return init_zeros_ - - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # MLP dimensions - tensor_space.add_tensor_dim(mlp := TensorDim(MLPDimNames.mlp, self.ffn_hidden_size, tensor)) - tensor_space.add_tensor_dim(gate_and_up := TensorDim(MLPDimNames.gate_and_up, 2 if self.gated else 1)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_gated_mlp, (gate_and_up, mlp))) - tensor_space.add_tensor_dim(experts := TensorDim(MLPDimNames.experts, self.num_experts)) - tensor_space.add_tensor_dim(CompositeTensorDim(MLPDimNames.composite_expert_mlp, (experts, mlp))) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_expert_mlp, (experts, gate_and_up, mlp)) - ) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.top_experts, self.num_experts_per_token)) - tensor_space.add_tensor_dim(TensorDim(MLPDimNames.unshared_experts, self.num_unshared_experts)) - - # shared_experts - if self.num_shared_experts: - tensor_space.add_tensor_dim( - shared_experts := TensorDim(MLPDimNames.shared_experts, self.num_shared_experts) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_shared_expert_mlp, (shared_experts, mlp)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(MLPDimNames.composite_gated_shared_expert_mlp, (shared_experts, gate_and_up, mlp)) - ) diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index f401371a4..a9a13a5ff 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -4,11 +4,12 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames, MLPLossNames, RoutingType +from fast_llm.layers.block.config import BlockKwargs +from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear @@ -31,18 +32,25 @@ class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): _group: ProcessGroup - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, block_index, name) + super().__init__(config, distributed_config, hidden_dim, block_index, name) layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space[BlockDimNames.hidden], - tensor_space[MLPDimNames.unshared_experts], + self._hidden_dim, + TensorDim("router_experts", self._config.num_unshared_experts), bias=False, weight_init_method=init_normal_( std=self._config.init_method_std, @@ -52,20 +60,33 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i lr_scale=router_lr_scale, ) dropless_moe = self._config.dropless_moe - if dropless_moe and tensor_space.distributed_config.sequence_tensor_parallel: + if dropless_moe and self._sequence_parallel: warnings.warn( "Dropless MoE not supported for sequence-tensor-parallel, falling back to looped implementation." ) dropless_moe = False self._mlp_forward = self._forward_dropless if dropless_moe else self._forward_looped + if self._debug.enabled: + self._top_expert_dim = TensorDim("top_experts", self._config.num_experts_per_token) + + def _get_intermediate_dims(self) -> tuple[TensorDim, TensorDim]: + intermediate_1_dim, intermediate_2_dim = super()._get_intermediate_dims() + experts_dim = TensorDim("experts", self._config.num_experts) + return ( + CompositeTensorDim("moe_intermediate_1", (experts_dim, intermediate_1_dim)), + CompositeTensorDim("moe_intermediate_2", (experts_dim, intermediate_2_dim)), + ) + def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) if self._debug.enabled: - self._debug(logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.experts,), kwargs) + self._debug( + logits, "Router logits", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs + ) # Apply z_loss if applicable if self._config.expert_z_loss_coefficient > 0.0: @@ -80,7 +101,7 @@ def forward( # Apply input_jitter if applicable: if self.training and self._config.moe_jitter_eps > 0.0: - with set_generator(self._tensor_space.distributed.pp_generator): + with set_generator(self._distributed.pp_generator): logits = self._apply_input_jitter(logits) # Routing @@ -96,12 +117,12 @@ def forward( if self._debug.enabled: # To log all ranks set `global_=False` self._debug( - scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), kwargs + scores, "Router scores", kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs ) self._debug( top_experts, "Router top experts", - kwargs[BlockKwargs.hidden_dims][:-1] + (MLPDimNames.top_experts,), + kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,), kwargs, ) @@ -125,7 +146,7 @@ def _forward_dropless( None, gated=self._config.gated, activation_type=self._config.activation_type, - group=self._intermediate_dim.parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, @@ -145,7 +166,7 @@ def _forward_looped( self._config.num_experts, self._config.gated, self._config.activation_type, - self._intermediate_dim.parallel_group, + self._parallel_dim.group, self._sequence_parallel, self.training, self._config.mlp_recompute_level, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 0716bf777..c18a70db6 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,23 +2,30 @@ import torch -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.layers.block.mlp.config import MLPConfig, MLPDimNames +from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import get_lr_scale +from fast_llm.utils import Assert, get_lr_scale class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): - def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + super().__init__(config, distributed_config, hidden_dim, block_index, name) - hidden_dim = self._tensor_space[BlockDimNames.hidden] - self._intermediate_dim = self._tensor_space[MLPDimNames.composite_expert_mlp] + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None @@ -32,19 +39,19 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - self._tensor_space[MLPDimNames.composite_gated_expert_mlp], + intermediate_1_dim, bias=self._config.add_bias, weight_init_method=self._config.layer_1_weight_initialization_method, bias_init_method=self._config.layer_1_bias_initialization_method, lr_scale=lr_scale, ) self.layer_2 = LinearBase( - self._intermediate_dim, + intermediate_2_dim, hidden_dim, bias=self._config.add_bias, weight_init_method=self._config.layer_2_weight_initialization_method, bias_init_method=self._config.layer_2_bias_initialization_method, - auto_bias_grad_accumulation=self._tensor_space.distributed_config.tensor_parallel > 1, + auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, transposed_weight=True, lr_scale=lr_scale, ) @@ -53,8 +60,28 @@ def __init__(self, config: BlockConfig, tensor_space: TensorSpace, block_index: self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + def _get_intermediate_dims(self): + intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) + if self._config.gated: + TensorDim("gate_and_up", 2) + intermediate_1_dim = ConcatenatedTensorDim("gate_and_up", (intermediate_2_dim, intermediate_2_dim)) + else: + intermediate_1_dim = intermediate_2_dim + return intermediate_1_dim, intermediate_2_dim + class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + Assert.eq(config.num_experts, 1) + super().__init__(config, distributed_config, hidden_dim, block_index, name) + def forward( self, input_: torch.Tensor, @@ -62,7 +89,6 @@ def forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - parallel_group = self._intermediate_dim.parallel_group return ( mlp_autograd( input_, @@ -70,14 +96,14 @@ def forward( self.layer_1.weight, self.layer_1.bias, self.layer_2.weight, - None if parallel_group else self.layer_2.bias, + None if self._parallel_dim.group else self.layer_2.bias, gated=self._config.gated, activation_type=self._config.activation_type, - group=parallel_group, + group=self._parallel_dim.group, sequence_parallel=self._sequence_parallel, training=self.training, recompute_level=self._config.mlp_recompute_level, transposed_layer_2_weight=self.layer_2.transposed_weight, ), - self.layer_2.bias if parallel_group else None, + self.layer_2.bias if self._parallel_dim.group else None, ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 8483dc573..90c47ecf4 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -6,7 +6,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.utils import Assert if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index 740b4847c..ca807e67c 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -4,7 +4,7 @@ import torch from fast_llm.engine.config_utils.initialization import init_zeros_ -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( input_parallel_linear_autograd, diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index cedfd2294..4af6cb2c3 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -4,7 +4,7 @@ from fast_llm.config import Configurable from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd from fast_llm.layers.common.config import ( diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 943c64d01..4c7307e1b 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -3,22 +3,11 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_ -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.utils import Assert -class LanguageModelDimNames(BlockDimNames): - # Embedding dimensions - position_embed = "position_embed" - vocab = "vocab" - vocab_tp = "vocab_tp" - # Misc - scalar = "scalar" - - class LanguageModelLossNames: language_model_loss = "language_model_loss" z_loss = "z_loss" @@ -220,19 +209,6 @@ def _validate(self) -> None: if self.output_weight_initialization.has_initialization: assert not self.tie_word_embeddings - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - self.transformer.setup_tensor_space(tensor_space) - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Embedding dimensions - if self.use_absolute_position_embeddings: - tensor_space.add_tensor_dim( - TensorDim(LanguageModelDimNames.position_embed, self.absolute_position_embeddings) - ) - # TODO: Need both? - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab, self.vocab_size)) - tensor_space.add_tensor_dim(TensorDim(LanguageModelDimNames.vocab_tp, self.vocab_size, tensor)) - @property def use_absolute_position_embeddings(self) -> int: return self.absolute_position_embeddings is not None diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 463151079..33e05cde1 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -2,19 +2,20 @@ import torch -from fast_llm.config import Configurable from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import TensorSpace -from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelDimNames, LanguageModelKwargs +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert WORD_EMBEDDINGS_WEIGHT = "word_embeddings_weight" -class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[LanguageModelBaseConfig], Layer): +class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model embedding layer. Consists of word embeddings (tensor-parallel or sequence-tensor-parallel), @@ -27,25 +28,30 @@ class LanguageModelEmbedding[ConfigType: LanguageModelBaseConfig](Configurable[L def __init__( self, config: ConfigType, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, ): - super().__init__(config) - self._tensor_space = tensor_space - self._distributed_config = self._tensor_space.distributed_config + super().__init__( + config, + config.transformer, + distributed_config, + hidden_dim, + block_index, + name, + ) self._residual_dtype = ( self._distributed_config.optimization_dtype if self._config.transformer.full_precision_residual else self._distributed_config.training_dtype ).torch - self._group_size = self._distributed_config.tensor_parallel - self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - self._tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None ) - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -58,7 +64,7 @@ def __init__( ) if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (self._tensor_space[LanguageModelDimNames.position_embed], hidden_dim), + (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), init_method=self._config.position_embedding_weight_initialization_method, allow_sequence_tensor_parallel=not self._config.parallel_embeddings, lr_scale=self._config.embeddings_lr_scale, @@ -74,21 +80,20 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) - group = self._tensor_space.distributed.tensor_group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa - embeddings = reduce_forward(embeddings, group) + embeddings = reduce_forward(embeddings, self._parallel_dim.group) if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: - embeddings = split(embeddings, group=group, dim=0) + embeddings = split(embeddings, group=self._parallel_dim.group, dim=0) else: if self._sequence_parallel: - input_ = split(input_, group=group, dim=0) + input_ = split(input_, group=self._parallel_dim.group, dim=0) if self._config.use_absolute_position_embeddings: - position_ids = split(position_ids, group=group, dim=0) + position_ids = split(position_ids, group=self._parallel_dim.group, dim=0) # handle masked tokens if mask_inputs: input_mask = input_ >= 0 @@ -101,9 +106,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask if mask_inputs: embeddings = embeddings * input_mask.unsqueeze(2) with set_generator( - self._tensor_space.distributed.tp_generator - if self._sequence_parallel - else self._tensor_space.distributed.pp_generator + self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 364dc745e..aa77089e5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -4,24 +4,19 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.config import Configurable from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames +from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl, TargetFormat, TritonConfig from fast_llm.functional.cross_entropy import cross_entropy_forward_backward, reverse_kl_forward_backward from fast_llm.functional.dpo import compute_dpo_loss from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import DebugLayer +from fast_llm.layers.block.block import BlockLayerBase +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.layers.language_model.config import ( - LanguageModelBaseConfig, - LanguageModelDimNames, - LanguageModelKwargs, - LanguageModelLossNames, -) +from fast_llm.layers.language_model.config import LanguageModelBaseConfig, LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.tensor import ParameterMeta, TensorMeta from fast_llm.utils import Assert, div, get_unique @@ -31,32 +26,35 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHead[ConfigType: LanguageModelBaseConfig](Configurable[ConfigType], Layer): +class LanguageModelHead[ConfigType: LanguageModelBaseConfig](BlockLayerBase[ConfigType], Layer): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). """ - def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_distance: int): - super().__init__(config) - self._debug = DebugLayer( - tensor_space, - f"Language model head", - self._config.transformer.debug_transformer, - self._config.transformer.debug_transformer_memory, - ) - self._group_size = tensor_space.distributed_config.tensor_parallel - self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel - self._parallel_embeddings = ( - tensor_space.distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings - ) - self._sequence_parallel_logits = ( - tensor_space.distributed_config.sequence_tensor_parallel and not self._config.parallel_embeddings + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + # TODO: Unnecessary? + block_index: int, + name: str, + prediction_distance: int, + ): + super().__init__( + config, + config.transformer, + distributed_config, + hidden_dim, + block_index, + name, ) - self._cross_entropy_splits = self._config.cross_entropy_splits - if self._cross_entropy_splits is not None and self._sequence_parallel: - assert not self._parallel_embeddings + self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) - hidden_dim = self._tensor_space[LanguageModelDimNames.hidden] + self._sequence_parallel_logits = self._sequence_parallel and not self._config.parallel_embeddings + if self._config.cross_entropy_splits is not None and self._sequence_parallel: + assert not self._parallel_logits self._loss_coefficient = ( self._config.prediction_loss_coefficient[prediction_distance] @@ -65,10 +63,6 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) - self._logits_scale_factor = self._config.logits_scale_factor - self._language_model_loss_factor = self._config.language_model_loss_factor - self._distillation_loss_factor = self._config.distillation_loss_factor - self._z_loss_factor = self._config.logit_z_loss # Distance of the target token prediction # 0: next-token prediction @@ -77,22 +71,10 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis self._prediction_distance = prediction_distance self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 - # Only the first head defines the output weights - if self._prediction_distance == 0 and not self._config.tie_word_embeddings: - # untie embedding weights - vocab_dim = self._tensor_space[ - LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ] - self.output_weights = ParameterMeta.from_dims( - (vocab_dim, hidden_dim), - init_method=self._config.output_weight_initialization_method, - lr_scale=self._config.output_lr_scale, - ) - if not self._config.enable_dpo: self._cross_entropy_impl = self._config.cross_entropy_impl if self._cross_entropy_impl == CrossEntropyImpl.auto: - if self._parallel_embeddings: + if self._parallel_logits: self._cross_entropy_impl = CrossEntropyImpl.fused elif TritonConfig.TRITON_ENABLED: self._cross_entropy_impl = CrossEntropyImpl.triton @@ -101,6 +83,20 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, prediction_dis self._forward = wrap_forward_backward(self._forward_backward, grad_is_context) + self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) + + self._vocab_dim = TensorDim( + "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_logits else None + ) + # Only the first head defines the output weights + if self._prediction_distance == 0 and not self._config.tie_word_embeddings: + # untie embedding weights + self.output_weights = ParameterMeta.from_dims( + (self._vocab_dim, hidden_dim), + init_method=self._config.output_weight_initialization_method, + lr_scale=self._config.output_lr_scale, + ) + # PEFT. self.final_norm = self._config.transformer.peft.apply_other(self.final_norm) if hasattr(self, "output_weights"): @@ -111,9 +107,8 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): if self._is_last_head: - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, + return TensorMeta.from_dims( + (scalar_dim,), tensor_name="Loss", reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa ) @@ -157,19 +152,19 @@ def _forward_backward( sequence_index = 1 - int(kwargs[LanguageModelKwargs.sequence_first]) dims[sequence_index] = ( TensorDim( - LanguageModelDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, dims[sequence_index].global_size, - DistributedDimNames.tensor, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), ) if self._sequence_parallel_logits - else TensorDim(LanguageModelDimNames.sequence_q, dims[sequence_index].global_size) + else TensorDim(BlockDimNames.sequence_q, dims[sequence_index].global_size) ) meta = TensorMeta.from_dims(tuple(dims), tensor_name="transformer hidden_state", dtype=ln_output.dtype) - hidden_state, _ = meta.local_to_global(ln_output.detach(), distributed=self._tensor_space.distributed) + hidden_state, _ = meta.local_to_global(ln_output.detach()) kwargs["hidden_states"][len(kwargs["hidden_states"]) - 1]["tensor"] = hidden_state grad_output = kwargs[LanguageModelKwargs.grad_output] / ( - self._group_size if self._sequence_parallel_logits else 1 + self._parallel_dim.size if self._sequence_parallel_logits else 1 ) output_weights = self._get_output_weights(kwargs) @@ -203,7 +198,7 @@ def _get_targets( if loss_mask is not None: loss_mask = loss_mask.flatten() - if self._config.distillation_model is None or self._language_model_loss_factor > 0.0: + if self._config.distillation_model is None or self._config.language_model_loss_factor > 0.0: lm_target = kwargs.get(LanguageModelKwargs.labels) if lm_target is not None: # MTP: Shift the labels @@ -289,7 +284,9 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * (self._group_size if self._sequence_parallel_logits else 1) + loss_count = (self._cross_entropy_splits or 1) * ( + self._parallel_dim.size if self._sequence_parallel_logits else 1 + ) if loss_count != 1: loss.div_(loss_count) if self._sequence_parallel_logits: @@ -310,39 +307,32 @@ def _logits_cross_entropy_forward_backward( input_=input_, weight=weight, bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - sequence_parallel=self._sequence_parallel and self._parallel_embeddings, + group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + sequence_parallel=self._sequence_parallel and self._parallel_logits, ) - if self._z_loss_factor > 0.0: + if self._config.logit_z_loss > 0.0: logits = z_loss( logits, - self._z_loss_factor, + self._config.logit_z_loss, self.training, grad_output, losses, LanguageModelLossNames.z_loss, - logits_scale_factor=self._logits_scale_factor, - ) - if self._debug.enabled and self._cross_entropy_splits is None: - vocab_dim = ( - LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) - sequence_dim = ( - LanguageModelDimNames.sequence_q_tp - if self._sequence_parallel_logits - else LanguageModelDimNames.sequence_q + logits_scale_factor=self._config.logits_scale_factor, ) + if self._debug.enabled and self._config.cross_entropy_splits is None: + sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] dims = ( - (sequence_dim, batch_dim, vocab_dim) + (sequence_dim, batch_dim, self._vocab_dim) if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, vocab_dim) + else (batch_dim, sequence_dim, self._vocab_dim) ) - self._debug(logits, "Language model logits", dims, kwargs, scale=self._logits_scale_factor) + self._debug(logits, "Language model logits", dims, kwargs, scale=self._config.logits_scale_factor) if targets is None: - return logits * self._logits_scale_factor, None + return logits * self._config.logits_scale_factor, None dpo_target, lm_target, distillation_target, loss_mask = targets if dpo_target is not None: @@ -363,25 +353,25 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._language_model_loss_factor, + group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.labels, ) - lm_loss = lm_loss * self._language_model_loss_factor + lm_loss = lm_loss * self._config.language_model_loss_factor else: lm_loss, lm_grad = None, None - if distillation_target is not None and self._distillation_loss_factor > 0.0: + if distillation_target is not None and self._config.distillation_loss_factor > 0.0: if self._config.distillation_loss_implementation == DistillationLossImpl.reverse_kl: distillation_loss, distillation_grad = reverse_kl_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - logits_scale_factor=self._logits_scale_factor, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, + group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, target_format=( TargetFormat.labels if self._config.distillation_model is None else TargetFormat.logits @@ -392,17 +382,17 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_embeddings else None, - grad_output=grad_output * self._loss_coefficient * self._distillation_loss_factor, + group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, implementation=self._cross_entropy_impl, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._config.logits_scale_factor, target_format=TargetFormat.logits, ) else: raise ValueError( f"Invalid distillation loss implementation: {self._config.distillation_loss_implementation}" ) - distillation_loss = distillation_loss * self._distillation_loss_factor + distillation_loss = distillation_loss * self._config.distillation_loss_factor else: distillation_loss, distillation_grad = None, None diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 987d5fa0d..e6374e725 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,11 +1,12 @@ -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.ssm.config import SSMConfig # TODO: Sort out configs. -class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): +class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ @@ -14,21 +15,25 @@ class SSMBlock[ConfigType: BlockConfig](Block[BlockConfig]): def __init__( self, - config: BlockConfig, + config: ConfigType, ssm_config: SSMConfig, - tensor_space: TensorSpace, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, mixer_cls: type[BlockLayer], block_index: int, + name: str, return_input: bool = False, ): self._ssm_config = ssm_config self._mixer_cls = mixer_cls - super().__init__(config, tensor_space, block_index, return_input) + super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) def _create_mixer(self) -> BlockLayer: return self._mixer_cls( self._ssm_config, - tensor_space=self._tensor_space, - block_index=self._block_index, - block_config=self._config, + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", ) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index dec0675b9..910024e52 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -2,38 +2,13 @@ import typing from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedDimNames from fast_llm.functional.config import ActivationType -from fast_llm.layers.block.config import BlockDimNames -from fast_llm.utils import Assert, div +from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.initialization import Initializer -class SSMDimNames(BlockDimNames): - # TODO: Use separate tensor space for different mixers so there is no risk of name conflict. - state = "ssm_state" # State dimension (N), aka head size / num channels - head_dim = "ssm_head_dim" - head_groups = "ssm_head_groups" - group_heads = "ssm_group_heads" - - convolution_kernel = "ssm_convolution_kernel" # Kernel dimension of the conv1d in mamba layers - - dt_rank = "ssm_dt_rank" - - # Composite dimensions - composite_heads = "ssm_composite_heads" - composite_heads_and_head_dim = "ssm_composite_heads_and_head_dim" - composite_head_groups_and_state = "ssm_composite_head_groups_and_state" - - # Concatenated dimensions - concatenated_convolution = "ssm_concatenated_convolution" - concatenated_x_projection = "ssm_x_concatenated_x_projection" - concatenated_inner_projection = "ssm_concatenated_inner_projection" - - class SSMBlockType(enum.StrEnum): """ An enum for the available mamba types for the MLP layer. @@ -46,9 +21,9 @@ class SSMBlockType(enum.StrEnum): def get_mixer_class(self): if self == SSMBlockType.mamba: - from fast_llm.layers.ssm.mamba_layer import MambaLayer + from fast_llm.layers.ssm.mamba import Mamba - return MambaLayer + return Mamba elif self == SSMBlockType.mamba2: from fast_llm.layers.ssm.mamba2 import Mamba2 @@ -79,21 +54,21 @@ class SSMConfig(Config): # TODO: Remove (redundant default) expansion_factor: int = Field( default=2, - desc="Expansion factor for Mamba blocks.", + desc="Expansion factor.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # head_size [MambaLayer, Mamba2, DiscreteMamba2] state_size: int = Field( default=16, - desc="State size for Mamba blocks.", + desc="State size.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, - desc="Conv kernel dimension for Mamba blocks.", + desc="Conv kernel dimensions.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) @@ -106,19 +81,19 @@ class SSMConfig(Config): # head_groups [DiscreteMamba2] n_qk_heads: int = Field( default=32, - desc="Number of QK heads for Mamba2 blocks.", + desc="Number of QK heads.", hint=FieldHint.architecture, ) # heads [DiscreteMamba2]# TODO: Remove? (redundant) n_v_heads: int = Field( default=32, - desc="Number of V heads for Mamba2 blocks.", + desc="Number of V heads.", hint=FieldHint.architecture, ) # c_size [MambaLayer, Mamba2, DiscreteMamba2]? d_inner: None | int = Field( default=None, - desc="Inner dimension for Mamba2 blocks.", + desc="Inner dimension.", hint=FieldHint.core, ) # xb_size [Mamba2] @@ -204,79 +179,3 @@ def _validate(self) -> None: self.activation_type = ActivationType.silu super()._validate() Assert.geq(self.dt_max, self.dt_min) - - def setup_tensor_space(self, tensor_space: TensorSpace, block_type: SSMBlockType) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - - # Head groups are configured differently depending on the block type. - if block_type == SSMBlockType.mamba: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = num_heads - elif block_type == SSMBlockType.mamba2: - num_heads = div(self.d_inner, self.state_size) - num_head_groups = div(self.d_xb, self.state_size) - elif block_type == SSMBlockType.mamba2_discrete: - # TODO: Use different variables? - num_heads = self.n_v_heads - num_head_groups = self.n_qk_heads - else: - raise NotImplementedError(block_type) - - tensor_space.add_tensor_dim(state := TensorDim(SSMDimNames.state, self.state_size)) - if block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim(head_dim := TensorDim(SSMDimNames.head_dim, div(self.d_inner, num_heads))) - else: - head_dim = state - - tensor_space.add_tensor_dim(head_groups := TensorDim(SSMDimNames.head_groups, num_head_groups, tensor)) - tensor_space.add_tensor_dim(group_heads := TensorDim(SSMDimNames.group_heads, div(num_heads, num_head_groups))) - tensor_space.add_tensor_dim( - heads := CompositeTensorDim(SSMDimNames.composite_heads, (head_groups, group_heads)) - ) - tensor_space.add_tensor_dim( - heads_and_head_dim := CompositeTensorDim( - SSMDimNames.composite_heads_and_head_dim, (head_groups, group_heads, head_dim) - ) - ) - tensor_space.add_tensor_dim( - head_groups_and_state := CompositeTensorDim( - SSMDimNames.composite_head_groups_and_state, (head_groups, state) - ) - ) - tensor_space.add_tensor_dim(TensorDim(SSMDimNames.convolution_kernel, self.conv_kernel_dimension)) - - # DT projection - if block_type in (SSMBlockType.mamba, SSMBlockType.mamba2): - tensor_space.add_tensor_dim(dt_rank := TensorDim(SSMDimNames.dt_rank, self.dt_rank)) - - if block_type == SSMBlockType.mamba: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim(SSMDimNames.concatenated_x_projection, (dt_rank, state, state)) - ) - # TODO: Use composition instead - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, (heads_and_head_dim, heads_and_head_dim) - ) - ) - elif block_type == SSMBlockType.mamba2: - # TODO: Factor out state? - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim), - ) - ) - elif block_type == SSMBlockType.mamba2_discrete: - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_inner_projection, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state, heads_and_head_dim, heads), - ) - ) - tensor_space.add_tensor_dim( - ConcatenatedTensorDim( - SSMDimNames.concatenated_convolution, - (heads_and_head_dim, head_groups_and_state, head_groups_and_state), - ) - ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 61291f845..4ae6b4821 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,15 +5,16 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_kaiming_ +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import get_lr_scale +from fast_llm.utils import div, get_lr_scale logger = logging.getLogger(__name__) @@ -34,48 +35,65 @@ _causal_conv1d_available = False -class DiscreteMamba2(BlockLayer): +class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + block_config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, ) - self._config: SSMConfig = config - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + state_dim = TensorDim("state", self._config.state_size) + v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] - conv1d_dim = tensor_space[SSMDimNames.concatenated_convolution] - heads_dim = tensor_space[SSMDimNames.composite_heads] + head_groups_dim = TensorDim( + "head_groups", + self._config.n_qk_heads, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor), + ) + group_heads_dim = TensorDim("group_heads", div(self._config.n_v_heads, self._config.n_qk_heads)) + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, v_head_size_dim)) + bc_dim = CompositeTensorDim("bc", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, bc_dim, bc_dim, inner_dim, heads_dim), + ) + convolution_dim = ConcatenatedTensorDim("convolution", (inner_dim, bc_dim, bc_dim)) # local_head_groups = head_groups / TP - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + self._local_head_groups = head_groups_dim.size # local_heads = local_head_groups * group_heads self._local_heads = heads_dim.size # local_inner_size = local_heads * head_size self._local_inner_size = inner_dim.size # local_bc_size = local_head_groups * state - self._local_bc_size = tensor_space[SSMDimNames.composite_head_groups_and_state].size + self._local_bc_size = bc_dim.size + + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None + lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) # TODO: double check initializations # Projections self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, @@ -90,15 +108,17 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( ( - conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + convolution_dim, + scalar_dim, + convolution_kernel_dim, + ), + init_method=init_uniform_centered_( + (convolution_dim.global_size * self._config.conv_kernel_dimension) ** -0.5 ), - init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, ) self.conv1d_bias = ParameterMeta.from_dims( - (conv1d_dim,), + (convolution_dim,), init_method=init_uniform_centered_(self._config.conv_kernel_dimension**-0.5), lr_scale=lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba.py similarity index 79% rename from fast_llm/layers/ssm/mamba_layer.py rename to fast_llm/layers/ssm/mamba.py index 0dcc29f0b..99609b1c4 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba.py @@ -5,14 +5,15 @@ import torch from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.layers.common.linear import Linear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, div, get_lr_scale try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -53,31 +54,39 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return LambdaInitializer(init_) -class MambaLayer(BlockLayer): +class Mamba[ConfigType: SSMConfig](BlockLayer[ConfigType]): _mixer_name: typing.ClassVar[str] = "mamba" def __init__( self, - config: SSMConfig, - block_index: int, - tensor_space: TensorSpace, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + block_config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, ) - assert tensor_space.distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" - self._config = config + assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) # Tensor dims: - inner_dim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - hidden_dim = tensor_space[SSMDimNames.hidden] + heads_dim = TensorDim("heads", div(self._config.d_inner, self._config.state_size)) + state_dim = TensorDim("state", self._config.state_size) + inner_dim = CompositeTensorDim("inner", (heads_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) + x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) + layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) @@ -85,7 +94,7 @@ def __init__( # TODO: lr_scale? self.in_proj = Linear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) @@ -93,8 +102,8 @@ def __init__( self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, @@ -102,7 +111,7 @@ def __init__( self.x_proj = Linear( inner_dim, - tensor_space[SSMDimNames.concatenated_x_projection], + x_projection_dim, weight_init_method=init_kaiming_(inner_dim.size), bias=False, lr_scale=lr_scale, @@ -111,11 +120,10 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.dt_rank]), + (inner_dim, dt_rank_dim), init_method=init_kaiming_(self._config.dt_rank), lr_scale=lr_scale, ) - self.dt_proj_bias = ParameterMeta.from_dims( (inner_dim,), init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), @@ -123,12 +131,11 @@ def __init__( ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), weight_decay=False, init_method=init_A(self._config.state_size, inner_dim.size), lr_scale=lr_scale, ) - # D "skip" parameter self.D = ParameterMeta.from_dims( (inner_dim,), @@ -136,7 +143,6 @@ def __init__( init_method=init_ones_, lr_scale=lr_scale, ) - self.out_proj = Linear( inner_dim, hidden_dim, diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b6626e893..a5797a50d 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,13 +4,14 @@ import torch from fast_llm.engine.config_utils.initialization import init_ones_, init_uniform_centered_ -from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim, scalar_dim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear -from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias, init_kaiming_ +from fast_llm.layers.ssm.config import SSMConfig +from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, div, get_lr_scale @@ -31,63 +32,71 @@ logger = logging.getLogger(__name__) -class Mamba2(BlockLayer): +class Mamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ _mixer_name: typing.ClassVar[str] = "mamba_2" - _XZ_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads_and_head_dim, - SSMDimNames.sequence_q, - ) - _BC_DIMS = ( - SSMDimNames.batch, - SSMDimNames.composite_heads, - SSMDimNames.state, - SSMDimNames.sequence_q, - ) - def __init__( self, - config: SSMConfig, - tensor_space: TensorSpace, - block_index: int, + config: ConfigType, block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ): super().__init__( - tensor_space, + config, + block_config, + distributed_config, + hidden_dim, block_index, - self._mixer_name, - debug_level=block_config.debug_transformer, - debug_memory=block_config.debug_transformer_memory, + name, ) - self._config: SSMConfig = config Assert.eq(self._config.activation_type, ActivationType.silu) layer_lr_scale: float | None = ( block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None ) lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) - inner_dim: TensorDim = tensor_space[SSMDimNames.composite_heads_and_head_dim] - xb_dim = tensor_space[SSMDimNames.composite_head_groups_and_state] - hidden_dim: TensorDim = tensor_space[SSMDimNames.hidden] - dt_rank_dim = tensor_space[SSMDimNames.dt_rank] + num_heads = div(self._config.d_inner, self._config.state_size) + num_head_groups = div(self._config.d_xb, self._config.state_size) + + state_dim = TensorDim("state", self._config.state_size) + + head_groups_dim = TensorDim( + "head_groups", num_head_groups, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + ) + group_heads_dim = TensorDim("group_heads", div(num_heads, num_head_groups)) + + heads_dim = CompositeTensorDim("heads", (head_groups_dim, group_heads_dim)) + + inner_dim = CompositeTensorDim("inner", (head_groups_dim, group_heads_dim, state_dim)) + xb_dim = CompositeTensorDim("xb", (head_groups_dim, state_dim)) + convolution_kernel_dim = TensorDim("convolution_kernel", self._config.conv_kernel_dimension) - self._local_heads = tensor_space[SSMDimNames.composite_heads].size - self._local_head_groups = tensor_space[SSMDimNames.head_groups].size + # DT projection + dt_rank_dim = TensorDim("dt_rank", self._config.dt_rank) + + inner_projection_dim = ConcatenatedTensorDim( + "inner_projection", + (inner_dim, xb_dim, xb_dim, inner_dim), + ) + + self._local_heads = heads_dim.size + self._local_head_groups = head_groups_dim.size self._group_heads = div(self._local_heads, self._local_head_groups) self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size - conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, - tensor_space[DefaultDimNames.scalar], - tensor_space[SSMDimNames.convolution_kernel], + scalar_dim, + convolution_kernel_dim, ), init_method=init_uniform_centered_((conv1d_dim.global_size * self._config.conv_kernel_dimension) ** -0.5), lr_scale=lr_scale, @@ -99,13 +108,12 @@ def __init__( ) self.in_proj = OutputParallelLinear( hidden_dim, - tensor_space[SSMDimNames.concatenated_inner_projection], + inner_projection_dim, bias=config.add_bias_linear, weight_init_method=init_kaiming_(block_config.hidden_size), sequence_parallel=self._sequence_parallel, lr_scale=lr_scale, ) - self.dt_in_proj = Linear( hidden_dim, dt_rank_dim, @@ -131,7 +139,7 @@ def __init__( lr_scale=lr_scale, ) self.A_log = ParameterMeta.from_dims( - (inner_dim, tensor_space[SSMDimNames.state]), + (inner_dim, state_dim), init_method=init_A(self._config.state_size, self._config.d_inner), lr_scale=lr_scale, weight_decay=False, @@ -151,6 +159,19 @@ def __init__( # TODO: lr_scale? ) + if self._debug.enabled: + _xz_dims = ( + BlockDimNames.batch, + inner_dim, + BlockDimNames.sequence_q, + ) + _bc_dims = ( + BlockDimNames.batch, + heads_dim, + state_dim, + BlockDimNames.sequence_q, + ) + def forward( self, input_: torch.Tensor, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 0bea58d9a..888d1bec7 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,12 +4,14 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.tensor_space import TensorSpace +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.utils import get_lr_scale try: @@ -50,37 +52,64 @@ class Attention[ConfigType: AttentionConfig](BlockLayer[ConfigType]): A self-attention layer. """ - _QUERY_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_heads, - AttentionDimNames.kv_channels, - ) - _KV_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.head_groups, - AttentionDimNames.kv_channels, - ) - _CONTEXT_DIMS = ( - AttentionDimNames.batch, - AttentionDimNames.sequence_q, - AttentionDimNames.composite_dense, - ) - - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, name: str): - super().__init__(config, tensor_space, block_index, name) - self._config = config - self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) - - self._kv_channels = self._tensor_space[AttentionDimNames.kv_channels].size - self._head_groups = self._tensor_space[AttentionDimNames.head_groups].global_size - self._local_head_groups = self._tensor_space[AttentionDimNames.head_groups].size - self._local_heads_per_group = self._tensor_space[AttentionDimNames.group_heads].size + def __init__( + self, + config: ConfigType, + block_config: BlockConfig, + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + ): + super().__init__( + config, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + ) + self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) + + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sequence_data_parallel_dim = self._distributed_config.get_distributed_dim( + DistributedDimNames.sequence_data + ) + head_group_dim = TensorDim( + "head_groups", self._config.head_groups, self._parallel_dim if self._config.head_groups > 1 else None + ) + group_heads_dim = TensorDim( + "group_heads", + div(self._config.num_attention_heads, self._config.head_groups), + None if self._config.head_groups > 1 else self._parallel_dim, + ) + self._local_head_groups = head_group_dim.size + self._local_heads_per_group = group_heads_dim.size self._local_heads = self._local_head_groups * self._local_heads_per_group - self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] + kv_channels_dim = TensorDim("kv_channels", self._config.kv_channels) + query_dim = CompositeTensorDim("query", (head_group_dim, group_heads_dim, kv_channels_dim)) + key_value_dim = ConcatenatedTensorDim( + "key_value", + ( + CompositeTensorDim("key", (head_group_dim, kv_channels_dim)), + CompositeTensorDim("value", (head_group_dim, kv_channels_dim)), + ), + ) + dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, kv_channels_dim)) + + self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) + + init_method_qkv = init_normal_( + std=self._config.init_method_std_qkv, + min_val=self._config.init_method_min_qkv, + max_val=self._config.init_method_max_qkv, + ) + init_method_std_attn_proj = init_normal_( + std=self._config.init_method_std_attn_proj, + min_val=self._config.init_method_min_attn_proj, + max_val=self._config.init_method_max_attn_proj, + ) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -88,7 +117,7 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_query], + query_dim, bias=self._config.add_qkv_bias, weight_init_method=self._config.qkv_weight_initialization_method, bias_init_method=self._config.qkv_bias_initialization_method, @@ -97,7 +126,7 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space[AttentionDimNames.composite_key_value], + key_value_dim, bias=self._config.add_qkv_bias, weight_init_method=self._config.qkv_weight_initialization_method, bias_init_method=self._config.qkv_bias_initialization_method, @@ -107,11 +136,11 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build() + self._rotary = self._config.rotary.build(kv_channels_dim) # Output. self.dense = InputParallelLinear( - self._tensor_space[AttentionDimNames.composite_dense], + dense_dim, hidden_dim, bias=self._config.add_dense_bias, weight_init_method=self._config.dense_weight_initialization_method, @@ -119,12 +148,30 @@ def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: i sequence_parallel=self._sequence_parallel, lr_scale=attention_lr_scale, ) - # PEFT. self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + if self._debug.enabled: + self._query_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), + kv_channels_dim, + ) + self._kv_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + head_group_dim, + kv_channels_dim, + ) + self._context_dims = ( + BlockDimNames.batch, + BlockDimNames.sequence_q, + dense_dim, + ) + def _attn_fused( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor ) -> torch.Tensor: @@ -133,16 +180,18 @@ def _attn_fused( sk = key.size(1) if self._local_head_groups == 1: - query = query.view(b, sq * self._local_heads, self._kv_channels) + query = query.view(b, sq * self._local_heads, self._config.kv_channels) key = key.transpose(-1, -2) else: query = ( - query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._kv_channels)) + query.unflatten(-1, (self._local_head_groups, self._local_heads_per_group, self._config.kv_channels)) .transpose(1, 2) - .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._kv_channels) + .reshape(b * self._local_head_groups, sq * self._local_heads_per_group, self._config.kv_channels) + ) + key = key.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).movedim(1, 3).flatten(0, 1) + value = ( + value.unflatten(-1, (self._local_head_groups, self._config.kv_channels)).transpose(1, 2).flatten(0, 1) ) - key = key.unflatten(-1, (self._local_head_groups, self._kv_channels)).movedim(1, 3).flatten(0, 1) - value = value.unflatten(-1, (self._local_head_groups, self._kv_channels)).transpose(1, 2).flatten(0, 1) attn_weights = torch.empty( (b * self._local_head_groups, sq * self._local_heads_per_group, sk), device=query.device, dtype=query.dtype @@ -159,7 +208,7 @@ def _attn_fused( attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): attn_weights = torch.dropout(attn_weights, self._config.attention_dropout, self.training) attn_output = torch.bmm( attn_weights.view(b * self._local_head_groups, sq * self._local_heads_per_group, sk), value @@ -169,7 +218,7 @@ def _attn_fused( return attn_output.view(b, sq, -1) else: return ( - attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._kv_channels) + attn_output.view(b, self._local_head_groups, sq, self._local_heads_per_group, self._config.kv_channels) .transpose(1, 2) .flatten(2) ) @@ -182,17 +231,15 @@ def _query_key_value_forward( handle = None if self._head_groups == 1 and self._sequence_parallel: - key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.tensor_group, dim=0, async_op=True - ) + key_value, handle = gather_op(key_value, group=self._parallel_dim.group, dim=0, async_op=True) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: if handle: # TODO: This is probably unnecessary. handle.wait() # sequence dim may not be zero, but this needs to be handled after `handle.wait()` key_value, handle = gather_op( - key_value, group=self._tensor_space.distributed.sequence_data_group, dim=0, async_op=True + key_value, group=self._sequence_data_parallel_dim.group, dim=0, async_op=True ) query, query_context = self.query.forward_only(input_) @@ -200,8 +247,8 @@ def _query_key_value_forward( if handle: handle.wait() - if self._tensor_space.distributed.sequence_data_group and not sequence_first: - key_value = swap_mult_dim(key_value, self._tensor_space.distributed_config.sequence_data_parallel, 0, 1) + if self._sequence_data_parallel_dim.group and not sequence_first: + key_value = swap_mult_dim(key_value, self._distributed_config.sequence_data_parallel, 0, 1) context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} return query, key_value, context @@ -212,10 +259,10 @@ def _query_key_value_backward( # TODO: De-allocate qkv grads quicker. handle = None - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: key_value_grad, handle = reduce_scatter_op( key_value_grad, - group=self._tensor_space.distributed.sequence_data_group, + group=self._sequence_data_parallel_dim.group, dim=1 - context["sequence_first"], async_op=True, ) @@ -226,7 +273,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._tensor_space.distributed.tensor_group): + if self._head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: @@ -269,7 +316,7 @@ def forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._tensor_space.distributed.sequence_data_group: + if self._sequence_data_parallel_dim.group: key_value = ( key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] if sequence_first @@ -281,11 +328,11 @@ def forward( query = query.transpose(0, 1).contiguous() key_value = key_value.transpose(0, 1).contiguous() - key, value = key_value.split(self._local_head_groups * self._kv_channels, dim=-1) + key, value = key_value.split(self._local_head_groups * self._config.kv_channels, dim=-1) - query = query.view(*query.shape[:2], self._local_heads, self._kv_channels) - key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) - value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) + query = query.view(*query.shape[:2], self._local_heads, self._config.kv_channels) + key = key.view(*key.shape[:2], self._local_head_groups, self._config.kv_channels) + value = value.view(*value.shape[:2], self._local_head_groups, self._config.kv_channels) if self._debug.enabled: self._debug(query, "query_rotary_input", self._QUERY_DIMS, kwargs) @@ -296,7 +343,7 @@ def forward( if self._use_flash_attention: assert _flash_available - with set_generator(self._tensor_space.distributed.tp_generator): + with set_generator(self._distributed.tp_generator): if (cu_seqlens_q := kwargs.get(AttentionKwargs.cu_seqlens_q, None)) is not None: out_dims = query.size() query = query.view(-1, query.size(-2), query.size(-1)) @@ -337,10 +384,10 @@ def forward( ) if self._debug.enabled: - self._debug(query, "query", self._QUERY_DIMS, kwargs) - self._debug(key, "key", self._KV_DIMS, kwargs) - self._debug(value, "value", self._KV_DIMS, kwargs) - self._debug(input_, "context", self._CONTEXT_DIMS, kwargs) + self._debug(query, "query", self._query_dims, kwargs) + self._debug(key, "key", self._kv_dims, kwargs) + self._debug(value, "value", self._kv_dims, kwargs) + self._debug(input_, "context", self._context_dims, kwargs) if sequence_first: # TODO: Optimize (is contiguous avoidable? Transpose dense output?) diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index 89d7a2e3b..dd81a4da5 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,7 +1,6 @@ import logging import typing -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.layers.block.block import Block, BlockLayer from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig @@ -10,13 +9,11 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): - _name = "Transformer layer" # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" _config: TransformerConfig - def __init__(self, config: ConfigType, tensor_space: TensorSpace, block_index: int, return_input: bool = False): - super().__init__(config, tensor_space, block_index, return_input) - def _create_mixer(self) -> BlockLayer: - return Attention(self._config, self._tensor_space, self._block_index) + return Attention( + self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" + ) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index e8c319b0f..0b76a1e87 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -6,8 +6,7 @@ from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ -from fast_llm.engine.config_utils.tensor_space import CompositeTensorDim, TensorDim, TensorSpace -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs, MixerConfig from fast_llm.layers.transformer.rotary.config import RotaryConfig @@ -150,36 +149,6 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - tensor = tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.tensor) - # Needed for multiple inheritance. - super().setup_tensor_space(tensor_space) # Noqa - - tensor_space.add_tensor_dim( - head_groups := TensorDim( - AttentionDimNames.head_groups, self.head_groups, tensor if self.head_groups > 1 else None - ) - ) - tensor_space.add_tensor_dim( - group_heads := TensorDim( - AttentionDimNames.group_heads, - div(self.num_attention_heads, self.head_groups), - None if self.head_groups > 1 else tensor, - ) - ) - tensor_space.add_tensor_dim(key_and_value := TensorDim(AttentionDimNames.key_and_value, 2)) - tensor_space.add_tensor_dim(kv_channels := TensorDim(AttentionDimNames.kv_channels, self.kv_channels)) - tensor_space.add_tensor_dim(CompositeTensorDim(AttentionDimNames.composite_heads, (head_groups, group_heads))) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_query, (head_groups, group_heads, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_key_value, (key_and_value, head_groups, kv_channels)) - ) - tensor_space.add_tensor_dim( - CompositeTensorDim(AttentionDimNames.composite_dense, (head_groups, group_heads, kv_channels)) - ) - @functools.cached_property def add_qkv_bias(self) -> bool: if isinstance(self.block.add_linear_biases, bool): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 796c34756..2b941a68a 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -6,7 +6,7 @@ from fast_llm.data.data.gpt.data import GPTBatch from fast_llm.engine.base_model.base_model import BaseModel, Layer, LossDef from fast_llm.engine.base_model.config import Preprocessor -from fast_llm.engine.config_utils.tensor_space import TensorDim +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel @@ -36,6 +36,7 @@ def __init__( config: GPTBaseModelConfig, distributed_config: DistributedConfig, ): + self._hidden_dim = TensorDim("hidden", config.transformer.hidden_size) super().__init__(config, distributed_config) self._use_flash_attention = self._config.transformer.do_use_flash_attention(distributed_config) if self._config.use_megatron_initialization: @@ -45,59 +46,79 @@ def __init__( # `self._reference_models` is not populated at this point, so we pass a mutable dict. self._preprocessors: list[Preprocessor] = [] if self._config.use_absolute_position_embeddings: - self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) if self._use_flash_attention: - self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: - self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._tensor_space)) + self._preprocessors.append(BackupAttentionPreprocessor(self._config.transformer, self._distributed_config)) if self._config.enable_dpo: # TODO better way to pass in? - self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._tensor_space)) + self._preprocessors.append(PreferenceSpanPreprocessor(self._config, self._distributed_config)) - def get_output_layers(self) -> list[Layer]: + def _get_output_layers(self) -> list[Layer]: layers = [] for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, + self._get_block( # TODO MTP: which index? - block_index=max(self._config.transformer.num_blocks + i, 1), + max(self._config.transformer.num_layers + i, 1), + f"MPT head {i} block", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=i < self._config.prediction_heads - 1, + i < self._config.prediction_heads - 1, ) ) - layers.append( - LanguageModelHead( - self._config, - self._tensor_space, - prediction_distance=i, - ) - ) + layers.append(self._get_head(i)) return layers def get_layers(self) -> list[Layer]: return [ - LanguageModelEmbedding(self._config, self._tensor_space), + self._get_embeddings(), *[ - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, + self._get_block( + i + 1, + f"Block {i + 1}", # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. - return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_blocks - 1, + self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, ) - for i in range(self._config.transformer.num_blocks) + for i in range(self._config.transformer.num_layers) ], - *self.get_output_layers(), + *self._get_output_layers(), ] + def _get_block( + self, + block_index: int, + name: str, + return_input: bool = False, + ): + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + + def _get_embeddings(self): + return LanguageModelEmbedding(self._config, self._distributed_config, self._hidden_dim, 0, "Embeddings") + + def _get_head(self, prediction_distance): + return LanguageModelHead( + self._config, + self._distributed_config, + self._hidden_dim, + max(self._config.transformer.num_layers + prediction_distance, 1), + f"Language model head {prediction_distance}", + prediction_distance=prediction_distance, + ) + def preprocess_meta( self, batch_meta: GPTBatchConfig | torch.Tensor, phase: PhaseType ) -> list[tuple[TensorMeta, dict]]: diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 7175e6438..1c85327e2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -6,7 +6,6 @@ from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.engine.checkpoint.config import CheckpointHandler from fast_llm.engine.config_utils.runnable import RunnableConfig -from fast_llm.engine.config_utils.tensor_space import TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig @@ -47,14 +46,6 @@ class HybridSSMBaseModelConfig(GPTBaseModelConfig): # TODO: Support combination of different SSM block types. ssm_block_type: SSMBlockType | None = Field(init=False) - def setup_tensor_space(self, tensor_space: TensorSpace) -> None: - """ - Setup the tensor space for the model. - """ - super().setup_tensor_space(tensor_space) - if self.ssm_block_type is not None: - self.ssm.setup_tensor_space(tensor_space, self.ssm_block_type) - def _validate(self): with self._set_implicit_default(None): if self.ssm.dt_rank == "auto" or self.ssm.dt_rank is None: diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 32fbdad9b..94f9eb321 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,10 +1,7 @@ import logging import typing -from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.ssm.block import SSMBlock from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel @@ -29,79 +26,39 @@ def __init__( ): super().__init__(config, distributed_config) - def get_output_layers(self) -> list[Layer]: - """ - Get the output layers of the model. - This includes the language model head and any additional heads specified in the configuration. - """ - layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] - - if self._config.prediction_heads > 1: + def _get_block( + self, + block_index: int, + name: str, + return_input: bool = False, + ): + if block_index > self._config.transformer.num_layers: + # MTP block block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] - for i in range(1, self._config.prediction_heads): - if block_type == SSMBlockType.transformer: - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=len(self._config.hybrid_block_layout), - return_input=i != self._config.prediction_heads - 1, - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - ) - layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) - - return layers - - def get_layers(self) -> list[Layer]: - """ - Create a list of layers for the model, interleaving Transformer and Mamba blocks - according to the block pattern. - """ - layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] - - # Create blocks according to pattern - for i, block_type in enumerate(self._config.hybrid_block_layout): - if block_type == SSMBlockType.transformer: - # Transformer block - layers.append( - TransformerBlock( - self._config.transformer, - self._tensor_space, - block_index=i + 1, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - else: - layers.append( - SSMBlock( - config=self._config.transformer, - ssm_config=self._config.ssm, - mixer_cls=self._config.ssm_block_type.get_mixer_class(), - block_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - ) - - # Add the output layers - layers += self.get_output_layers() - - return layers + else: + # Decoder block + block_type = self._config.hybrid_block_layout[block_index - 1] + + if block_type == SSMBlockType.transformer: + return TransformerBlock( + self._config.transformer, + self._distributed_config, + self._hidden_dim, + block_index, + name, + return_input, + ) + else: + return SSMBlock( + self._config.transformer, + self._config.ssm, + self._distributed_config, + self._hidden_dim, + self._config.ssm_block_type.get_mixer_class(), + block_index, + name, + return_input, + ) class HybridSSMModel[ConfigType: HybridSSMModelConfig](GPTModel[ConfigType]): From af990c9ffe85e5e9b6d121c53586217cc7056228 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 19:09:29 -0400 Subject: [PATCH 10/19] stuff --- .../engine/config_utils/initialization.py | 4 +- fast_llm/layers/block/config.py | 17 +- fast_llm/layers/block/mlp/config.py | 40 +-- .../layers/block/mlp/mixture_of_experts.py | 4 +- fast_llm/layers/block/mlp/mlp.py | 4 +- fast_llm/layers/common/config.py | 139 +------- .../layers/common/normalization/__init__.py | 311 ++++++++++++++++++ .../layers/common/normalization/config.py | 158 +++++++++ .../{ => normalization}/normalization.py | 6 +- fast_llm/layers/language_model/config.py | 22 +- fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/ssm/mamba.py | 4 +- fast_llm/layers/ssm/mamba2.py | 6 +- fast_llm/layers/transformer/attention.py | 6 +- fast_llm/layers/transformer/config.py | 38 +-- fast_llm/layers/transformer/rotary/config.py | 2 +- fast_llm/models/gpt/conversion.py | 2 +- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/ssm/conversion.py | 2 +- fast_llm/utils.py | 39 ++- tests/functional/test_triton_kernels.py | 6 +- 21 files changed, 571 insertions(+), 245 deletions(-) create mode 100644 fast_llm/layers/common/normalization/__init__.py create mode 100644 fast_llm/layers/common/normalization/config.py rename fast_llm/layers/common/{ => normalization}/normalization.py (98%) diff --git a/fast_llm/engine/config_utils/initialization.py b/fast_llm/engine/config_utils/initialization.py index cdee37935..5e02d6d2e 100644 --- a/fast_llm/engine/config_utils/initialization.py +++ b/fast_llm/engine/config_utils/initialization.py @@ -13,7 +13,7 @@ @config_class(registry=True) class InitializationConfig(Config): _abstract = True - has_initialization: typing.ClassVar[bool] = True + is_default: typing.ClassVar[bool] = False @classmethod def _from_dict( @@ -35,7 +35,7 @@ def get_initializer(self) -> "Initializer": class DefaultInitializationConfig(InitializationConfig): # A placeholder indicating that the class default should be used instead. _abstract = False - has_initialization = False + is_default = True @config_class(dynamic_type={InitializationConfig: "fill"}) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index e5f1020e1..3c9da42f6 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,4 +1,3 @@ -import enum import typing from fast_llm.config import Field, FieldHint, check_field, config_class @@ -6,7 +5,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.peft import TransformerPeftConfig -from fast_llm.layers.common.config import NormalizationConfig +from fast_llm.layers.common.normalization import NormalizationConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -39,13 +38,6 @@ class BlockKwargs: grad_output = "grad_output" -class AddLinearBiasChoices(str, enum.Enum): - # TODO: Review - nowhere = "nowhere" - everywhere = "everywhere" - only_attn_qkv = "only_attn_qkv" - - @config_class() class BlockLayerConfig(BaseModelConfig): """ @@ -127,7 +119,7 @@ class BlockConfig(BaseModelConfig): desc="Configuration for the MLP.", hint=FieldHint.architecture, ) - # TODO: Review names + # TODO: Allow separate initializations? normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, @@ -136,6 +128,7 @@ class BlockConfig(BaseModelConfig): desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) + # TODO: Review names hidden_dropout: float = Field( default=0.0, desc="Dropout applied to the residual connections.", @@ -153,9 +146,9 @@ class BlockConfig(BaseModelConfig): desc="Log the memory usage after each operation in a transformer layer..", hint=FieldHint.logging, ) - add_linear_biases: bool | AddLinearBiasChoices = Field( + add_linear_biases: bool = Field( default=True, - desc="Add biases to all, none or Q, K, V layers. Accepted values: True, False, or AddLinearBiasChoices.", + desc="Whether to add biases to linear layers. May be overridden in individual layer configs.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 1d918b0d1..237a538fa 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -149,13 +149,7 @@ def layer_class(self) -> "type[MLPBase]": @property def add_bias(self) -> bool: - from fast_llm.layers.block.config import AddLinearBiasChoices - - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - if self.block.add_linear_biases == AddLinearBiasChoices.everywhere: - return True - return False + return self.block.add_linear_biases def _validate(self) -> None: assert hasattr(self, "block") @@ -181,41 +175,41 @@ def _validate(self) -> None: elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) - if self.layer_1_bias_initialization.has_initialization or self.layer_2_bias_initialization.has_initialization: + if not (self.layer_1_bias_initialization.is_default and self.layer_2_bias_initialization.is_default): assert self.add_bias @functools.cached_property def layer_1_weight_initialization_method(self) -> Initializer: - if self.layer_1_weight_initialization.has_initialization: - return self.layer_1_weight_initialization.get_initializer() - else: + if self.layer_1_weight_initialization.is_default: return init_normal_(0, self.block.hidden_size**-0.5) + else: + return self.layer_1_weight_initialization.get_initializer() @functools.cached_property def layer_1_bias_initialization_method(self) -> Initializer: - if self.layer_1_bias_initialization.has_initialization: - return self.layer_1_bias_initialization.get_initializer() - else: + if self.layer_1_bias_initialization.is_default: return init_zeros_ + else: + return self.layer_1_bias_initialization.get_initializer() @functools.cached_property def layer_2_weight_initialization_method(self) -> Initializer: - if self.layer_2_weight_initialization.has_initialization: - return self.layer_2_weight_initialization.get_initializer() - else: + if self.layer_2_weight_initialization.is_default: return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) + else: + return self.layer_2_weight_initialization.get_initializer() @functools.cached_property def layer_2_bias_initialization_method(self) -> Initializer: - if self.layer_2_bias_initialization.has_initialization: - return self.layer_2_bias_initialization.get_initializer() - else: + if self.layer_2_bias_initialization.is_default: return init_zeros_ + else: + return self.layer_2_bias_initialization.get_initializer() @functools.cached_property def router_weight_initialization_method(self) -> Initializer: - if self.router_weight_initialization.has_initialization: + if self.router_weight_initialization.is_default: + return init_zeros_ + else: assert self.add_bias return self.router_weight_initialization.get_initializer() - else: - return init_zeros_ diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index a9a13a5ff..fa7258b7e 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -13,7 +13,7 @@ from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss from fast_llm.layers.common.linear import Linear -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -46,7 +46,7 @@ def __init__( super().__init__(config, distributed_config, hidden_dim, block_index, name) layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None - router_lr_scale = get_lr_scale(self._config.router_lr_scale, layer_lr_scale) + router_lr_scale = combine_lr_scales(self._config.router_lr_scale, layer_lr_scale) self.router = Linear( self._hidden_dim, diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index c18a70db6..cc4562dfc 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -10,7 +10,7 @@ from fast_llm.layers.block.mlp.config import MLPConfig from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -34,7 +34,7 @@ def __init__( if isinstance(self._config.mlp_lr_scale, list) else self._config.mlp_lr_scale ) - lr_scale = get_lr_scale(lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 90c47ecf4..b09672961 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -1,148 +1,11 @@ import abc -import enum -import functools import typing -from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.config import Field, FieldHint, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.utils import Assert if typing.TYPE_CHECKING: from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.layers.common.normalization import Normalization - - -class NormalizationImplementation(str, enum.Enum): - """ - An enum for the available implementations of layer norm. - """ - - auto = "auto" - torch = "torch" - fused = "fused" - fast = "fast" - triton = "triton" - - -@config_class(registry=True) -class NormalizationConfig(BaseModelConfig): - pass - - @property - @abc.abstractmethod - def module_class(self) -> type["Normalization"]: - pass - - def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "Normalization": - return self.module_class(self, hidden_dim, lr_scale) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LayerNormalizationConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={NormalizationConfig: "none"}) -class NoNormalizationConfig(NormalizationConfig): - _abstract = False - - @property - def module_class(self) -> type["Normalization"]: - from fast_llm.layers.common.normalization import NoNormalization - - return NoNormalization - - -@config_class() -class LayerNormalizationBaseConfig(NormalizationConfig): - """ - Common configuration for layer norm and rms norm - """ - - # TODO: Rename to normalization_epsilon - epsilon: float = Field( - default=1e-5, - desc="Regularizer for the division.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - zero_centered: bool = Field( - default=False, - desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", - hint=FieldHint.architecture, - ) - implementation: NormalizationImplementation = Field( - default=NormalizationImplementation.auto, - desc="The implementation to use for the normalization layer.", - hint=FieldHint.performance, - ) - weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the normalization weights. Default: fill with ones", - hint=FieldHint.feature, - ) - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - cls._handle_renamed_field(default, "normalization_implementation", "implementation") - cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") - return super()._from_dict(default, strict, flat) - - @functools.cached_property - def weight_initialization_method(self) -> Initializer: - if self.weight_initialization.has_initialization: - return self.weight_initialization.get_initializer() - else: - return init_ones_ - - -@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) -class LayerNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the normalization biases. Default: fill with zeros", - hint=FieldHint.feature, - ) - - @functools.cached_property - def bias_initialization_method(self) -> Initializer: - if self.bias_initialization.has_initialization: - return self.bias_initialization.get_initializer() - else: - return init_zeros_ - - @property - def module_class(self): - from fast_llm.layers.common.normalization import LayerNormalization - - return LayerNormalization - - -@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) -class RMSNormalizationConfig(LayerNormalizationBaseConfig): - _abstract = False - - @property - def module_class(self): - from fast_llm.layers.common.normalization import RMSNormalization - - return RMSNormalization @config_class() diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py new file mode 100644 index 000000000..a727e4c69 --- /dev/null +++ b/fast_llm/layers/common/normalization/__init__.py @@ -0,0 +1,311 @@ +import abc + +import torch + +from fast_llm.config import Configurable +from fast_llm.engine.config_utils.run import log_main_rank +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.functional.config import TritonConfig +from fast_llm.functional.triton.normalization import triton_normalization_autograd +from fast_llm.layers.common.normalization.config import ( + LayerNormalizationConfig, + NoNormalizationConfig, + NormalizationConfig, + NormalizationImplementation, + RMSNormalizationConfig, +) +from fast_llm.tensor import ParameterMeta, accumulate_gradient +from fast_llm.utils import Assert + +try: + import fused_layer_norm_cuda # noqa + + _fused_normalization_available = True +except ImportError: + _fused_normalization_available = False + +try: + import fast_layer_norm # noqa + + _fast_normalization_available = True +except ImportError: + _fast_normalization_available = False + + +_PERSIST_LN_SIZES = ( + 1024, + 1536, + 2048, + 2304, + 3072, + 3840, + 4096, + 5120, + 6144, + 8192, + 10240, + 12288, + 12800, + 15360, + 16384, + 18432, + 20480, + 24576, + 25600, + 30720, + 32768, + 40960, + 49152, + 65536, +) + + +class FastLayerNorm(torch.autograd.Function): + """ + The fast layer normalization implementation from `apex.contrib`. + Faster than `FusedLayerNorm`, but doesn't support all layer widths. + TODO: Move to functional. + """ + + @staticmethod + def forward( + ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float + ) -> torch.Tensor: # noqa + assert _fast_normalization_available + Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES) + output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps) + ctx.save_for_backward(output, weight, bias, inv_var) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa + output, weight, bias, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. + grad_input, grad_weight, grad_bias, _, _ = fast_layer_norm.ln_bwd( + grad_output, output, None, inv_var, weight, bias, True + ) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) + if bias.requires_grad: + accumulate_gradient(bias, grad_bias) + return grad_input, None, None, None, None + + +class FusedLayerNorm(torch.autograd.Function): + """ + The fused layer normalization implementation from `apex`. + Faster than the stock pytorch implementation, supports all layer widths. + TODO: Move to functional. + """ + + @staticmethod + def forward( + ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float + ) -> torch.Tensor: # noqa + assert _fused_normalization_available + ctx.eps = eps + ctx.normalized_shape = normalized_shape + output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps) + ctx.save_for_backward(output, weight, bias, inv_var) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa + output, weight, bias, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. + grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( + grad_output, None, inv_var, output, ctx.normalized_shape, weight, bias, ctx.eps, True + ) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) + if bias.requires_grad: + accumulate_gradient(bias, grad_bias) + return grad_input, None, None, None, None + + +class FusedRMSNorm(torch.autograd.Function): + @staticmethod + def forward( + ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float + ) -> torch.Tensor: # noqa + assert _fused_normalization_available + ctx.eps = eps + ctx.normalized_shape = normalized_shape + output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps) + ctx.save_for_backward(output, weight, inv_var) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: # noqa + output, weight, inv_var = ctx.saved_tensors + # TODO: Gradients may be computed unnecessarily. + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( + grad_output.contiguous(), inv_var, output, ctx.normalized_shape, weight, ctx.eps, True + ) + if weight.requires_grad: + accumulate_gradient(weight, grad_weight) + return grad_input, None, None, None + + +class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): + def __init__( + self, + config: NormalizationConfig, + hidden_dim: TensorDim, + lr_scale: float | None = None, + ): + super().__init__(config) + self._hidden_dim = hidden_dim + self._lr_scale = lr_scale + assert not self._hidden_dim.is_parallel + + @abc.abstractmethod + def forward(self, input_: torch.Tensor) -> torch.Tensor: + pass + + +class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + +class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): + """ + A layer normalization layer, supporting multiple implementations. + Note: Converting input automatically to training dtype to match Apex behaviour, + needed for full precision residual. + TODO: Review this? + """ + + def __init__( + self, + config: LayerNormalizationConfig, + hidden_dim: TensorDim, + lr_scale: float | None = None, + ): + super().__init__(config, hidden_dim, lr_scale) + implementation = self._config.implementation + if implementation == NormalizationImplementation.auto: + if ( + _fast_normalization_available + and hidden_dim.size in _PERSIST_LN_SIZES + and not self._config.zero_centered + ): + implementation = NormalizationImplementation.fast + elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: + log_main_rank("Fast layer norm unavailable, using backup triton implementation.") + implementation = NormalizationImplementation.triton + elif _fused_normalization_available: + log_main_rank("Fast layer norm unavailable, using backup fused implementation.") + implementation = NormalizationImplementation.fused + else: + log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") + implementation = NormalizationImplementation.torch + if self._config.zero_centered: + assert implementation == NormalizationImplementation.triton + if implementation == NormalizationImplementation.triton: + self._forward = self._forward_triton + elif implementation == NormalizationImplementation.fast: + self._forward = self._forward_fast + elif implementation == NormalizationImplementation.fused: + self._forward = self._forward_fused + elif implementation == NormalizationImplementation.torch: + self._forward = self._forward_torch + else: + raise NotImplementedError(implementation) + + self.weight = ParameterMeta.from_dims( + (hidden_dim,), + init_method=self._config.weight_initialization_method, + weight_decay=False, + auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, + ) + self.bias = ParameterMeta.from_dims( + (hidden_dim,), + init_method=self._config.bias_initialization_method, + weight_decay=False, + auto_grad_accumulation=implementation == NormalizationImplementation.torch, + lr_scale=lr_scale, + ) + self._normalized_shape = self.weight.shape + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) + + def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: + return triton_normalization_autograd( + input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered + ) + + def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: + return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) + + def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: + return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return torch.layer_norm( + input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon + ) + + +class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): + """ + A RMS normalization layer. + Note: Converting input automatically to training dtype to match Apex behaviour, + needed for full precision residual. + TODO: Review this? + """ + + def __init__( + self, + config: RMSNormalizationConfig, + hidden_dim: TensorDim, + lr_scale: float | None = None, + ): + super().__init__(config, hidden_dim, lr_scale) + assert not hidden_dim.is_parallel + implementation = self._config.implementation + if implementation == NormalizationImplementation.auto: + if TritonConfig.TRITON_ENABLED or self._config.zero_centered: + implementation = NormalizationImplementation.triton + elif _fused_normalization_available: + log_main_rank("Triton RMS norm unavailable, using fused implementation.") + implementation = NormalizationImplementation.fused + else: + log_main_rank("Fused RMS norm unavailable, using backup implementation.") + implementation = NormalizationImplementation.torch + if self._config.zero_centered: + assert implementation == NormalizationImplementation.triton + if implementation == NormalizationImplementation.triton: + self._forward = self._forward_triton + elif implementation == NormalizationImplementation.torch: + self._forward = self._forward_torch + elif implementation == NormalizationImplementation.fused: + self._forward = self._forward_fused + else: + raise NotImplementedError(implementation) + + self.weight = ParameterMeta.from_dims( + (hidden_dim,), + init_method=self._config.weight_initialization_method, + weight_decay=False, + auto_grad_accumulation=True, + lr_scale=lr_scale, + ) + self._normalized_shape = self.weight.shape + + def forward(self, input_: torch.Tensor) -> torch.Tensor: + return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) + + def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: + return triton_normalization_autograd( + input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered + ) + + def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: + return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) + + def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: + return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py new file mode 100644 index 000000000..45aa644a7 --- /dev/null +++ b/fast_llm/layers/common/normalization/config.py @@ -0,0 +1,158 @@ +import abc +import enum +import functools +import typing + +from fast_llm.config import Field, FieldHint, check_field, config_class +from fast_llm.engine.base_model.config import BaseModelConfig +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.layers.common.config import PeftConfig +from fast_llm.utils import Assert + +if typing.TYPE_CHECKING: + from fast_llm.layers.common.normalization.normalization import Normalization + + +class NormalizationImplementation(str, enum.Enum): + """ + An enum for the available implementations of layer norm. + """ + + auto = "auto" + torch = "torch" + fused = "fused" + fast = "fast" + triton = "triton" + + +@config_class(registry=True) +class NormalizationConfig(BaseModelConfig): + pass + + @property + @abc.abstractmethod + def module_class(self) -> type["Normalization"]: + pass + + def get_layer( + self, + hidden_dim: "TensorDim", + lr_scale: float | None = None, + peft: PeftConfig | None = None, + ) -> "Normalization": + out = self.module_class(self, hidden_dim, lr_scale) + if peft: + out = peft.apply_other(out) + return out + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is NormalizationConfig and cls.get_subclass(default.get("type")) is None: + # Default subclass. + return LayerNormalizationConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(dynamic_type={NormalizationConfig: "none"}) +class NoNormalizationConfig(NormalizationConfig): + _abstract = False + + @property + def module_class(self) -> type["Normalization"]: + from fast_llm.layers.common.normalization.normalization import NoNormalization + + return NoNormalization + + +@config_class() +class LayerNormalizationBaseConfig(NormalizationConfig): + """ + Common configuration for layer norm and rms norm + """ + + # TODO: Rename to normalization_epsilon + epsilon: float = Field( + default=1e-5, + desc="Regularizer for the division.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + zero_centered: bool = Field( + default=False, + desc="Write the normalization weight as `w = 1 + w'`, to improve numerical accuracy when close to one.", + hint=FieldHint.architecture, + ) + implementation: NormalizationImplementation = Field( + default=NormalizationImplementation.auto, + desc="The implementation to use for the normalization layer.", + hint=FieldHint.performance, + ) + weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the normalization weights. Default: fill with ones", + hint=FieldHint.feature, + ) + lr_scale: float | None = Field( + default=None, + desc="Learning rate scaling factor.", + hint=FieldHint.feature, + ) + + @functools.cached_property + def weight_initialization_method(self) -> Initializer: + if self.weight_initialization.is_default: + return self.weight_initialization.get_initializer() + else: + return init_ones_ + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + cls._handle_renamed_field(default, "normalization_type", "type") + cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") + cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") + cls._handle_renamed_field(default, "normalization_implementation", "implementation") + cls._handle_renamed_field(default, "layer_norm_init_range", "initialization_range") + return super()._from_dict(default, strict, flat) + + +@config_class(dynamic_type={NormalizationConfig: "layer_norm"}) +class LayerNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the normalization biases. Default: fill with zeros", + hint=FieldHint.feature, + ) + + @functools.cached_property + def bias_initialization_method(self) -> Initializer: + if self.bias_initialization.is_default: + return self.bias_initialization.get_initializer() + else: + return init_zeros_ + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import LayerNormalization + + return LayerNormalization + + +@config_class(dynamic_type={NormalizationConfig: "rms_norm"}) +class RMSNormalizationConfig(LayerNormalizationBaseConfig): + _abstract = False + + @property + def module_class(self): + from fast_llm.layers.common.normalization.normalization import RMSNormalization + + return RMSNormalization diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization/normalization.py similarity index 98% rename from fast_llm/layers/common/normalization.py rename to fast_llm/layers/common/normalization/normalization.py index 4af6cb2c3..dac4a7548 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.config import ( +from fast_llm.layers.common.normalization import ( LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -15,7 +15,7 @@ RMSNormalizationConfig, ) from fast_llm.tensor import ParameterMeta, accumulate_gradient -from fast_llm.utils import Assert +from fast_llm.utils import Assert, combine_lr_scales try: import fused_layer_norm_cuda # noqa @@ -156,7 +156,7 @@ def __init__( ): super().__init__(config) self._hidden_dim = hidden_dim - self._lr_scale = lr_scale + self._lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) assert not self._hidden_dim.is_parallel @abc.abstractmethod diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4c7307e1b..b06a870dd 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -204,9 +204,9 @@ def _validate(self) -> None: Assert.eq( len(self.transformer.per_layer_lr_scale), self.transformer.num_blocks + self.prediction_heads - 1 + 1 ) - if self.output_weight_initialization.has_initialization: + if not self.output_weight_initialization.is_default: assert self.use_absolute_position_embeddings - if self.output_weight_initialization.has_initialization: + if not self.output_weight_initialization.is_default: assert not self.tie_word_embeddings @property @@ -215,21 +215,21 @@ def use_absolute_position_embeddings(self) -> int: @functools.cached_property def word_embedding_weight_initialization_method(self) -> Initializer: - if self.word_embedding_weight_initialization.has_initialization: - return self.word_embedding_weight_initialization.get_initializer() - else: + if self.word_embedding_weight_initialization.is_default: return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.word_embedding_weight_initialization.get_initializer() @functools.cached_property def position_embedding_weight_initialization_method(self) -> Initializer: - if self.position_embedding_weight_initialization.has_initialization: - return self.position_embedding_weight_initialization.get_initializer() - else: + if self.position_embedding_weight_initialization.is_default: return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.position_embedding_weight_initialization.get_initializer() @functools.cached_property def output_weight_initialization_method(self) -> Initializer: - if self.output_weight_initialization.has_initialization: - return self.output_weight_initialization.get_initializer() - else: + if self.output_weight_initialization.is_default: return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.output_weight_initialization.get_initializer() diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 4ae6b4821..374acffd2 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -14,7 +14,7 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import div, get_lr_scale +from fast_llm.utils import combine_lr_scales, div logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ def __init__( self._local_bc_size = bc_dim.size layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) # TODO: double check initializations # Projections diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 99609b1c4..bcb98d7c8 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -13,7 +13,7 @@ from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn as _mamba_inner_fn # noqa @@ -88,7 +88,7 @@ def __init__( x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) # TODO: Backward compatibility? # TODO: lr_scale? diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a5797a50d..09b96b7dd 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -13,7 +13,7 @@ from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta -from fast_llm.utils import Assert, div, get_lr_scale +from fast_llm.utils import Assert, combine_lr_scales, div try: from mamba_ssm.ops.selective_scan_interface import selective_scan_fn # noqa @@ -60,7 +60,9 @@ def __init__( layer_lr_scale: float | None = ( block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None ) - lr_scale: float | tuple[float | None, ...] | None = get_lr_scale(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale: float | tuple[float | None, ...] | None = combine_lr_scales( + self._config.mamba_lr_scale, layer_lr_scale + ) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 888d1bec7..26503dd2b 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -12,7 +12,7 @@ from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs -from fast_llm.utils import get_lr_scale +from fast_llm.utils import combine_lr_scales try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -112,7 +112,7 @@ def __init__( ) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) + attention_lr_scale = combine_lr_scales(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( @@ -136,7 +136,7 @@ def __init__( self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) # Rotary embeddings. - self._rotary = self._config.rotary.build(kv_channels_dim) + self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. self.dense = InputParallelLinear( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index 0b76a1e87..a55888fa3 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -8,7 +8,7 @@ from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import AddLinearBiasChoices, BlockConfig, BlockDimNames, BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, MixerConfig from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div @@ -136,9 +136,9 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") Assert.multiple(self.num_attention_heads, self.head_groups) - if self.qkv_bias_initialization.has_initialization: + if not self.qkv_bias_initialization.is_default: assert self.add_qkv_bias - if self.dense_bias_initialization.has_initialization: + if not self.dense_bias_initialization.is_default: assert self.add_dense_bias @functools.cached_property @@ -151,43 +151,39 @@ def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: @functools.cached_property def add_qkv_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - return self.block.add_linear_biases != AddLinearBiasChoices.nowhere + return self.block.add_linear_biases @functools.cached_property def add_dense_bias(self) -> bool: - if isinstance(self.block.add_linear_biases, bool): - return self.block.add_linear_biases - return self.block.add_linear_biases == AddLinearBiasChoices.everywhere + return self.block.add_linear_biases @functools.cached_property def qkv_weight_initialization_method(self) -> Initializer: - if self.qkv_weight_initialization.has_initialization: - return self.qkv_weight_initialization.get_initializer() - else: + if self.qkv_weight_initialization.is_default: return init_normal_(0, self.block.hidden_size**-0.5) + else: + return self.qkv_weight_initialization.get_initializer() @functools.cached_property def qkv_bias_initialization_method(self) -> Initializer: - if self.qkv_bias_initialization.has_initialization: - return self.qkv_bias_initialization.get_initializer() - else: + if self.qkv_bias_initialization.is_default: return init_zeros_ + else: + return self.qkv_bias_initialization.get_initializer() @functools.cached_property def dense_weight_initialization_method(self) -> Initializer: - if self.dense_weight_initialization.has_initialization: - return self.dense_weight_initialization.get_initializer() - else: + if self.dense_weight_initialization.is_default: return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) + else: + return self.dense_weight_initialization.get_initializer() @functools.cached_property def dense_bias_initialization_method(self) -> Initializer: - if self.dense_bias_initialization.has_initialization: - return self.dense_bias_initialization.get_initializer() - else: + if self.dense_bias_initialization.is_default: return init_zeros_ + else: + return self.dense_bias_initialization.get_initializer() @config_class() diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/transformer/rotary/config.py index f0e0079c7..6cc19fce8 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/transformer/rotary/config.py @@ -29,7 +29,7 @@ def _from_dict( return NoRotaryConfig._from_dict(default, strict, flat) return super()._from_dict(default, strict=strict, flat=flat) - def build(self, kv_channels_dim: TensorDim) -> "Rotary": + def get_layer(self, kv_channels_dim: TensorDim) -> "Rotary": return self._get_configurable_class()(self, kv_channels_dim) @classmethod diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 0ef970db2..58fc2cf44 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,7 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.config import LayerNormalizationConfig +from fast_llm.layers.common.normalization import LayerNormalizationConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2b941a68a..2011903e0 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -49,7 +49,7 @@ def __init__( self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.build(self._tensor_space)) + self._preprocessors.append(self._config.transformer.rotary.get_layer(self._tensor_space)) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index 11d888eaf..012f2fae4 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( diff --git a/fast_llm/utils.py b/fast_llm/utils.py index 58285d408..f7f5e9663 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,22 +348,29 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def get_lr_scale( - lr_scale: float | None | tuple[float | None, ...], layer_lr_scale: float | None -) -> float | None | tuple[float | None, ...]: - """ - Combine module and layer lr_scale. - If one is None, return the other. - """ - if lr_scale is None: - return layer_lr_scale - if layer_lr_scale is None: - return lr_scale - if isinstance(lr_scale, float): - return lr_scale * layer_lr_scale - if isinstance(lr_scale, tuple): - return tuple(lrs * layer_lr_scale if lrs is not None else layer_lr_scale for lrs in lr_scale) - raise ValueError(f"Invalid lr_scale: {lr_scale} (type {type(lr_scale)})") +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): + # Remove `None` entries. + lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + if not lr_scales: + # Everything is None + return None + tuple_length = None + # Check if we have tuples, and determine the length. + for lr_scale in lr_scales: + if isinstance(lr_scale, tuple): + if tuple_length is None: + tuple_length = len(lr_scale) + else: + assert len(lr_scale) == tuple_length + if tuple_length is None: + # No tuple: simple product. + return math.prod(lr_scales) + else: + # Tuple(s): use recursion. + return [ + combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) + for i in range(tuple_length) + ] class Interrupter: diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index e4ad937b7..3f4446e4d 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -92,7 +92,7 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y1 = apply_rotary_embeddings( x, DefaultRotaryConfig(triton=False) - .build(None) + .get_layer(None) ._get_frequencies( sequence_length, kv_channels, @@ -103,7 +103,9 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, kv_channels): y2 = convert_rotary_real_to_complex( triton_rotary_( convert_rotary_complex_to_real(x, kv_channels, 3), - DefaultRotaryConfig(triton=True).build(None)._get_frequencies(sequence_length, kv_channels, device="cuda"), + DefaultRotaryConfig(triton=True) + .get_layer(None) + ._get_frequencies(sequence_length, kv_channels, device="cuda"), ), kv_channels, 3, From 0d2fc8974aca7bef3ca15b577ab8b603732be3f8 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 19:57:31 -0400 Subject: [PATCH 11/19] stuff --- .../layers/common/normalization/__init__.py | 311 ------------------ fast_llm/layers/ssm/discrete_mamba2.py | 4 +- fast_llm/layers/transformer/block.py | 1 - fast_llm/layers/transformer/config.py | 15 +- fast_llm/models/gpt/conversion.py | 2 +- fast_llm/models/gpt/model.py | 34 +- 6 files changed, 22 insertions(+), 345 deletions(-) diff --git a/fast_llm/layers/common/normalization/__init__.py b/fast_llm/layers/common/normalization/__init__.py index a727e4c69..e69de29bb 100644 --- a/fast_llm/layers/common/normalization/__init__.py +++ b/fast_llm/layers/common/normalization/__init__.py @@ -1,311 +0,0 @@ -import abc - -import torch - -from fast_llm.config import Configurable -from fast_llm.engine.config_utils.run import log_main_rank -from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.normalization.config import ( - LayerNormalizationConfig, - NoNormalizationConfig, - NormalizationConfig, - NormalizationImplementation, - RMSNormalizationConfig, -) -from fast_llm.tensor import ParameterMeta, accumulate_gradient -from fast_llm.utils import Assert - -try: - import fused_layer_norm_cuda # noqa - - _fused_normalization_available = True -except ImportError: - _fused_normalization_available = False - -try: - import fast_layer_norm # noqa - - _fast_normalization_available = True -except ImportError: - _fast_normalization_available = False - - -_PERSIST_LN_SIZES = ( - 1024, - 1536, - 2048, - 2304, - 3072, - 3840, - 4096, - 5120, - 6144, - 8192, - 10240, - 12288, - 12800, - 15360, - 16384, - 18432, - 20480, - 24576, - 25600, - 30720, - 32768, - 40960, - 49152, - 65536, -) - - -class FastLayerNorm(torch.autograd.Function): - """ - The fast layer normalization implementation from `apex.contrib`. - Faster than `FusedLayerNorm`, but doesn't support all layer widths. - TODO: Move to functional. - """ - - @staticmethod - def forward( - ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float - ) -> torch.Tensor: # noqa - assert _fast_normalization_available - Assert.incl(normalized_shape.numel(), _PERSIST_LN_SIZES) - output, _, inv_var = fast_layer_norm.ln_fwd(input_, weight, bias, eps) - ctx.save_for_backward(output, weight, bias, inv_var) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa - output, weight, bias, inv_var = ctx.saved_tensors - # TODO: Gradients may be computed unnecessarily. - grad_input, grad_weight, grad_bias, _, _ = fast_layer_norm.ln_bwd( - grad_output, output, None, inv_var, weight, bias, True - ) - if weight.requires_grad: - accumulate_gradient(weight, grad_weight) - if bias.requires_grad: - accumulate_gradient(bias, grad_bias) - return grad_input, None, None, None, None - - -class FusedLayerNorm(torch.autograd.Function): - """ - The fused layer normalization implementation from `apex`. - Faster than the stock pytorch implementation, supports all layer widths. - TODO: Move to functional. - """ - - @staticmethod - def forward( - ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, bias: torch.Tensor, eps: float - ) -> torch.Tensor: # noqa - assert _fused_normalization_available - ctx.eps = eps - ctx.normalized_shape = normalized_shape - output, _, inv_var = fused_layer_norm_cuda.forward_affine(input_, normalized_shape, weight, bias, eps) - ctx.save_for_backward(output, weight, bias, inv_var) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None, None]: # noqa - output, weight, bias, inv_var = ctx.saved_tensors - # TODO: Gradients may be computed unnecessarily. - grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output, None, inv_var, output, ctx.normalized_shape, weight, bias, ctx.eps, True - ) - if weight.requires_grad: - accumulate_gradient(weight, grad_weight) - if bias.requires_grad: - accumulate_gradient(bias, grad_bias) - return grad_input, None, None, None, None - - -class FusedRMSNorm(torch.autograd.Function): - @staticmethod - def forward( - ctx, input_: torch.Tensor, normalized_shape: torch.Size, weight: torch.Tensor, eps: float - ) -> torch.Tensor: # noqa - assert _fused_normalization_available - ctx.eps = eps - ctx.normalized_shape = normalized_shape - output, inv_var = fused_layer_norm_cuda.rms_forward_affine(input_, normalized_shape, weight, eps) - ctx.save_for_backward(output, weight, inv_var) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, None]: # noqa - output, weight, inv_var = ctx.saved_tensors - # TODO: Gradients may be computed unnecessarily. - grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( - grad_output.contiguous(), inv_var, output, ctx.normalized_shape, weight, ctx.eps, True - ) - if weight.requires_grad: - accumulate_gradient(weight, grad_weight) - return grad_input, None, None, None - - -class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): - def __init__( - self, - config: NormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): - super().__init__(config) - self._hidden_dim = hidden_dim - self._lr_scale = lr_scale - assert not self._hidden_dim.is_parallel - - @abc.abstractmethod - def forward(self, input_: torch.Tensor) -> torch.Tensor: - pass - - -class NoNormalization[ConfigType: NoNormalizationConfig](Normalization[ConfigType]): - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return input_ - - -class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[ConfigType]): - """ - A layer normalization layer, supporting multiple implementations. - Note: Converting input automatically to training dtype to match Apex behaviour, - needed for full precision residual. - TODO: Review this? - """ - - def __init__( - self, - config: LayerNormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): - super().__init__(config, hidden_dim, lr_scale) - implementation = self._config.implementation - if implementation == NormalizationImplementation.auto: - if ( - _fast_normalization_available - and hidden_dim.size in _PERSIST_LN_SIZES - and not self._config.zero_centered - ): - implementation = NormalizationImplementation.fast - elif TritonConfig.TRITON_ENABLED or self._config.zero_centered: - log_main_rank("Fast layer norm unavailable, using backup triton implementation.") - implementation = NormalizationImplementation.triton - elif _fused_normalization_available: - log_main_rank("Fast layer norm unavailable, using backup fused implementation.") - implementation = NormalizationImplementation.fused - else: - log_main_rank("Fast and fused layer norm unavailable, using backup pytorch implementation.") - implementation = NormalizationImplementation.torch - if self._config.zero_centered: - assert implementation == NormalizationImplementation.triton - if implementation == NormalizationImplementation.triton: - self._forward = self._forward_triton - elif implementation == NormalizationImplementation.fast: - self._forward = self._forward_fast - elif implementation == NormalizationImplementation.fused: - self._forward = self._forward_fused - elif implementation == NormalizationImplementation.torch: - self._forward = self._forward_torch - else: - raise NotImplementedError(implementation) - - self.weight = ParameterMeta.from_dims( - (hidden_dim,), - init_method=self._config.weight_initialization_method, - weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, - ) - self.bias = ParameterMeta.from_dims( - (hidden_dim,), - init_method=self._config.bias_initialization_method, - weight_decay=False, - auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, - ) - self._normalized_shape = self.weight.shape - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) - - def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd( - input_, self.weight, self.bias, self._config.epsilon, self.training, self._config.zero_centered - ) - - def _forward_fast(self, input_: torch.Tensor) -> torch.Tensor: - return FastLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) - - def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedLayerNorm.apply(input_, self._normalized_shape, self.weight, self.bias, self._config.epsilon) - - def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.layer_norm( - input_.to(self.weight.dtype), self._normalized_shape, self.weight, self.bias, self._config.epsilon - ) - - -class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): - """ - A RMS normalization layer. - Note: Converting input automatically to training dtype to match Apex behaviour, - needed for full precision residual. - TODO: Review this? - """ - - def __init__( - self, - config: RMSNormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): - super().__init__(config, hidden_dim, lr_scale) - assert not hidden_dim.is_parallel - implementation = self._config.implementation - if implementation == NormalizationImplementation.auto: - if TritonConfig.TRITON_ENABLED or self._config.zero_centered: - implementation = NormalizationImplementation.triton - elif _fused_normalization_available: - log_main_rank("Triton RMS norm unavailable, using fused implementation.") - implementation = NormalizationImplementation.fused - else: - log_main_rank("Fused RMS norm unavailable, using backup implementation.") - implementation = NormalizationImplementation.torch - if self._config.zero_centered: - assert implementation == NormalizationImplementation.triton - if implementation == NormalizationImplementation.triton: - self._forward = self._forward_triton - elif implementation == NormalizationImplementation.torch: - self._forward = self._forward_torch - elif implementation == NormalizationImplementation.fused: - self._forward = self._forward_fused - else: - raise NotImplementedError(implementation) - - self.weight = ParameterMeta.from_dims( - (hidden_dim,), - init_method=self._config.weight_initialization_method, - weight_decay=False, - auto_grad_accumulation=True, - lr_scale=lr_scale, - ) - self._normalized_shape = self.weight.shape - - def forward(self, input_: torch.Tensor) -> torch.Tensor: - return self._forward(input_.view(-1, *self._normalized_shape)).view_as(input_) - - def _forward_triton(self, input_: torch.Tensor) -> torch.Tensor: - return triton_normalization_autograd( - input_, self.weight, None, self._config.epsilon, self.training, self._config.zero_centered - ) - - def _forward_fused(self, input_: torch.Tensor) -> torch.Tensor: - return FusedRMSNorm.apply(input_, self._normalized_shape, self.weight, self._config.epsilon) - - def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: - return torch.rms_norm(input_.to(self.weight.dtype), self._normalized_shape, self.weight, self._config.epsilon) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 374acffd2..f1020a903 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -36,7 +36,9 @@ class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): - """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + """ + This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py. + """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index dd81a4da5..a5aad45a9 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -11,7 +11,6 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - _config: TransformerConfig def _create_mixer(self) -> BlockLayer: return Attention( diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/transformer/config.py index a55888fa3..d8c5bf923 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/transformer/config.py @@ -8,26 +8,13 @@ from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig -from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs, MixerConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) -class AttentionDimNames(BlockDimNames): - # A set of common tensor dim names packed into a namespace. - # Self-attention dimensions - head_groups = "head_groups" - group_heads = "group_heads" - key_and_value = "key_value" - kv_channels = "kv_channels" - composite_heads = "composite_heads" - composite_query = "composite_query" - composite_key_value = "composite_key_value" - composite_dense = "composite_dense" - - class AttentionKwargs(BlockKwargs): rotary_freq_q = "rotary_freq_q" rotary_freq_k = "rotary_freq_k" diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 58fc2cf44..de9b51b08 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -25,7 +25,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.block.mlp.config import RoutingType -from fast_llm.layers.common.normalization import LayerNormalizationConfig +from fast_llm.layers.common.normalization.config import LayerNormalizationConfig from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 2011903e0..9422900a8 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,13 +10,14 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionDimNames, AttentionKwargs +from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron @@ -49,7 +50,9 @@ def __init__( self._preprocessors.append(PositionEmbeddingPreprocessor(self._config, self._distributed_config)) # We have multiple identical rotary modules/preprocessors, so it's simpler to make a new one here. # TODO: Find a better solution. - self._preprocessors.append(self._config.transformer.rotary.get_layer(self._tensor_space)) + self._preprocessors.append( + self._config.transformer.rotary.get_layer(TensorDim("kv_channels", self._config.transformer.kv_channels)) + ) if self._use_flash_attention: self._preprocessors.append(FlashAttnVarlenPreprocessor(self._config.transformer, self._distributed_config)) else: @@ -137,8 +140,8 @@ def preprocess_meta( micro_sequence_length = sequence_length truncate_documents = True - batch_data = self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - batch_dim = TensorDim(AttentionDimNames.batch, micro_batch_size * batch_data.size, batch_data) + batch_data = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) + batch_dim = TensorDim(BlockDimNames.batch, micro_batch_size * batch_data.size, batch_data) if micro_sequence_length is None: micro_sequence_length = sequence_length @@ -147,19 +150,17 @@ def preprocess_meta( # TODO: Calculate hidden dims elsewhere? sequence_q_dim = TensorDim( - AttentionDimNames.sequence_q, + BlockDimNames.sequence_q, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), + self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) hidden_sequence_q_dim = ( TensorDim( - AttentionDimNames.sequence_q_tp, + BlockDimNames.sequence_q_tp, micro_sequence_length, - self._tensor_space.distributed_config.get_distributed_dim( - DistributedDimNames.tensor_and_sequence_data - ), + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), ) - if self._tensor_space.distributed_config.sequence_tensor_parallel + if self._distributed_config.sequence_tensor_parallel else sequence_q_dim ) @@ -170,11 +171,10 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space[AttentionDimNames.hidden] hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, hidden_dim) + (hidden_sequence_q_dim, batch_dim, self._hidden_dim) if sequence_first - else (batch_dim, hidden_sequence_q_dim, hidden_dim) + else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) ) common_kwargs = { @@ -187,7 +187,7 @@ def preprocess_meta( } sequence_k_pasts = range( - sequence_q_dim.size * self._tensor_space.distributed_config.sequence_data_rank, + sequence_q_dim.size * self._distributed_config.sequence_data_rank, sequence_length, micro_sequence_length, ) @@ -201,7 +201,7 @@ def preprocess_meta( preprocessed_meta = [] for i, sequence_k_past in enumerate(sequence_k_pasts): sequence_k = sequence_k_past + sequence_q_dim.size - sequence_k_dim = TensorDim(AttentionDimNames.sequence_k, sequence_k) + sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 @@ -255,7 +255,7 @@ def preprocess( prediction_heads: int = self._config.prediction_heads batch.token_ids = batch.token_ids.to( - device=self._tensor_space.distributed.device, + device=self._distributed.device, dtype=torch.int64, non_blocking=True, ) From a7cb0189975f1fd32af9ccdeb1f4d4dd13b39ba1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 20:02:08 -0400 Subject: [PATCH 12/19] stuff --- fast_llm/layers/block/block.py | 11 ++++++----- fast_llm/layers/block/config.py | 16 ++++++++++++++-- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index d63ac78c1..c3f7f7d86 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -101,18 +101,19 @@ def __init__( name: str, ): super().__init__(config, distributed_config) + self._block_config = block_config self._hidden_dim = hidden_dim self._block_index = block_index self._name = name self._sequence_parallel: bool = self._distributed_config.sequence_tensor_parallel self._debug = DebugLayer( self._name, - block_config.debug_transformer, - block_config.debug_transformer_memory, + self._block_config.debug_transformer, + self._block_config.debug_transformer_memory, ) -class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType], torch.nn.Module): +class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -154,8 +155,8 @@ def __init__( self._return_input: bool = return_input # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale - self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) - self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(hidden_dim)) + self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 3c9da42f6..ef72cf1e7 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -52,9 +52,21 @@ def layer_class(self) -> "type[BlockLayer]": raise NotImplementedError() def get_layer( - self, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, ) -> "BlockLayer": - return self.layer_class(self, distributed_config, hidden_dim, block_index, name) + return self.layer_class( + self, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + ) @config_class(registry=True) From ddf3ac2c9ed64a9d850e05ddc669986bb516eba9 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 14 Aug 2025 20:40:15 -0400 Subject: [PATCH 13/19] peft --- fast_llm/layers/block/config.py | 6 +- fast_llm/layers/block/mlp/mlp.py | 13 +- fast_llm/layers/block/peft.py | 128 ------------------ fast_llm/layers/common/config.py | 58 -------- .../layers/common/normalization/config.py | 2 +- fast_llm/layers/common/peft/__init__.py | 0 fast_llm/layers/common/peft/config.py | 91 +++++++++++++ fast_llm/layers/common/{ => peft}/peft.py | 44 +++--- fast_llm/layers/transformer/attention.py | 28 ++-- 9 files changed, 140 insertions(+), 230 deletions(-) delete mode 100644 fast_llm/layers/block/peft.py delete mode 100644 fast_llm/layers/common/config.py create mode 100644 fast_llm/layers/common/peft/__init__.py create mode 100644 fast_llm/layers/common/peft/config.py rename fast_llm/layers/common/{ => peft}/peft.py (70%) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index ef72cf1e7..9df11bc44 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -4,8 +4,8 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.peft import TransformerPeftConfig -from fast_llm.layers.common.normalization import NormalizationConfig +from fast_llm.layers.common.normalization.config import NormalizationConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -136,7 +136,7 @@ class BlockConfig(BaseModelConfig): desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, ) - peft: TransformerPeftConfig = Field( + peft: PeftConfig = Field( desc="Configuration for the parameter-efficient fine tuning.", hint=FieldHint.architecture, ) diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index cc4562dfc..bd85e2089 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -7,8 +7,8 @@ from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd from fast_llm.layers.block.block import BlockLayer +from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig -from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import LinearBase from fast_llm.utils import Assert, combine_lr_scales @@ -17,18 +17,21 @@ class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, ): - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None + ) lr_scale = ( tuple(self._config.mlp_lr_scale) if isinstance(self._config.mlp_lr_scale, list) @@ -57,8 +60,8 @@ def __init__( ) # PEFT. - self.layer_1 = self._config.peft.apply_linear(self.layer_1, TransformerSubLayerName.mlp_1) - self.layer_2 = self._config.peft.apply_linear(self.layer_2, TransformerSubLayerName.mlp_2) + self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, False) + self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, False) def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) diff --git a/fast_llm/layers/block/peft.py b/fast_llm/layers/block/peft.py deleted file mode 100644 index 66bc675ed..000000000 --- a/fast_llm/layers/block/peft.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -TODO: Generalize beyond transformers. -""" - -import abc -import enum -import typing - -from fast_llm.config import Field, FieldHint, config_class -from fast_llm.layers.common.config import LoRAConfig, NoPeftConfig, PeftConfig -from fast_llm.utils import div - -if typing.TYPE_CHECKING: - import torch - - from fast_llm.layers.common.linear import LinearBase, LinearLike - from fast_llm.tensor import ParameterMeta - - -class TransformerSubLayerName(str, enum.Enum): - # TODO: Use this to replace AddLinearBiasChoices. - query = "query" - key = "key" - value_ = "value" - key_value = "key_value" - dense = "dense" - mlp_1 = "mlp_1" - mlp_2 = "mlp_2" - - -@config_class(registry=True) -class TransformerPeftConfig(PeftConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - pass - - @abc.abstractmethod - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - pass - - @abc.abstractmethod - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - pass - - @classmethod - def _from_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - flat: bool = False, - ) -> typing.Self: - if cls is TransformerPeftConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return TransformerNoPeftConfig._from_dict(default, strict, flat) - return super()._from_dict(default, strict=strict, flat=flat) - - -@config_class(dynamic_type={TransformerPeftConfig: "none"}) -class TransformerNoPeftConfig(NoPeftConfig, TransformerPeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - return super().apply_linear(linear) - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - return parameter - - -@config_class(dynamic_type={TransformerPeftConfig: "lora"}) -class TransformerLoRAConfig(LoRAConfig, TransformerPeftConfig): - layers: list[TransformerSubLayerName] = Field( - default=(TransformerSubLayerName.query, TransformerSubLayerName.value_), - desc="The layers on which to apply LoRA.", - hint=FieldHint.feature, - ) - freeze_others: bool = Field( - default=True, - desc="Whether to freeze other layers during training.", - ) - - def apply_linear(self, linear: "LinearBase", layer_type: TransformerSubLayerName | None = None) -> "LinearLike": - if layer_type is None or self.layers is None or layer_type in self.layers: - if layer_type == TransformerSubLayerName.key: - return super().apply_linear(linear, out_channel_end=div(linear._out_dim.global_size, 2)) - elif layer_type == TransformerSubLayerName.value_: - return super().apply_linear(linear, out_channel_begin=div(linear._out_dim.global_size, 2)) - else: - return super().apply_linear(linear) - elif self.freeze_others: - linear.weight.requires_grad = False - return linear - - def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": - if self.freeze_others: - for parameter in module.parameters(): - parameter.requires_grad = False - return module - - def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": - if self.freeze_others: - parameter.requires_grad = False - return parameter - - def _validate(self) -> None: - super()._validate() - if TransformerSubLayerName.mlp_1 in self.layers or TransformerSubLayerName.mlp_2 in self.layers: - # TODO: Add MLP support. - raise NotImplementedError("LoRA not supported for MLP.") - if TransformerSubLayerName.dense in self.layers: - # TODO: Support InputParallelLinear (different output format). - raise NotImplementedError("LoRA not supported for attention dense layer.") - if ( - sum( - name in self.layers - for name in ( - TransformerSubLayerName.key_value, - TransformerSubLayerName.key, - TransformerSubLayerName.value_, - ) - ) - > 1 - ): - raise ValueError( - f"{TransformerSubLayerName.key_value.value}, {TransformerSubLayerName.key.value} and {TransformerSubLayerName.value_.value} are mutually exclusive." - ) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py deleted file mode 100644 index b09672961..000000000 --- a/fast_llm/layers/common/config.py +++ /dev/null @@ -1,58 +0,0 @@ -import abc -import typing - -from fast_llm.config import Field, FieldHint, config_class -from fast_llm.engine.base_model.config import BaseModelConfig - -if typing.TYPE_CHECKING: - from fast_llm.layers.common.linear import LinearBase, LinearLike - - -@config_class() -class PeftConfig(BaseModelConfig): - @abc.abstractmethod - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - pass - - -@config_class() -class NoPeftConfig(PeftConfig): - _abstract = False - - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - return linear - - -@config_class() -class LoRAConfig(PeftConfig): - _abstract = False - - rank: int = Field( - default=8, - desc="The LoRA rank, i.e. the size of the intermediate dimension.", - hint=FieldHint.stability, - ) - alpha: float = Field( - default=8.0, - desc="The LoRA scaling parameter.", - hint=FieldHint.stability, - ) - dropout: float = Field( - default=0.0, - desc="Dropout rate for LoRA.", - hint=FieldHint.stability, - ) - - def apply_linear(self, linear: "LinearBase", **kwargs) -> "LinearLike": - from fast_llm.layers.common.peft import lora_linear - - # TODO: Init method? - return lora_linear( - linear, - linear.weight.param_init_method, - linear.weight.param_init_method, - self.rank, - self.alpha, - self.dropout, - **kwargs, - ) diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 45aa644a7..12f7c5ee7 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -7,7 +7,7 @@ from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.layers.common.config import PeftConfig +from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: diff --git a/fast_llm/layers/common/peft/__init__.py b/fast_llm/layers/common/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py new file mode 100644 index 000000000..d1e21e340 --- /dev/null +++ b/fast_llm/layers/common/peft/config.py @@ -0,0 +1,91 @@ +import typing + +from fast_llm.config import Field, FieldHint, config_class +from fast_llm.engine.base_model.config import BaseModelConfig + +if typing.TYPE_CHECKING: + import torch + + from fast_llm.layers.common.linear import LinearBase, LinearLike, OutputParallelLinear + from fast_llm.layers.common.normalization.normalization import Normalization + from fast_llm.tensor import ParameterMeta + + +@config_class(registry=True) +class PeftConfig(BaseModelConfig): + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + return self.apply_other(module) + + def apply_normalization(self, module: "Normalization") -> "Normalization": + return self.apply_other(module) + + def apply_other(self, module: "torch.nn.Module") -> "torch.nn.Module": + for parameter in module.parameters(): + self.apply_weight(parameter) + return module + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + return parameter + + +@config_class(dynamic_type={PeftConfig: "none"}) +class NoPeftConfig(PeftConfig): + _abstract = False + + +@config_class(dynamic_type={PeftConfig: "lora"}) +class LoRAConfig(PeftConfig): + _abstract = False + + rank: int = Field( + default=8, + desc="The LoRA rank, i.e. the size of the intermediate dimension.", + hint=FieldHint.stability, + ) + alpha: float = Field( + default=8.0, + desc="The LoRA scaling parameter.", + hint=FieldHint.stability, + ) + dropout: float = Field( + default=0.0, + desc="Dropout rate for LoRA.", + hint=FieldHint.stability, + ) + freeze_others: bool = Field( + default=True, + desc="Whether to freeze other layers during training.", + ) + + def apply_linear( + self, + module: "LinearBase", + enabled: bool, + out_channel_begin: int | None = None, + out_channel_end: int | None = None, + ) -> "LinearLike": + if not enabled: + return self.apply_other(module) + + from fast_llm.layers.common.linear import InputParallelLinear + from fast_llm.layers.common.peft.peft import lora_linear + + if isinstance(module, InputParallelLinear): + # TODO: Support InputParallelLinear (different output format). + raise NotImplementedError("LoRA not supported for InputParallelLinear.") + elif isinstance(module, OutputParallelLinear): + assert out_channel_begin is None and out_channel_end is None + + # TODO: Init method? + return lora_linear(module, self.rank, self.alpha, self.dropout, out_channel_begin, out_channel_end) + + def apply_weight(self, parameter: "ParameterMeta") -> "ParameterMeta": + if self.freeze_others: + parameter.requires_grad = False + return parameter diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft/peft.py similarity index 70% rename from fast_llm/layers/common/peft.py rename to fast_llm/layers/common/peft/peft.py index 87991ef29..9e0ca0dd0 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft/peft.py @@ -8,21 +8,19 @@ def lora_linear( - layer: LinearBase, - init_method_0, - init_method_1, + module: LinearBase, rank: int, alpha: float, dropout: float = 0.0, out_channel_begin: int | None = None, out_channel_end: int | None = None, ): - layer.weight.requires_grad = False - in_dim = layer._in_dim + module.weight.requires_grad = False + in_dim = module._in_dim assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: in_dim = TensorDim(in_dim.name, in_dim.global_size) - out_dim = layer._out_dim + out_dim = module._out_dim assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: out_dim = TensorDim(out_dim.name, out_dim.global_size) @@ -36,27 +34,27 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) - layer.lora_0 = Linear( + module.lora_0 = Linear( in_dim, middle_dim, bias=False, - weight_init_method=init_method_0, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) - layer.lora_1 = Linear( + module.lora_1 = Linear( middle_dim, out_dim, bias=False, - weight_init_method=init_method_1, - transposed_weight=layer.transposed_weight, - lr_scale=layer.weight.lr_scale, + weight_init_method=module.weight.param_init_method, + transposed_weight=module.transposed_weight, + lr_scale=module.weight.lr_scale, ) # TODO: Implement proper backward pass. - layer.lora_0.weight.auto_grad_accumulation = True - layer.lora_1.weight.auto_grad_accumulation = True + module.lora_0.weight.auto_grad_accumulation = True + module.lora_1.weight.auto_grad_accumulation = True - old_forward = layer._forward + old_forward = module._forward def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: # TODO: torch compile? @@ -66,8 +64,8 @@ def forward_only(input_: torch.Tensor) -> tuple[torch.Tensor, tuple[torch.Tensor if isinstance(output, tuple): layer_out, tp_bias = output[0] assert tp_bias is None - lora_out = (alpha / rank) * layer.lora_1( - layer.lora_0(torch.dropout(input_, dropout, layer.training) if dropout > 0.0 else input_) + lora_out = (alpha / rank) * module.lora_1( + module.lora_0(torch.dropout(input_, dropout, module.training) if dropout > 0.0 else input_) ) if out_channel_begin is None: output = output + lora_out @@ -83,8 +81,8 @@ def backward( output.backward(grad_output) return input_.grad - layer._forward = wrap_forward_backward(forward_only, backward) - layer.forward_only = forward_only - layer.backward = backward + module._forward = wrap_forward_backward(forward_only, backward) + module.forward_only = forward_only + module.backward = backward - return layer + return module diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 26503dd2b..c1f8010ba 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,15 +4,16 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim +from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs -from fast_llm.utils import combine_lr_scales +from fast_llm.utils import combine_lr_scales, div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -101,14 +102,14 @@ def __init__( self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, + std=self._block_config.init_method_std_qkv, + min_val=self._block_config.init_method_min_qkv, + max_val=self._block_config.init_method_max_qkv, ) init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, + std=self._block_config.init_method_std_attn_proj, + min_val=self._block_config.init_method_min_attn_proj, + max_val=self._block_config.init_method_max_attn_proj, ) layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None @@ -149,9 +150,12 @@ def __init__( lr_scale=attention_lr_scale, ) # PEFT. - self.query = self._config.peft.apply_linear(self.query, TransformerSubLayerName.query) - self.key_value = self._config.peft.apply_linear(self.key_value, TransformerSubLayerName.key_value) - self.dense = self._config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) + self.query = self._block_config.peft.apply_linear(self.query, True) + + self.key_value = self._block_config.peft.apply_linear( + self.key_value, True, out_channel_begin=div(self.key_value._out_dim.global_size, 2) + ) + self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) if self._debug.enabled: self._query_dims = ( From e7741b71aa686507a80c2a8658080dbbd2418f39 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 13:44:38 -0400 Subject: [PATCH 14/19] stuff --- fast_llm/layers/block/block.py | 19 +++++++++++-- fast_llm/layers/block/config.py | 2 ++ fast_llm/layers/block/mlp/config.py | 2 +- .../layers/block/mlp/mixture_of_experts.py | 9 +++---- fast_llm/layers/block/mlp/mlp.py | 7 +++-- fast_llm/layers/ssm/block.py | 27 +++++++++---------- fast_llm/layers/ssm/config.py | 2 +- fast_llm/layers/ssm/discrete_mamba2.py | 15 +++-------- fast_llm/layers/ssm/mamba.py | 13 +++------ fast_llm/layers/ssm/mamba2.py | 19 ++++--------- fast_llm/layers/transformer/attention.py | 4 +-- fast_llm/layers/transformer/block.py | 16 ++++++----- 12 files changed, 67 insertions(+), 68 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index c3f7f7d86..da981a0d8 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -11,6 +11,7 @@ from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -99,6 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -111,6 +113,7 @@ def __init__( self._block_config.debug_transformer, self._block_config.debug_transformer_memory, ) + self._lr_scale = lr_scale class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): @@ -141,6 +144,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, return_input: bool = False, ): super().__init__( @@ -150,6 +154,7 @@ def __init__( hidden_dim, block_index, name, + lr_scale, ) # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input @@ -162,13 +167,23 @@ def __init__( self, self._config.mixer.module_name, self._config.mixer.get_layer( - self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mixer" + self._config, + self._distributed_config, + self._hidden_dim, + self._block_index, + f"{self._name} mixer", + self._lr_scale, ), ) self.mlp = self._config.mlp.get_layer( - self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mlp" + self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mlp", self._lr_scale ) + def setup(self, distributed: Distributed) -> None: + super().setup(distributed) + getattr(self, self._config.mixer.module_name).setup(distributed) + self.mlp.setup(distributed) + @torch.compile def _bias_dropout_add( self, input_: torch.Tensor, bias: torch.Tensor | None, residual: torch.Tensor diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 9df11bc44..4dbce8eb0 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -58,6 +58,7 @@ def get_layer( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ) -> "BlockLayer": return self.layer_class( self, @@ -66,6 +67,7 @@ def get_layer( hidden_dim, block_index, name, + lr_scale, ) diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 237a538fa..416670c82 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -92,7 +92,7 @@ class MLPConfig(BlockLayerConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - mlp_lr_scale: float | None | list[float | None] = Field( + mlp_lr_scale: float | None | tuple[float | None] = Field( default=None, desc="Custom learning rate scale for each expert.", doc="May be used to freeze some experts by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index fa7258b7e..b7a50ee8d 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -4,6 +4,7 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped @@ -39,14 +40,12 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, distributed_config, hidden_dim, block_index, name) - - layer_lr_scale = self._config.per_layer_lr_scale[block_index] if self._config.per_layer_lr_scale else None - router_lr_scale = combine_lr_scales(self._config.router_lr_scale, layer_lr_scale) + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale) self.router = Linear( self._hidden_dim, @@ -57,7 +56,7 @@ def __init__( min_val=self._config.init_method_min, max_val=self._config.init_method_max, ), - lr_scale=router_lr_scale, + lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index bd85e2089..23da37766 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -22,8 +22,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() @@ -77,13 +78,15 @@ class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, + block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): Assert.eq(config.num_experts, 1) - super().__init__(config, distributed_config, hidden_dim, block_index, name) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) def forward( self, diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index e6374e725..408f21041 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -1,3 +1,5 @@ +import functools + from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.block.block import Block, BlockLayer @@ -11,29 +13,26 @@ class SSMBlock[ConfigType: BlockConfig](Block[ConfigType]): A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ - _name = "Llamba block" - def __init__( self, config: ConfigType, ssm_config: SSMConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, - mixer_cls: type[BlockLayer], block_index: int, + lr_scale: float | list[float] | None, name: str, + mixer_class: type[BlockLayer], return_input: bool = False, ): self._ssm_config = ssm_config - self._mixer_cls = mixer_cls - super().__init__(config, distributed_config, hidden_dim, block_index, name, return_input) + self._mixer_class = mixer_class + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale, return_input) + + @functools.cached_property + def _mixer_class(self) -> type[BlockLayer]: + return self._mixer_class - def _create_mixer(self) -> BlockLayer: - return self._mixer_cls( - self._ssm_config, - self._config, - self._distributed_config, - self._hidden_dim, - self._block_index, - f"{self._name} mixer", - ) + @property + def _mixer_config(self) -> SSMConfig: + return self._config diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 910024e52..8917feaf6 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -68,7 +68,7 @@ class SSMConfig(Config): # [MambaLayer, Mamba2, DiscreteMamba2] conv_kernel_dimension: int = Field( default=4, - desc="Conv kernel dimensions.", + desc="Conv kernel dimension.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index f1020a903..7fea3d480 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -37,7 +37,7 @@ class DiscreteMamba2[ConfigType: SSMConfig](BlockLayer[ConfigType]): """ - This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py. + This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py """ _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" @@ -50,15 +50,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) v_head_size_dim = TensorDim("v_head_size", div(self._config.d_inner, self._config.n_v_heads)) @@ -88,8 +82,7 @@ def __init__( # local_bc_size = local_head_groups * state self._local_bc_size = bc_dim.size - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: double check initializations # Projections diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index bcb98d7c8..0ca98a174 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -65,15 +65,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" # TODO: It's not silu? Assert.eq(self._config.activation_type, ActivationType.silu) @@ -87,8 +81,7 @@ def __init__( inner_projection_dim = ConcatenatedTensorDim("inner_projection", (inner_dim, inner_dim)) x_projection_dim = ConcatenatedTensorDim("x_projection", (dt_rank_dim, state_dim, state_dim)) - layer_lr_scale = block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - lr_scale = combine_lr_scales(self._config.mamba_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) # TODO: Backward compatibility? # TODO: lr_scale? diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 09b96b7dd..bf9c30521 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -47,22 +47,10 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) - layer_lr_scale: float | None = ( - block_config.per_layer_lr_scale[block_index] if block_config.per_layer_lr_scale else None - ) - lr_scale: float | tuple[float | None, ...] | None = combine_lr_scales( - self._config.mamba_lr_scale, layer_lr_scale - ) num_heads = div(self._config.d_inner, self._config.state_size) num_head_groups = div(self._config.d_xb, self._config.state_size) @@ -94,6 +82,9 @@ def __init__( self._local_inner_size = inner_dim.size self._local_xb_size = xb_dim.size conv1d_dim = inner_dim if self._config.repeat_kv_before_conv else xb_dim + + lr_scale = combine_lr_scales(self._lr_scale, self._config.mamba_lr_scale) + self.conv1d_weight = ParameterMeta.from_dims( ( conv1d_dim, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c1f8010ba..bfe2b58c3 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -234,7 +234,7 @@ def _query_key_value_forward( handle = None - if self._head_groups == 1 and self._sequence_parallel: + if self._config.head_groups == 1 and self._sequence_parallel: key_value, handle = gather_op(key_value, group=self._parallel_dim.group, dim=0, async_op=True) if self._sequence_data_parallel_dim.group: @@ -277,7 +277,7 @@ def _query_key_value_backward( if handle: handle.wait() - if self._head_groups == 1 and (group := self._parallel_dim.group): + if self._config.head_groups == 1 and (group := self._parallel_dim.group): if self._sequence_parallel: key_value_grad = reduce_scatter_op(key_value_grad, group=group, dim=0) else: diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/transformer/block.py index a5aad45a9..ba593461b 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/transformer/block.py @@ -1,9 +1,10 @@ +import functools import logging import typing -from fast_llm.layers.block.block import Block, BlockLayer +from fast_llm.layers.block.block import Block from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) @@ -12,7 +13,10 @@ class TransformerBlock[ConfigType: TransformerConfig](Block[ConfigType]): # TODO: Standardize to `mixer` _mixer_module_name: typing.ClassVar[str] = "self_attn" - def _create_mixer(self) -> BlockLayer: - return Attention( - self._config, self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} attn" - ) + @functools.cached_property + def _mixer_class(self) -> type[Attention]: + return Attention + + @property + def _mixer_config(self) -> AttentionConfig: + return self._config From 0d3f4a6de2f29bddd4f3af0d31babc594fced4a4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 15 Aug 2025 14:06:31 -0400 Subject: [PATCH 15/19] stuff --- fast_llm/layers/language_model/embedding.py | 8 ++-- fast_llm/layers/language_model/head.py | 32 +++++++------- fast_llm/layers/ssm/mamba.py | 4 -- fast_llm/layers/transformer/attention.py | 48 ++++++--------------- 4 files changed, 34 insertions(+), 58 deletions(-) diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 33e05cde1..b93ee0d25 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -41,13 +41,15 @@ def __init__( hidden_dim, block_index, name, + # TODO: Add lr scale? + None, ) self._residual_dtype = ( self._distributed_config.optimization_dtype - if self._config.transformer.full_precision_residual + if self._block_config.full_precision_residual else self._distributed_config.training_dtype ).torch - self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) vocab_dim = TensorDim( "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None @@ -108,7 +110,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) + embeddings = torch.dropout(embeddings, self._block_config.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index aa77089e5..8093e0562 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -48,6 +48,8 @@ def __init__( hidden_dim, block_index, name, + # TODO: Add lr scale? + None, ) self._parallel_logits = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -62,7 +64,6 @@ def __init__( else 1.0 ) self._loss_name = LanguageModelLossNames.multi_token_prediction_loss(prediction_distance) - self.final_norm = self._config.transformer.normalization.get_layer(hidden_dim) # Distance of the target token prediction # 0: next-token prediction @@ -222,10 +223,7 @@ def _get_targets( targets = (dpo_target, lm_target, distillation_target, loss_mask) if self._sequence_parallel_logits: - targets = [ - None if target is None else split_op(target, self._tensor_space.distributed.tensor_group, 0) - for target in targets - ] + targets = [None if target is None else split_op(target, self._parallel_dim.group, 0) for target in targets] if not any(target is not None for target in targets): # Simplify so we don't have to check every time. targets = None @@ -247,7 +245,7 @@ def _logits_cross_entropy_forward_backward_split( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: - if self._cross_entropy_splits is None or targets is None: + if self._config.cross_entropy_splits is None or targets is None: loss, logit_input_grad = self._logits_cross_entropy_forward_backward( input_, targets, weight, grad_output, kwargs, losses ) @@ -257,18 +255,19 @@ def _logits_cross_entropy_forward_backward_split( return None, None else: loss = None - # TODO MTP: allow a _cross_entropy_splits that is not a divisor of the sequence length - grad_output /= self._cross_entropy_splits + # TODO MTP: allow a cross_entropy_splits that is not a divisor of the sequence length + grad_output /= self._config.cross_entropy_splits logit_input = input_.flatten(0, -2) if self.training: logit_input_grad = torch.empty_like(logit_input) else: logit_input_grad = None split_size = div( - get_unique(target.size(0) for target in targets if target is not None), self._cross_entropy_splits + get_unique(target.size(0) for target in targets if target is not None), + self._config.cross_entropy_splits, ) tensors_split = [ - [None] * self._cross_entropy_splits if tensor is None else tensor.split(split_size) + [None] * self._config.cross_entropy_splits if tensor is None else tensor.split(split_size) for tensor in [logit_input, *targets, logit_input_grad] ] for logit_input_, *targets_, logit_input_grad_ in zip(*tensors_split, strict=True): @@ -284,14 +283,14 @@ def _logits_cross_entropy_forward_backward_split( logit_input_grad_.copy_(grad_) loss = loss_ if loss is None else loss + loss_ del grad_, loss_ - loss_count = (self._cross_entropy_splits or 1) * ( + loss_count = (self._config.cross_entropy_splits or 1) * ( self._parallel_dim.size if self._sequence_parallel_logits else 1 ) if loss_count != 1: loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._tensor_space.distributed.tensor_group) + all_reduce(loss, group=self._distributed.tensor_group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -303,11 +302,12 @@ def _logits_cross_entropy_forward_backward( kwargs: dict, losses: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + group = self._parallel_dim.group if self._parallel_logits else None logits, context = output_parallel_linear_forward( input_=input_, weight=weight, bias=None, - group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + group=self._distributed.tensor_group if self._parallel_logits else None, sequence_parallel=self._sequence_parallel and self._parallel_logits, ) @@ -353,7 +353,7 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + group=sgroup, grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, @@ -370,7 +370,7 @@ def _logits_cross_entropy_forward_backward( distillation_target, loss_mask, grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, - group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + group=group, logits_scale_factor=self._config.logits_scale_factor, teacher_softmax_temperature=self._config.teacher_softmax_temperature, target_format=( @@ -382,7 +382,7 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), distillation_target, loss_mask, - group=self._tensor_space.distributed.tensor_group if self._parallel_logits else None, + group=group, grad_output=grad_output * self._loss_coefficient * self._config.distillation_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 0ca98a174..59fd03a1e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -91,7 +91,6 @@ def __init__( bias=False, weight_init_method=init_kaiming_(hidden_dim.size), ) - self.conv1d_weight = ParameterMeta.from_dims( ( inner_dim, @@ -101,7 +100,6 @@ def __init__( init_method=init_kaiming_(inner_dim.size), lr_scale=lr_scale, ) - self.x_proj = Linear( inner_dim, x_projection_dim, @@ -110,7 +108,6 @@ def __init__( lr_scale=lr_scale, ) self.x_proj.weight.auto_grad_accumulation = True - # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (inner_dim, dt_rank_dim), @@ -122,7 +119,6 @@ def __init__( init_method=init_dtprojbias(self._config.dt_max, self._config.dt_min, self._config.dt_init_floor), lr_scale=lr_scale, ) - self.A_log = ParameterMeta.from_dims( (inner_dim, state_dim), weight_decay=False, diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index bfe2b58c3..c7a21fa0c 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -4,7 +4,6 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward @@ -61,15 +60,9 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, + lr_scale: float | list[float] | None, ): - super().__init__( - config, - block_config, - distributed_config, - hidden_dim, - block_index, - name, - ) + super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -101,19 +94,7 @@ def __init__( self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) - init_method_qkv = init_normal_( - std=self._block_config.init_method_std_qkv, - min_val=self._block_config.init_method_min_qkv, - max_val=self._block_config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._block_config.init_method_std_attn_proj, - min_val=self._block_config.init_method_min_attn_proj, - max_val=self._block_config.init_method_max_attn_proj, - ) - - layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None - attention_lr_scale = combine_lr_scales(self._config.attention_lr_scale, layer_lr_scale) + lr_scale = combine_lr_scales(self._lr_scale, self._config.attention_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( @@ -123,7 +104,7 @@ def __init__( weight_init_method=self._config.qkv_weight_initialization_method, bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self.key_value = OutputParallelLinear( hidden_dim, @@ -132,7 +113,7 @@ def __init__( weight_init_method=self._config.qkv_weight_initialization_method, bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -147,7 +128,7 @@ def __init__( weight_init_method=self._config.dense_weight_initialization_method, bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=attention_lr_scale, + lr_scale=lr_scale, ) # PEFT. self.query = self._block_config.peft.apply_linear(self.query, True) @@ -252,7 +233,7 @@ def _query_key_value_forward( handle.wait() if self._sequence_data_parallel_dim.group and not sequence_first: - key_value = swap_mult_dim(key_value, self._distributed_config.sequence_data_parallel, 0, 1) + key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} return query, key_value, context @@ -261,15 +242,12 @@ def _query_key_value_backward( self, query_grad: torch.Tensor, key_value_grad: torch.Tensor, context: dict ) -> torch.Tensor: # TODO: De-allocate qkv grads quicker. - handle = None - - if self._sequence_data_parallel_dim.group: - key_value_grad, handle = reduce_scatter_op( - key_value_grad, - group=self._sequence_data_parallel_dim.group, - dim=1 - context["sequence_first"], - async_op=True, - ) + key_value_grad, handle = reduce_scatter_op( + key_value_grad, + group=self._sequence_data_parallel_dim.group, + dim=1 - context["sequence_first"], + async_op=True, + ) # TODO: Overlap with both. input_grad = self.query.backward(query_grad, context.pop("query")) From 651be5d52c0f97c9ad50fa1d1c0e3a2661d1bb35 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 10:58:53 -0400 Subject: [PATCH 16/19] Reduce diff --- .../common/normalization/normalization.py | 27 +++++-------------- fast_llm/layers/language_model/embedding.py | 9 ++++--- fast_llm/layers/language_model/head.py | 10 ++++--- 3 files changed, 17 insertions(+), 29 deletions(-) diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index dac4a7548..5e5dc8795 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.normalization import triton_normalization_autograd -from fast_llm.layers.common.normalization import ( +from fast_llm.layers.common.normalization.config import ( LayerNormalizationConfig, NoNormalizationConfig, NormalizationConfig, @@ -148,12 +148,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): - def __init__( - self, - config: NormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): + def __init__(self, config: NormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config) self._hidden_dim = hidden_dim self._lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) @@ -177,12 +172,7 @@ class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[Con TODO: Review this? """ - def __init__( - self, - config: LayerNormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): + def __init__(self, config: LayerNormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) implementation = self._config.implementation if implementation == NormalizationImplementation.auto: @@ -219,14 +209,14 @@ def __init__( init_method=self._config.weight_initialization_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) self.bias = ParameterMeta.from_dims( (hidden_dim,), init_method=self._config.bias_initialization_method, weight_decay=False, auto_grad_accumulation=implementation == NormalizationImplementation.torch, - lr_scale=lr_scale, + lr_scale=self._lr_scale, ) self._normalized_shape = self.weight.shape @@ -258,12 +248,7 @@ class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigTy TODO: Review this? """ - def __init__( - self, - config: RMSNormalizationConfig, - hidden_dim: TensorDim, - lr_scale: float | None = None, - ): + def __init__(self, config: RMSNormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel implementation = self._config.implementation diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index b93ee0d25..b8ef2c6d1 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -82,20 +82,21 @@ def __init__( @torch.compile def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool) -> torch.Tensor: Assert.eq(position_ids is not None, self._config.use_absolute_position_embeddings) + group = self._parallel_dim.group if self._parallel_embeddings: input_mask = (input_ >= self._vocab_start_index) * (input_ < self._vocab_end_index) masked_input = (input_ - self._vocab_start_index) * input_mask embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * input_mask.unsqueeze(2) # noqa - embeddings = reduce_forward(embeddings, self._parallel_dim.group) + embeddings = reduce_forward(embeddings, group) if self._config.use_absolute_position_embeddings: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if self._sequence_parallel: - embeddings = split(embeddings, group=self._parallel_dim.group, dim=0) + embeddings = split(embeddings, group=group, dim=0) else: if self._sequence_parallel: - input_ = split(input_, group=self._parallel_dim.group, dim=0) + input_ = split(input_, group=group, dim=0) if self._config.use_absolute_position_embeddings: - position_ids = split(position_ids, group=self._parallel_dim.group, dim=0) + position_ids = split(position_ids, group=group, dim=0) # handle masked tokens if mask_inputs: input_mask = input_ >= 0 diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 8093e0562..a21fd9934 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -111,7 +111,9 @@ def forward( return TensorMeta.from_dims( (scalar_dim,), tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + reductions=( + (self._distributed_config.get_distributed_dim(DistributedDimNames.data), ReduceOp.AVG), + ), ) else: return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") @@ -290,7 +292,7 @@ def _logits_cross_entropy_forward_backward_split( loss.div_(loss_count) if self._sequence_parallel_logits: # TODO: Async - all_reduce(loss, group=self._distributed.tensor_group) + all_reduce(loss, group=self._parallel_dim.group) return loss, logit_input_grad.view_as(input_) if logit_input_grad is not None else None def _logits_cross_entropy_forward_backward( @@ -307,7 +309,7 @@ def _logits_cross_entropy_forward_backward( input_=input_, weight=weight, bias=None, - group=self._distributed.tensor_group if self._parallel_logits else None, + group=group, sequence_parallel=self._sequence_parallel and self._parallel_logits, ) @@ -353,7 +355,7 @@ def _logits_cross_entropy_forward_backward( logits.flatten(0, -2), lm_target, None, - group=sgroup, + group=group, grad_output=grad_output * self._loss_coefficient * self._config.language_model_loss_factor, implementation=self._cross_entropy_impl, logits_scale_factor=self._config.logits_scale_factor, From 92da4cd5f54c07e9b7beaf8c8c999b897f78e834 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 11:14:27 -0400 Subject: [PATCH 17/19] Reduce diff --- fast_llm/layers/block/block.py | 5 +++-- fast_llm/layers/block/config.py | 2 +- fast_llm/layers/block/mlp/mixture_of_experts.py | 2 +- fast_llm/layers/block/mlp/mlp.py | 4 ++-- fast_llm/layers/ssm/block.py | 2 +- fast_llm/layers/ssm/discrete_mamba2.py | 2 +- fast_llm/layers/ssm/mamba.py | 2 +- fast_llm/layers/ssm/mamba2.py | 2 +- fast_llm/layers/transformer/attention.py | 2 +- fast_llm/utils.py | 8 ++++---- 10 files changed, 16 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index da981a0d8..991ecb3e7 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -100,7 +100,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, distributed_config) self._block_config = block_config @@ -144,7 +144,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, return_input: bool = False, ): super().__init__( @@ -162,6 +162,7 @@ def __init__( # TODO: add a separate norm_lr_scale self.norm_1 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) self.norm_2 = self._config.peft.apply_other(self._config.normalization.get_layer(self._hidden_dim)) + # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 4dbce8eb0..7f91dd0b0 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -58,7 +58,7 @@ def get_layer( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ) -> "BlockLayer": return self.layer_class( self, diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index b7a50ee8d..f507ca1ab 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -40,7 +40,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 23da37766..0c1a18ae2 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -22,7 +22,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) @@ -83,7 +83,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): Assert.eq(config.num_experts, 1) super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 408f21041..fef890d41 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -20,7 +20,7 @@ def __init__( distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, - lr_scale: float | list[float] | None, + lr_scale: float | None, name: str, mixer_class: type[BlockLayer], return_input: bool = False, diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 7fea3d480..0d91fbaff 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -50,7 +50,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) state_dim = TensorDim("state", self._config.state_size) diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 59fd03a1e..79a0e5c8e 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -65,7 +65,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) assert self._distributed_config.tensor_parallel == 1, "Tensor-parallel not supported for MambaLayer" diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index bf9c30521..eec134a22 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -47,7 +47,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) Assert.eq(self._config.activation_type, ActivationType.silu) diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index c7a21fa0c..1fd7b6be1 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -60,7 +60,7 @@ def __init__( hidden_dim: TensorDim, block_index: int, name: str, - lr_scale: float | list[float] | None, + lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) self._use_flash_attention = self._config.do_use_flash_attention(self._distributed_config) diff --git a/fast_llm/utils.py b/fast_llm/utils.py index f7f5e9663..fb16ef192 100644 --- a/fast_llm/utils.py +++ b/fast_llm/utils.py @@ -348,9 +348,9 @@ def check_equal_nested(config_a, config_b): raise ValueError("\n".join(errors)) -def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): +def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]) -> float | None | tuple[float | None, ...]: # Remove `None` entries. - lr_scales = [lr_scale for lr_scale in lr_scales if lr_scale is not None] + lr_scales = tuple(lr_scale for lr_scale in lr_scales if lr_scale is not None) if not lr_scales: # Everything is None return None @@ -367,10 +367,10 @@ def combine_lr_scales(*lr_scales: float | None | tuple[float | None, ...]): return math.prod(lr_scales) else: # Tuple(s): use recursion. - return [ + return tuple( combine_lr_scales(*[lr_scale[i] if isinstance(lr_scale, tuple) else lr_scale for lr_scale in lr_scales]) for i in range(tuple_length) - ] + ) class Interrupter: From 11a5a219cb43e131f1a9e5c3fb2c73bb6ad89c18 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 12:52:36 -0400 Subject: [PATCH 18/19] Fix merge --- fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py | 2 +- fast_llm/layers/{transformer => attention}/__init__.py | 0 .../layers/{transformer => attention}/attention.py | 2 +- fast_llm/layers/{transformer => attention}/block.py | 4 ++-- fast_llm/layers/{transformer => attention}/config.py | 2 +- .../layers/{transformer => attention}/preprocessing.py | 2 +- .../{transformer => attention}/rotary/__init__.py | 0 .../layers/{transformer => attention}/rotary/config.py | 10 +++++----- .../layers/{transformer => attention}/rotary/rotary.py | 4 ++-- fast_llm/layers/block/config.py | 2 +- fast_llm/models/gpt/conversion.py | 6 +++--- fast_llm/models/gpt/huggingface.py | 2 +- fast_llm/models/gpt/megatron.py | 6 +++--- fast_llm/models/gpt/model.py | 6 +++--- fast_llm/models/ssm/model.py | 2 +- tests/functional/test_triton_kernels.py | 4 ++-- tests/layers/test_lm_head.py | 2 +- tests/test_attention.py | 6 +++--- tests/test_multi_stage.py | 2 +- 19 files changed, 32 insertions(+), 32 deletions(-) rename fast_llm/layers/{transformer => attention}/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/attention.py (99%) rename fast_llm/layers/{transformer => attention}/block.py (77%) rename fast_llm/layers/{transformer => attention}/config.py (99%) rename fast_llm/layers/{transformer => attention}/preprocessing.py (98%) rename fast_llm/layers/{transformer => attention}/rotary/__init__.py (100%) rename fast_llm/layers/{transformer => attention}/rotary/config.py (92%) rename fast_llm/layers/{transformer => attention}/rotary/rotary.py (98%) diff --git a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py index 8f4dffedf..439d1da2e 100644 --- a/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py +++ b/fast_llm/engine/evaluation/lm_eval/fast_llm_wrapper.py @@ -16,7 +16,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.evaluation.lm_eval.utils import prepare_lm_eval_simple_eval_params, process_lm_eval_results from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.rotary.config import NoRotaryConfig +from fast_llm.layers.attention.rotary.config import NoRotaryConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/__init__.py b/fast_llm/layers/attention/__init__.py similarity index 100% rename from fast_llm/layers/transformer/__init__.py rename to fast_llm/layers/attention/__init__.py diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/attention/attention.py similarity index 99% rename from fast_llm/layers/transformer/attention.py rename to fast_llm/layers/attention/attention.py index 1fd7b6be1..cec026b82 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -7,11 +7,11 @@ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames from fast_llm.layers.block.peft import TransformerSubLayerName from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs from fast_llm.utils import combine_lr_scales, div try: diff --git a/fast_llm/layers/transformer/block.py b/fast_llm/layers/attention/block.py similarity index 77% rename from fast_llm/layers/transformer/block.py rename to fast_llm/layers/attention/block.py index ba593461b..3396a2997 100644 --- a/fast_llm/layers/transformer/block.py +++ b/fast_llm/layers/attention/block.py @@ -2,9 +2,9 @@ import logging import typing +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionConfig, TransformerConfig from fast_llm.layers.block.block import Block -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionConfig, TransformerConfig logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/config.py b/fast_llm/layers/attention/config.py similarity index 99% rename from fast_llm/layers/transformer/config.py rename to fast_llm/layers/attention/config.py index d8c5bf923..301a7c5c1 100644 --- a/fast_llm/layers/transformer/config.py +++ b/fast_llm/layers/attention/config.py @@ -8,8 +8,8 @@ from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig +from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig -from fast_llm.layers.transformer.rotary.config import RotaryConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/attention/preprocessing.py similarity index 98% rename from fast_llm/layers/transformer/preprocessing.py rename to fast_llm/layers/attention/preprocessing.py index 769177668..24ef3397c 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -6,7 +6,7 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.transformer.config import AttentionConfig, AttentionKwargs +from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.tensor import TensorMeta logger = logging.getLogger(__name__) diff --git a/fast_llm/layers/transformer/rotary/__init__.py b/fast_llm/layers/attention/rotary/__init__.py similarity index 100% rename from fast_llm/layers/transformer/rotary/__init__.py rename to fast_llm/layers/attention/rotary/__init__.py diff --git a/fast_llm/layers/transformer/rotary/config.py b/fast_llm/layers/attention/rotary/config.py similarity index 92% rename from fast_llm/layers/transformer/rotary/config.py rename to fast_llm/layers/attention/rotary/config.py index 6cc19fce8..4ebd6c5dc 100644 --- a/fast_llm/layers/transformer/rotary/config.py +++ b/fast_llm/layers/attention/rotary/config.py @@ -10,7 +10,7 @@ from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary, Llama3Rotary, NoRotary, Rotary, YarnRotary @config_class(registry=True) @@ -44,7 +44,7 @@ class NoRotaryConfig(RotaryConfig): @classmethod def _get_configurable_class(self) -> "type[NoRotary]": - from fast_llm.layers.transformer.rotary.rotary import NoRotary + from fast_llm.layers.attention.rotary.rotary import NoRotary return NoRotary @@ -75,7 +75,7 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but the triton rotary kernel will be used anyway.") def _get_configurable_class(self) -> "type[DefaultRotary]": - from fast_llm.layers.transformer.rotary.rotary import DefaultRotary + from fast_llm.layers.attention.rotary.rotary import DefaultRotary return DefaultRotary @@ -97,7 +97,7 @@ def _validate(self) -> None: Assert.gt(self.high_frequency_factor, self.low_frequency_factor) def _get_configurable_class(self) -> "type[Llama3Rotary]": - from fast_llm.layers.transformer.rotary.rotary import Llama3Rotary + from fast_llm.layers.attention.rotary.rotary import Llama3Rotary return Llama3Rotary @@ -137,6 +137,6 @@ def _validate(self) -> None: super()._validate() def _get_configurable_class(self) -> "type[YarnRotary]": - from fast_llm.layers.transformer.rotary.rotary import YarnRotary + from fast_llm.layers.attention.rotary.rotary import YarnRotary return YarnRotary diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py similarity index 98% rename from fast_llm/layers/transformer/rotary/rotary.py rename to fast_llm/layers/attention/rotary/rotary.py index bbf8b524a..53b24c9bb 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -8,8 +8,8 @@ from fast_llm.engine.base_model.config import Preprocessor from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.functional.triton.rotary import triton_rotary_autograd_ -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.rotary.config import ( +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.rotary.config import ( DefaultRotaryConfig, Llama3RotaryConfig, NoRotaryConfig, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 7f91dd0b0..c85fa4aff 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -91,7 +91,7 @@ def _from_dict( flat: bool = False, ) -> typing.Self: if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: - from fast_llm.layers.transformer.config import AttentionConfig + from fast_llm.layers.attention.config import AttentionConfig # Default subclass. return AttentionConfig._from_dict(default, strict, flat) diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index de9b51b08..f26e811d0 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -24,11 +24,11 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig +from fast_llm.layers.attention.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.layers.block.mlp.config import RoutingType from fast_llm.layers.common.normalization.config import LayerNormalizationConfig -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig, Llama3RotaryConfig, YarnRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import convert_rotary_complex_to_real, convert_rotary_real_to_complex from fast_llm.models.gpt.config import ( DiffusionDreamGPTHuggingfaceCheckpointFormat, DiffusionLlamaGPTHuggingfaceCheckpointFormat, diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 4e3f258fc..2f99ae4c3 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -9,7 +9,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.engine.inference.config import HuggingfaceModelConfig from fast_llm.engine.inference.huggingface import HuggingfaceBaseModelForCausalLM -from fast_llm.layers.transformer.config import AttentionKwargs +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index 20ed8e828..5d3130549 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -1,7 +1,7 @@ import typing -from fast_llm.layers.transformer.config import TransformerConfig -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.config import TransformerConfig +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: @@ -94,7 +94,7 @@ def _init_attention_megatron( raise NotImplementedError(meta.tensor_name) if isinstance(config.rotary, DefaultRotaryConfig) and config.rotary.complex_format: - from fast_llm.layers.transformer.rotary.config import convert_rotary_real_to_complex + from fast_llm.layers.attention.rotary.config import convert_rotary_real_to_complex # Megatron uses (2, kv_channels/2) for the complex split; we use (kv_channels/2, 2). # TODO: Avoid unnecessarily changing the value and dense tensors. diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 9422900a8..a0b2d6f4d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -10,15 +10,15 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.engine.multi_stage.fast_llm_model import FastLLMModel +from fast_llm.layers.attention.block import TransformerBlock +from fast_llm.layers.attention.config import AttentionKwargs +from fast_llm.layers.attention.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.block.mlp.config import MLPLossNames, RoutingType from fast_llm.layers.language_model.config import LanguageModelKwargs, LanguageModelLossNames from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT, LanguageModelEmbedding from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead from fast_llm.layers.language_model.preprocessing import PositionEmbeddingPreprocessor, PreferenceSpanPreprocessor -from fast_llm.layers.transformer.block import TransformerBlock -from fast_llm.layers.transformer.config import AttentionKwargs -from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 94f9eb321..26f437215 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -2,8 +2,8 @@ import typing from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 3f4446e4d..5a9065454 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -23,8 +23,8 @@ from fast_llm.functional.triton.pointwise import triton_add, triton_copy, triton_fill from fast_llm.functional.triton.rotary import triton_rotary_ from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.transformer.rotary.config import DefaultRotaryConfig -from fast_llm.layers.transformer.rotary.rotary import ( +from fast_llm.layers.attention.rotary.config import DefaultRotaryConfig +from fast_llm.layers.attention.rotary.rotary import ( apply_rotary_embeddings, convert_rotary_complex_to_real, convert_rotary_real_to_complex, diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index 8c33aed4d..380ab0550 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,10 +6,10 @@ from fast_llm.config import UpdateType from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl +from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.language_model.embedding import WORD_EMBEDDINGS_WEIGHT from fast_llm.layers.language_model.head import OUTPUT_WEIGHTS, LanguageModelHead -from fast_llm.layers.transformer.config import AttentionKwargs from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTModelConfig from fast_llm.utils import Assert from tests.utils.utils import get_base_model, get_stage, requires_cuda diff --git a/tests/test_attention.py b/tests/test_attention.py index 7d05e0a66..9564a931f 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -5,10 +5,10 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed +from fast_llm.layers.attention.attention import Attention +from fast_llm.layers.attention.config import AttentionKwargs, TransformerConfig +from fast_llm.layers.attention.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.layers.block.config import BlockDimNames -from fast_llm.layers.transformer.attention import Attention -from fast_llm.layers.transformer.config import AttentionKwargs, TransformerConfig -from fast_llm.layers.transformer.preprocessing import FlashAttnVarlenPreprocessor from fast_llm.utils import Assert diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index 0639ec7ed..56356cf7a 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,8 +3,8 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer +from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock -from fast_llm.layers.transformer.block import TransformerBlock from fast_llm.utils import Assert from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup From 0418be13424f103fc6ac0927797f1d8d2bb0de9a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 20 Aug 2025 13:16:39 -0400 Subject: [PATCH 19/19] Fix merge --- fast_llm/layers/common/peft/config.py | 2 +- fast_llm/layers/common/peft/{peft.py => lora.py} | 0 2 files changed, 1 insertion(+), 1 deletion(-) rename fast_llm/layers/common/peft/{peft.py => lora.py} (100%) diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index d1e21e340..4090c001a 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -74,7 +74,7 @@ def apply_linear( return self.apply_other(module) from fast_llm.layers.common.linear import InputParallelLinear - from fast_llm.layers.common.peft.peft import lora_linear + from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): # TODO: Support InputParallelLinear (different output format). diff --git a/fast_llm/layers/common/peft/peft.py b/fast_llm/layers/common/peft/lora.py similarity index 100% rename from fast_llm/layers/common/peft/peft.py rename to fast_llm/layers/common/peft/lora.py