diff --git a/docs/developer_guide/conversion.md b/docs/developer_guide/conversion.md index 35a324db..a465cb9a 100644 --- a/docs/developer_guide/conversion.md +++ b/docs/developer_guide/conversion.md @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define: def _create_weight_converters(self) -> list[WeightConverter]: converters = [] # The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`. - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks # A simple renaming example, for the word embeddings. converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 43ef8beb..cec026b8 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -4,14 +4,15 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames -from fast_llm.utils import div +from fast_llm.layers.block.peft import TransformerSubLayerName +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear +from fast_llm.utils import combine_lr_scales, div try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -93,37 +94,26 @@ def __init__( self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power) - init_method_qkv = init_normal_( - std=self._config.init_method_std_qkv, - min_val=self._config.init_method_min_qkv, - max_val=self._config.init_method_max_qkv, - ) - init_method_std_attn_proj = init_normal_( - std=self._config.init_method_std_attn_proj, - min_val=self._config.init_method_min_attn_proj, - max_val=self._config.init_method_max_attn_proj, - ) + lr_scale = combine_lr_scales(self._lr_scale, self._config.attention_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) - self.query = self._config.query_layer.get_layer( + self.query = OutputParallelLinear( hidden_dim, query_dim, - bias=self._block_config.add_linear_biases, - weight_init_method=self._config.query_layer.weight_initialization, - bias_init_method=self._config.query_layer.bias_initialization, + bias=self._config.add_qkv_bias, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._config.block.peft, + lr_scale=lr_scale, ) - # TODO: Separate Peft (others?) for key and value - self.key_value = self._config.query_layer.get_layer( + self.key_value = OutputParallelLinear( hidden_dim, - query_dim, - weight_init_method=self._config.key_layer.weight_initialization, - bias_init_method=self._config.key_layer.bias_initialization, + key_value_dim, + bias=self._config.add_qkv_bias, + weight_init_method=self._config.qkv_weight_initialization_method, + bias_init_method=self._config.qkv_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._config.block.peft, + lr_scale=lr_scale, ) self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward) @@ -131,15 +121,22 @@ def __init__( self._rotary = self._config.rotary.get_layer(kv_channels_dim) # Output. - self.dense = self._config.dense_layer.get_layer( + self.dense = InputParallelLinear( dense_dim, hidden_dim, - weight_init_method=self._config.dense_layer.weight_initialization, - bias_init_method=self._config.dense_layer.bias_initialization, + bias=self._config.add_dense_bias, + weight_init_method=self._config.dense_weight_initialization_method, + bias_init_method=self._config.dense_bias_initialization_method, sequence_parallel=self._sequence_parallel, - lr_scale=self._lr_scale, - peft=self._config.block.peft, + lr_scale=lr_scale, + ) + # PEFT. + self.query = self._block_config.peft.apply_linear(self.query, True) + + self.key_value = self._block_config.peft.apply_linear( + self.key_value, True, out_channel_begin=div(self.key_value._out_dim.global_size, 2) ) + self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense) if self._debug.enabled: self._query_dims = ( diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index c00dd8a2..301a7c5c 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,15 +1,15 @@ import functools import logging +import typing import warnings -from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.config_utils.data_type import DataType -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig -from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear.config import LinearConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig from fast_llm.utils import Assert, div logger = logging.getLogger(__name__) @@ -29,29 +29,14 @@ class AttentionKwargs(BlockKwargs): past_key_values = "past_key_values" -@config_class() -class AttentionConfig(Config): - # TODO: Make mixer class dynamic. +@config_class(dynamic_type={MixerConfig: "attention"}) +class AttentionConfig(MixerConfig): _abstract = False + # Needed for backward compatibility. TODO: remove + module_name: typing.ClassVar[str] = "attn" + # TODO: Review names - query_layer: LinearConfig = Field( - desc="Configuration for the query layer.", - hint=FieldHint.architecture, - ) - key_layer: LinearConfig = Field( - desc="Configuration for the key layer.", - hint=FieldHint.architecture, - ) - # TODO: Use - value_layer: LinearConfig = Field( - desc="Configuration for the value layer.", - hint=FieldHint.architecture, - ) - dense_layer: LinearConfig = Field( - desc="Initialization configuration for the dense layer.", - hint=FieldHint.feature, - ) rotary: RotaryConfig = Field( desc="Configuration for the rotary positional embeddings.", hint=FieldHint.architecture, @@ -107,28 +92,30 @@ class AttentionConfig(Config): " Under muP (if scaling number of heads instead of kv_channels): use 0.5.", valid=skip_valid_if_none(check_field(Assert.geq, 0)), ) + qkv_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer weights." + " Default: normal(std=hidden_size**-0.5)", + hint=FieldHint.feature, + ) + qkv_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + dense_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer weight." + " Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)", + hint=FieldHint.feature, + ) + dense_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the dense layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) def _validate(self) -> None: with self._set_implicit_default(): - # TODO: Make this work without inheritance. + # TODO: hidden_size not yet validated. if self.kv_channels is None: - self.kv_channels = div(self.hidden_size, self.num_attention_heads) - # TODO: Block variables as defaults? - for layer, scale, enable_peft in ( - zip( - (self.query_layer, self.key_layer, self.value_layer, self.dense_layer), - (1, 1, 1, 2 * max(self.num_blocks, 1)), - (True, False, True, False), - ), - ): - layer.default = LinearConfig( - bias=True, - weight_initialization=init_normal_(0, (self.hidden_size * scale) ** -0.5), - bias_initialization=init_zeros_, - lr_scale=None, - enable_peft=True, - ) - super()._validate() + self.kv_channels = div(self.block.hidden_size, self.num_attention_heads) super()._validate() @@ -136,6 +123,10 @@ def _validate(self) -> None: warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") Assert.multiple(self.num_attention_heads, self.head_groups) + if not self.qkv_bias_initialization.is_default: + assert self.add_qkv_bias + if not self.dense_bias_initialization.is_default: + assert self.add_dense_bias @functools.cached_property def projection_size(self): @@ -145,8 +136,44 @@ def projection_size(self): def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool: return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16) + @functools.cached_property + def add_qkv_bias(self) -> bool: + return self.block.add_linear_biases + + @functools.cached_property + def add_dense_bias(self) -> bool: + return self.block.add_linear_biases + + @functools.cached_property + def qkv_weight_initialization_method(self) -> Initializer: + if self.qkv_weight_initialization.is_default: + return init_normal_(0, self.block.hidden_size**-0.5) + else: + return self.qkv_weight_initialization.get_initializer() + + @functools.cached_property + def qkv_bias_initialization_method(self) -> Initializer: + if self.qkv_bias_initialization.is_default: + return init_zeros_ + else: + return self.qkv_bias_initialization.get_initializer() + + @functools.cached_property + def dense_weight_initialization_method(self) -> Initializer: + if self.dense_weight_initialization.is_default: + return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) + else: + return self.dense_weight_initialization.get_initializer() + + @functools.cached_property + def dense_bias_initialization_method(self) -> Initializer: + if self.dense_bias_initialization.is_default: + return init_zeros_ + else: + return self.dense_bias_initialization.get_initializer() + @config_class() -# TODO: Use composition instead -class TransformerConfig(AttentionConfig, BlockConfig): - _abstract = False +# TODO: Remove +class TransformerConfig(BlockConfig): + pass diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index f90fce69..75388bfb 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -12,7 +12,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.engine.distributed.distributed import Distributed -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta @@ -116,7 +116,7 @@ def __init__( self._lr_scale = lr_scale -class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]): +class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]): """ Base class for mixer and MLP modules. """ @@ -137,9 +137,6 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer): A transformer-like decoder base block with abstract mixer. """ - # TODO: Standardize to `mixer` - _mixer_module_name: typing.ClassVar[str] = "mixer" - def __init__( self, config: ConfigType, @@ -173,9 +170,8 @@ def __init__( # Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix. setattr( self, - self._mixer_module_name, - self._mixer_class( - self._mixer_config, + self._config.mixer.module_name, + self._config.mixer.get_layer( self._config, self._distributed_config, self._hidden_dim, @@ -184,34 +180,13 @@ def __init__( self._lr_scale, ), ) - - # TODO: Use dynamic type. - from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP - from fast_llm.layers.block.mlp.mlp import MLP - - self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, - self._config, - self._distributed_config, - self._hidden_dim, - self._block_index, - f"{self._name} MLP", - self._lr_scale, + self.mlp = self._config.mlp.get_layer( + self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mlp", self._lr_scale ) - @functools.cached_property - @abc.abstractmethod - def _mixer_class(self) -> type[BlockLayer]: - pass - - @property - @abc.abstractmethod - def _mixer_config(self) -> Config: - pass - def setup(self, distributed: Distributed) -> None: super().setup(distributed) - getattr(self, self._mixer_module_name).setup(distributed) + getattr(self, self._config.mixer.module_name).setup(distributed) self.mlp.setup(distributed) @torch.compile @@ -241,7 +216,7 @@ def forward( hidden_states = self.norm_1(input_) if self._debug.enabled: self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs) - hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs) + hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs) if self._debug.enabled: self._debug( hidden_states if bias is None else hidden_states + bias, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index c8922af3..c85fa4af 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -1,10 +1,16 @@ +import typing + from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.layers.block.mlp.config import MLPConfig +from fast_llm.engine.config_utils.tensor_dim import TensorDim +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.common.normalization.config import NormalizationConfig from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert +if typing.TYPE_CHECKING: + from fast_llm.layers.block.block import BlockLayer + # TODO: Generalize these beyond language models? (Ex. vision) @@ -33,10 +39,101 @@ class BlockKwargs: @config_class() -# TODO: Use composition instead -class BlockConfig(MLPConfig, BaseModelConfig): +class BlockLayerConfig(BaseModelConfig): + """ + A common class for mixers and mlps, which have the exact same interface. + """ + + _abstract = True + block: "BlockConfig" = Field(init=False) + + @property + def layer_class(self) -> "type[BlockLayer]": + raise NotImplementedError() + + def get_layer( + self, + block_config: "BlockConfig", + distributed_config: DistributedConfig, + hidden_dim: TensorDim, + block_index: int, + name: str, + lr_scale: float | None, + ) -> "BlockLayer": + return self.layer_class( + self, + block_config, + distributed_config, + hidden_dim, + block_index, + name, + lr_scale, + ) + + +@config_class(registry=True) +class MixerConfig(BlockLayerConfig): + _abstract = True + + # Needed for backward compatibility. TODO: Standardize to `mixer` + module_name: typing.ClassVar[str] = "mixer" + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mixer, self) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MixerConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.attention.config import AttentionConfig + + # Default subclass. + return AttentionConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + + +@config_class(registry=True) +class MLPBaseConfig(BlockLayerConfig): + _abstract = True + + def _validate(self) -> None: + assert hasattr(self, "block") + Assert.is_(self.block.mlp, self) + super()._validate() + + @classmethod + def _from_dict( + cls, + default: dict[str, typing.Any], + strict: bool = True, + flat: bool = False, + ) -> typing.Self: + if cls is MLPBaseConfig and cls.get_subclass(default.get("type")) is None: + from fast_llm.layers.block.mlp.config import MLPConfig + + # Default subclass. + return MLPConfig._from_dict(default, strict, flat) + return super()._from_dict(default, strict=strict, flat=flat) + - # TODO: Allow separate config for each normalization layer? +@config_class() +class BlockConfig(BaseModelConfig): + _abstract = False + mixer: MixerConfig = Field( + desc="Configuration for the mixer.", + hint=FieldHint.architecture, + ) + mlp: MLPBaseConfig = Field( + desc="Configuration for the MLP.", + hint=FieldHint.architecture, + ) + # TODO: Allow separate initializations? normalization: NormalizationConfig = Field( desc="Configuration for the normalization layers architecture.", hint=FieldHint.architecture, @@ -52,11 +149,6 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) - full_precision_residual: bool = Field( - default=False, - desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", - hint=FieldHint.stability, - ) debug_transformer: int = Field( default=0, desc="Log the output of each operation in a transformer layer.", @@ -71,10 +163,11 @@ class BlockConfig(MLPConfig, BaseModelConfig): add_linear_biases: bool = Field( default=True, desc="Whether to add biases to linear layers. May be overridden in individual layer configs.", + hint=FieldHint.architecture, ) # TODO: Move these, not specific to a single block. - num_layers: int = Field( + num_blocks: int = Field( default=12, desc="Number of blocks in the model.", hint=FieldHint.architecture, @@ -86,7 +179,12 @@ class BlockConfig(MLPConfig, BaseModelConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - per_layer_lr_scale: list[float | None] | None = Field( + full_precision_residual: bool = Field( + default=False, + desc="Store the residuals for the transformer in full precision (`optimization_dtype`).", + hint=FieldHint.stability, + ) + per_layer_lr_scale: list[float] | None = Field( default=None, desc="Custom learning rate scale for each layer.", doc="May be used to freeze some layers by setting their scale to zero.", diff --git a/fast_llm/layers/block/mlp/config.py b/fast_llm/layers/block/mlp/config.py index 09495c8f..416670c8 100644 --- a/fast_llm/layers/block/mlp/config.py +++ b/fast_llm/layers/block/mlp/config.py @@ -1,14 +1,15 @@ import enum +import functools import typing -from fast_llm.config import Config, Field, FieldHint, check_field, config_class -from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_ +from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_ from fast_llm.functional.config import ActivationType, MLPRecomputeLevel -from fast_llm.layers.common.linear.config import LinearConfig +from fast_llm.layers.block.config import BlockLayerConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - pass + from fast_llm.layers.block.mlp.mlp import MLPBase class MLPLossNames: @@ -21,23 +22,11 @@ class RoutingType(str, enum.Enum): sinkhorn = "sinkhorn" -@config_class() -class MLPConfig(Config): - # TODO: Review names # TODO: Separate MoE? +@config_class(dynamic_type={BlockLayerConfig: "mlp"}) +class MLPConfig(BlockLayerConfig): + # TODO: Review names + # TODO: Separate MoE? _abstract = False - layer_1: LinearConfig = Field( - desc="Configuration for the first MLP layer.", - hint=FieldHint.architecture, - ) - layer_2: LinearConfig = Field( - desc="Configuration for the second MLP layer.", - hint=FieldHint.architecture, - ) - router: LinearConfig = Field( - # TODO: Improve default? - desc="Configuration for the MoE router.", - hint=FieldHint.feature, - ) ffn_hidden_size: int = Field( default=None, desc="Hidden dimension of the MLP intermediate state. Default: 4 * hidden_size.", @@ -103,6 +92,18 @@ class MLPConfig(Config): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + mlp_lr_scale: float | None | tuple[float | None] = Field( + default=None, + desc="Custom learning rate scale for each expert.", + doc="May be used to freeze some experts by setting their scale to zero.", + hint=FieldHint.feature, + ) + router_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate for the MoE router weight.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) dropless_moe: bool = Field( default=True, desc="Evaluate all the experts at once using dropless MoE.", hint=FieldHint.expert ) @@ -112,26 +113,52 @@ class MLPConfig(Config): " Reduces memory usage, but increases fragmentation and requires CPU synchronisation. Not recommended.", hint=FieldHint.expert, ) + layer_1_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer weights. Default: normal(0, hidden_size**-0.5).", + hint=FieldHint.feature, + ) + layer_1_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the first mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + layer_2_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer weights." + " Default: normal((2 * num_blocks * hidden_size)**-0.5)", + hint=FieldHint.feature, + ) + layer_2_bias_initialization: InitializationConfig = Field( + desc="Initialization configuration for the second mlp layer biases. Default: fill with zeros.", + hint=FieldHint.feature, + ) + router_weight_initialization: InitializationConfig = Field( + # TODO: Improve default? + desc="Initialization configuration for the MoE router weight. Default: normal(0, hidden_size**-0.5).", + hint=FieldHint.feature, + ) - def _validate(self) -> None: - # TODO: Make this work without inheritance. - for layer, bias, scale in zip( - (self.layer_1, self.layer_2, self.router), - (self.add_linear_biases, self.add_linear_biases, False), - (1, max(self.num_blocks, 1), 1), - ): - layer.default = LinearConfig( - bias=bias, - weight_initialization=init_normal_(0, (self.hidden_size * scale) ** -0.5), - bias_initialization=init_zeros_, - apply_peft=False, - ) + @property + def layer_class(self) -> "type[MLPBase]": + if self.num_experts > 1: + from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP + + return MixtureOfExpertMLP + else: + from fast_llm.layers.block.mlp.mlp import MLP + + return MLP + @property + def add_bias(self) -> bool: + return self.block.add_linear_biases + + def _validate(self) -> None: + assert hasattr(self, "block") with self._set_implicit_default(): if self.activation_type is None: self.activation_type = ActivationType.silu if self.gated else ActivationType.gelu + # TODO: `hidden_size` not yet validated. if self.ffn_hidden_size is None: - self.ffn_hidden_size = 4 * self.hidden_size + self.ffn_hidden_size = 4 * self.block.hidden_size self.num_unshared_experts = self.num_experts - self.num_shared_experts @@ -140,10 +167,49 @@ def _validate(self) -> None: Assert.leq(self.num_shared_experts, self.num_experts) Assert.leq(self.num_shared_experts + self.num_experts_per_token, self.num_experts) - if isinstance(self.mlp_lr_scale, tuple): + if isinstance(self.mlp_lr_scale, list): Assert.eq(len(self.mlp_lr_scale), self.num_experts) for scale in self.mlp_lr_scale: if scale is not None: Assert.geq(scale, 0) elif self.mlp_lr_scale is not None: Assert.geq(self.mlp_lr_scale, 0) + + if not (self.layer_1_bias_initialization.is_default and self.layer_2_bias_initialization.is_default): + assert self.add_bias + + @functools.cached_property + def layer_1_weight_initialization_method(self) -> Initializer: + if self.layer_1_weight_initialization.is_default: + return init_normal_(0, self.block.hidden_size**-0.5) + else: + return self.layer_1_weight_initialization.get_initializer() + + @functools.cached_property + def layer_1_bias_initialization_method(self) -> Initializer: + if self.layer_1_bias_initialization.is_default: + return init_zeros_ + else: + return self.layer_1_bias_initialization.get_initializer() + + @functools.cached_property + def layer_2_weight_initialization_method(self) -> Initializer: + if self.layer_2_weight_initialization.is_default: + return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1)) + else: + return self.layer_2_weight_initialization.get_initializer() + + @functools.cached_property + def layer_2_bias_initialization_method(self) -> Initializer: + if self.layer_2_bias_initialization.is_default: + return init_zeros_ + else: + return self.layer_2_bias_initialization.get_initializer() + + @functools.cached_property + def router_weight_initialization_method(self) -> Initializer: + if self.router_weight_initialization.is_default: + return init_zeros_ + else: + assert self.add_bias + return self.router_weight_initialization.get_initializer() diff --git a/fast_llm/layers/block/mlp/mixture_of_experts.py b/fast_llm/layers/block/mlp/mixture_of_experts.py index b50d8958..f507ca1a 100644 --- a/fast_llm/layers/block/mlp/mixture_of_experts.py +++ b/fast_llm/layers/block/mlp/mixture_of_experts.py @@ -4,15 +4,17 @@ import torch from fast_llm.core.distributed import ProcessGroup, set_generator +from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.block.config import BlockConfig, BlockKwargs +from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.block.mlp.config import MLPConfig, MLPLossNames, RoutingType from fast_llm.layers.block.mlp.mlp import MLPBase from fast_llm.layers.common.auxiliary_loss import AuxiliaryLoss, z_loss -from fast_llm.utils import Assert +from fast_llm.layers.common.linear import Linear +from fast_llm.utils import Assert, combine_lr_scales logger = logging.getLogger(__name__) @@ -34,24 +36,27 @@ class MixtureOfExpertMLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, - block_config: BlockConfig, distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, name: str, - # TODO: Vary LR scale (or full config?) per expert. lr_scale: float | None, ): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + super().__init__(config, distributed_config, hidden_dim, block_index, name, lr_scale) - self.router = self._config.router.get_layer( + self.router = Linear( self._hidden_dim, TensorDim("router_experts", self._config.num_unshared_experts), - lr_scale=self._lr_scale, - peft=self._block_config.peft, + bias=False, + weight_init_method=init_normal_( + std=self._config.init_method_std, + min_val=self._config.init_method_min, + max_val=self._config.init_method_max, + ), + lr_scale=combine_lr_scales(self._config.router_lr_scale, self._lr_scale), ) dropless_moe = self._config.dropless_moe if dropless_moe and self._sequence_parallel: diff --git a/fast_llm/layers/block/mlp/mlp.py b/fast_llm/layers/block/mlp/mlp.py index 0139aa21..0c1a18ae 100644 --- a/fast_llm/layers/block/mlp/mlp.py +++ b/fast_llm/layers/block/mlp/mlp.py @@ -2,7 +2,6 @@ import torch -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import ConcatenatedTensorDim, TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig @@ -10,7 +9,8 @@ from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig from fast_llm.layers.block.mlp.config import MLPConfig -from fast_llm.utils import Assert +from fast_llm.layers.common.linear import LinearBase +from fast_llm.utils import Assert, combine_lr_scales class MLPBase[ConfigType: MLPConfig](BlockLayer[ConfigType]): @@ -25,42 +25,45 @@ def __init__( lr_scale: float | None, ): super().__init__(config, block_config, distributed_config, hidden_dim, block_index, name, lr_scale) + self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, intermediate_2_dim = self._get_intermediate_dims() + self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - init_method_1 = init_normal_( - std=self._config.init_method_std_mlp_1, - min_val=self._config.init_method_min_mlp_1, - max_val=self._config.init_method_max_mlp_1, + layer_lr_scale = ( + self._block_config.per_layer_lr_scale[block_index] if self._block_config.per_layer_lr_scale else None ) - init_method_2 = init_normal_( - std=self._config.init_method_std_mlp_2, - min_val=self._config.init_method_min_mlp_2, - max_val=self._config.init_method_max_mlp_2, + lr_scale = ( + tuple(self._config.mlp_lr_scale) + if isinstance(self._config.mlp_lr_scale, list) + else self._config.mlp_lr_scale ) - - self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + lr_scale = combine_lr_scales(lr_scale, layer_lr_scale) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) - self.layer_1 = self.config.layer_1.get_layer( + self.layer_1 = LinearBase( hidden_dim, intermediate_1_dim, - sequence_parallel=self._sequence_parallel, - transposed_weight=False, - auto_bias_grad_accumulation=False, - lr_scale=self._lr_scale, - peft=self._block_config.peft, + bias=self._config.add_bias, + weight_init_method=self._config.layer_1_weight_initialization_method, + bias_init_method=self._config.layer_1_bias_initialization_method, + lr_scale=lr_scale, ) - self.layer_2 = self.config.layer_2.get_layer( + self.layer_2 = LinearBase( intermediate_2_dim, hidden_dim, - sequence_parallel=self._sequence_parallel, - transposed_weight=True, + bias=self._config.add_bias, + weight_init_method=self._config.layer_2_weight_initialization_method, + bias_init_method=self._config.layer_2_bias_initialization_method, auto_bias_grad_accumulation=self._distributed_config.tensor_parallel > 1, - lr_scale=self._lr_scale, - peft=self._block_config.peft, + transposed_weight=True, + lr_scale=lr_scale, ) + # PEFT. + self.layer_1 = self._block_config.peft.apply_linear(self.layer_1, False) + self.layer_2 = self._block_config.peft.apply_linear(self.layer_2, False) + def _get_intermediate_dims(self): intermediate_2_dim = TensorDim("intermediate", self._config.ffn_hidden_size, self._parallel_dim) if self._config.gated: @@ -71,7 +74,7 @@ def _get_intermediate_dims(self): return intermediate_1_dim, intermediate_2_dim -class MLP[ConfigType: BlockConfig](MLPBase[ConfigType]): +class MLP[ConfigType: MLPConfig](MLPBase[ConfigType]): def __init__( self, config: ConfigType, diff --git a/fast_llm/layers/common/linear/linear.py b/fast_llm/layers/common/linear.py similarity index 77% rename from fast_llm/layers/common/linear/linear.py rename to fast_llm/layers/common/linear.py index 744428c3..ca807e67 100644 --- a/fast_llm/layers/common/linear/linear.py +++ b/fast_llm/layers/common/linear.py @@ -3,7 +3,7 @@ import torch -from fast_llm.config import Configurable +from fast_llm.engine.config_utils.initialization import init_zeros_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.linear import ( @@ -15,9 +15,7 @@ output_parallel_linear_backward, output_parallel_linear_forward, ) -from fast_llm.layers.common.linear.config import LinearConfig from fast_llm.tensor import ParameterMeta -from fast_llm.utils import combine_lr_scales logger = logging.getLogger(__name__) @@ -37,41 +35,41 @@ def backward(self, grad_output: torch.Tensor, context: typing.Any) -> torch.Tens raise NotImplementedError() -class LinearBase(Configurable[LinearConfig], LinearLike): +class LinearBase(LinearLike): """ A base module for linear layers holding weights and biases. """ def __init__( self, - config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, + bias=True, + weight_init_method, + bias_init_method=init_zeros_, transposed_weight: bool = False, - sequence_parallel: bool = False, auto_bias_grad_accumulation: bool = False, - lr_scale: float | None = None, + lr_scale: float | None | tuple[float | None, ...] = None, ): - super().__init__(config) + super().__init__() self._transposed_weight = transposed_weight - self._sequence_parallel = sequence_parallel self._in_dim = in_dim self._out_dim = out_dim - self._lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) + self._weight_init_method = weight_init_method self.weight = ParameterMeta.from_dims( (self._in_dim, self._out_dim) if self._transposed_weight else (self._out_dim, self._in_dim), - init_method=self._config.weight_initialization, + init_method=weight_init_method, auto_grad_accumulation=False, - lr_scale=self._lr_scale, + lr_scale=lr_scale, ) - if self._config.bias: + if bias: self.bias = ParameterMeta.from_dims( (self._out_dim,), - init_method=self._config.bias_initialization, + init_method=bias_init_method, weight_decay=False, auto_grad_accumulation=auto_bias_grad_accumulation, - lr_scale=self._lr_scale, + lr_scale=lr_scale, ) else: self.bias = None @@ -80,10 +78,6 @@ def __init__( def transposed_weight(self) -> bool: return self._transposed_weight - @property - def lr_scale(self) -> float | None: - return self._lr_scale - class Linear(LinearBase): """ @@ -92,25 +86,24 @@ class Linear(LinearBase): def __init__( self, - config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, + bias=True, + weight_init_method, + bias_init_method=init_zeros_, transposed_weight: bool = False, - sequence_parallel: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None = None, + lr_scale: float | None | tuple[float | None, ...] = None, ): assert not in_dim.is_parallel assert not out_dim.is_parallel - assert not sequence_parallel super().__init__( - config, in_dim, out_dim, + bias=bias, + weight_init_method=weight_init_method, + bias_init_method=bias_init_method, transposed_weight=transposed_weight, - sequence_parallel=sequence_parallel, - auto_bias_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, ) @@ -130,25 +123,26 @@ class OutputParallelLinear(LinearBase): def __init__( self, - config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, + bias=True, + weight_init_method, + bias_init_method=init_zeros_, transposed_weight: bool = False, sequence_parallel: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None = None, + lr_scale: float | None | tuple[float | None, ...] = None, ): assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( - config, in_dim, out_dim, + bias=bias, + weight_init_method=weight_init_method, + bias_init_method=bias_init_method, transposed_weight=transposed_weight, - sequence_parallel=sequence_parallel and self._group_size > 1, - auto_bias_grad_accumulation=auto_bias_grad_accumulation, lr_scale=lr_scale, ) @@ -173,25 +167,28 @@ class InputParallelLinear(LinearBase): def __init__( self, - config: LinearConfig, in_dim: TensorDim, out_dim: TensorDim, *, - transposed_weight: bool = False, + bias=True, + weight_init_method, + bias_init_method=init_zeros_, sequence_parallel: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None = None, + transposed_weight: bool = False, + lr_scale: float | None | tuple[float | None, ...] = None, ): assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( - config, in_dim, out_dim, + bias=bias, + weight_init_method=weight_init_method, + bias_init_method=bias_init_method, transposed_weight=transposed_weight, - sequence_parallel=sequence_parallel and self._group_size > 1, - auto_bias_grad_accumulation=auto_bias_grad_accumulation, + # Tensor-parallel bias is computed in _bias_dropout_grad. + auto_bias_grad_accumulation=self._group_size > 1, lr_scale=lr_scale, ) diff --git a/fast_llm/layers/common/linear/__init__.py b/fast_llm/layers/common/linear/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/fast_llm/layers/common/linear/config.py b/fast_llm/layers/common/linear/config.py deleted file mode 100644 index 3a5d0c91..00000000 --- a/fast_llm/layers/common/linear/config.py +++ /dev/null @@ -1,136 +0,0 @@ -import typing - -from fast_llm.config import Config, Field, FieldHint, config_class -from fast_llm.engine.config_utils.initialization import InitializationConfig -from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.utils import combine_lr_scales - -if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_dim import TensorDim - from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, LinearLike, OutputParallelLinear - from fast_llm.tensor import ParameterMeta - - -@config_class() -class LinearWeightConfig(Config): - initialization: InitializationConfig = Field( - desc="Initialization configuration.", - hint=FieldHint.feature, - ) - lr_scale: float | None = None - - # Fixed defaults don't make sense because each parent layer uses its own. - # Instead, we use this variable to set defaults dynamically. - # This can either be a constant, - # or may point to another config, ex. to set a default for all layers in a model. - default: typing.Self = Field(init=False) - - def _validate(self) -> None: - if hasattr(self, "default"): - self.default.validate() - with self._set_implicit_default(): - if self.initialization.is_default: - self.initialization = self.default.initialization - if self.lr_scale is None: - self.lr_scale = self.default.lr_scale - if None in (self.initialization, self.lr_scale): - raise ValueError("Missing default values for linear weight configuration.") - - super()._validate() - - def get_weight( - self, - in_dim: TensorDim, - out_dim: TensorDim, - *, - transposed_weight: bool = False, - auto_grad_accumulation: bool = False, - lr_scale: float | None, - ) -> "ParameterMeta": - from fast_llm.tensor import ParameterMeta - - return ParameterMeta.from_dims( - (in_dim, out_dim) if transposed_weight else (out_dim, in_dim), - init_method=self.initialization, - auto_grad_accumulation=auto_grad_accumulation, - lr_scale=combine_lr_scales(self.lr_scale, lr_scale), - ) - - -@config_class() -class LinearConfig(Config): - weight_initialization: InitializationConfig = Field( - desc="Initialization configuration for the weight.", - hint=FieldHint.feature, - ) - bias_initialization: InitializationConfig = Field( - desc="Initialization configuration for the bias.", - hint=FieldHint.feature, - ) - bias: bool = Field( - default=None, - desc="Use bias.", - hint=FieldHint.architecture, - ) - lr_scale: float | None = None - apply_peft: bool = Field( - default=None, - desc="Apply peft on this layer if defined. Otherwise, treat the layer as a non-peft layer (may be frozen).", - hint=FieldHint.feature, - ) - # Fixed defaults don't make sense because each parent layer uses its own. - # Instead, we use this variable to set defaults dynamically. - # This can either be a constant, - # or may point to another config, ex. to set a default for all layers in a model. - default: typing.Self = Field(init=False) - - def _validate(self) -> None: - if hasattr(self, "default"): - self.default.validate() - with self._set_implicit_default(): - if self.bias is None: - self.bias = self.default.bias - if self.weight_initialization.is_default: - self.weight_initialization = self.default.weight_initialization - if self.bias_initialization.is_default: - self.bias_initialization = self.default.bias_initialization - if self.lr_scale is None: - self.lr_scale = self.default.lr_scale - if self.apply_peft is None: - self.apply_peft = self.default.apply_peft - if None in (self.bias, self.weight_initialization, self.bias_initialization, self.lr_scale, self.apply_peft): - raise ValueError("Missing default values for linear layer configuration.") - - super()._validate() - - def get_layer( - self, - in_dim: TensorDim, - out_dim: TensorDim, - *, - sequence_parallel: bool = False, - transposed_weight: bool = False, - auto_bias_grad_accumulation: bool = False, - lr_scale: float | None, - peft: PeftConfig | None = None, - ) -> "LinearLike": - if in_dim.parallel_dim is not None: - assert out_dim.parallel_dim is None - cls = InputParallelLinear - elif out_dim.parallel_dim is not None: - cls = OutputParallelLinear - else: - assert not sequence_parallel - cls = Linear - out = cls( - self, - in_dim, - out_dim, - transposed_weight=transposed_weight, - sequence_parallel=sequence_parallel, - auto_bias_grad_accumulation=auto_bias_grad_accumulation, - lr_scale=lr_scale, - ) - if peft is not None: - out = peft.apply_linear(out, self.apply_peft) - return out diff --git a/fast_llm/layers/common/normalization/config.py b/fast_llm/layers/common/normalization/config.py index 95b5314f..12f7c5ee 100644 --- a/fast_llm/layers/common/normalization/config.py +++ b/fast_llm/layers/common/normalization/config.py @@ -6,11 +6,11 @@ from fast_llm.config import Field, FieldHint, check_field, config_class from fast_llm.engine.base_model.config import BaseModelConfig from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_ones_, init_zeros_ +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: - from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.common.normalization.normalization import Normalization diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index e31fe948..5e5dc879 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -148,7 +148,7 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None, None, class Normalization[ConfigType: NormalizationConfig](Configurable[ConfigType], torch.nn.Module): - def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + def __init__(self, config: NormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config) self._hidden_dim = hidden_dim self._lr_scale = combine_lr_scales(self._config.lr_scale, lr_scale) @@ -172,7 +172,7 @@ class LayerNormalization[ConfigType: LayerNormalizationConfig](Normalization[Con TODO: Review this? """ - def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + def __init__(self, config: LayerNormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) implementation = self._config.implementation if implementation == NormalizationImplementation.auto: @@ -240,7 +240,7 @@ def _forward_torch(self, input_: torch.Tensor) -> torch.Tensor: ) -class RMSNormalization[ConfigType: RMSNormalizationConfig](Normalization[ConfigType], torch.nn.Module): +class RMSNormalization[ConfigType: RMSNormalizationConfig](Configurable[ConfigType], torch.nn.Module): """ A RMS normalization layer. Note: Converting input automatically to training dtype to match Apex behaviour, @@ -248,7 +248,7 @@ class RMSNormalization[ConfigType: RMSNormalizationConfig](Normalization[ConfigT TODO: Review this? """ - def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | None = None): + def __init__(self, config: RMSNormalizationConfig, hidden_dim: TensorDim, lr_scale: float | None = None): super().__init__(config, hidden_dim, lr_scale) assert not hidden_dim.is_parallel implementation = self._config.implementation @@ -277,7 +277,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | init_method=self._config.weight_initialization_method, weight_decay=False, auto_grad_accumulation=True, - lr_scale=self._lr_scale, + lr_scale=lr_scale, ) self._normalized_shape = self.weight.shape diff --git a/fast_llm/layers/common/peft/config.py b/fast_llm/layers/common/peft/config.py index 0bf68adc..4090c001 100644 --- a/fast_llm/layers/common/peft/config.py +++ b/fast_llm/layers/common/peft/config.py @@ -6,7 +6,7 @@ if typing.TYPE_CHECKING: import torch - from fast_llm.layers.common.linear.linear import InputParallelLinear, LinearBase, LinearLike, OutputParallelLinear + from fast_llm.layers.common.linear import LinearBase, LinearLike, OutputParallelLinear from fast_llm.layers.common.normalization.normalization import Normalization from fast_llm.tensor import ParameterMeta @@ -73,6 +73,7 @@ def apply_linear( if not enabled: return self.apply_other(module) + from fast_llm.layers.common.linear import InputParallelLinear from fast_llm.layers.common.peft.lora import lora_linear if isinstance(module, InputParallelLinear): diff --git a/fast_llm/layers/common/peft/lora.py b/fast_llm/layers/common/peft/lora.py index 70e26817..9e0ca0dd 100644 --- a/fast_llm/layers/common/peft/lora.py +++ b/fast_llm/layers/common/peft/lora.py @@ -4,8 +4,7 @@ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.functional.autograd import wrap_forward_backward -from fast_llm.layers.common.linear.config import LinearConfig -from fast_llm.layers.common.linear.linear import Linear, LinearBase +from fast_llm.layers.common.linear import Linear, LinearBase def lora_linear( @@ -35,8 +34,6 @@ def lora_linear( middle_dim = TensorDim("lora_middle", rank) - # Use the same config as the wrapped linear - config = LinearConfig.from_dict(module.config, {"bias": False, "lr_scale": module.weight.lr_scale}) module.lora_0 = Linear( in_dim, middle_dim, diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 60216ec7..b06a870d 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,13 +1,10 @@ -import typing +import functools from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none from fast_llm.engine.base_model.config import BaseModelConfig -from fast_llm.engine.config_utils.initialization import init_normal_ +from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_ from fast_llm.functional.config import CrossEntropyImpl, DistillationLossImpl -from fast_llm.layers.attention.config import TransformerConfig -from fast_llm.layers.attention.rotary.config import NoRotaryConfig -from fast_llm.layers.block.config import BlockKwargs -from fast_llm.layers.common.linear.config import LinearWeightConfig +from fast_llm.layers.block.config import BlockConfig, BlockKwargs from fast_llm.utils import Assert @@ -39,43 +36,22 @@ class LanguageModelKwargs(BlockKwargs): @config_class() class LanguageModelBaseConfig(BaseModelConfig): # TODO: block - transformer: TransformerConfig = Field( + transformer: BlockConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - word_embeddings_layer: LinearWeightConfig = Field( - desc="Configuration for the word embedding (weight).", - hint=FieldHint.architecture, - ) - position_embeddings_layer: LinearWeightConfig = Field( - desc="Configuration for the word embedding (weight).", - hint=FieldHint.architecture, - ) - output_layer: LinearWeightConfig = Field( - desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", - hint=FieldHint.architecture, - ) - # TODO: Move to `position_embeddings_layer`? - max_position_embeddings: int = Field( - default=2048, - desc="Number of absolute position embeddings, if applicable.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - # TODO: Move to `word_embeddings_layer`/`output_layer`? vocab_size: int = Field( default=49152, desc="Size of the vocabulary, i.e., number of vocabulary embeddings and logits.", hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - # TODO: Move to `position_embeddings_layer`? - use_position_embeddings: bool = Field( + absolute_position_embeddings: int | None = Field( + # TODO: backward compatibility? default=None, - desc="Enable absolute position embeddings. Default: Enable unless using rotary embeddings.", + desc="Number of absolute position embeddings, if applicable.", hint=FieldHint.architecture, ) - # TODO: Move to `output_layer`? (dynamic type?) tie_word_embeddings: bool = Field( default=True, desc="Tie the output weights (logits) with the vocabulary embedding.", @@ -127,7 +103,6 @@ class LanguageModelBaseConfig(BaseModelConfig): # Tensor-parallel word embeddings # (Default init std is different, dropout won't match, needs seq_first = False.) # (disable to allow for sequence-parallel embeddings and logits, better for larger models) - # TODO: Rename to `vocab_parallel`? Move to `word_embeddings_layer`/`output_layer`? parallel_embeddings: bool = Field( default=True, desc="Allow for tensor-parallel vocabulary embeddings and output weights.", @@ -176,29 +151,45 @@ class LanguageModelBaseConfig(BaseModelConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + embeddings_lr_scale: float | None = Field( + default=None, + desc="Learning rate scale for the word embeddings.", + doc="May be used to freeze some layers by setting their scale to zero.", + hint=FieldHint.feature, + valid=skip_valid_if_none(check_field(Assert.geq, 0)), + ) + output_lr_scale: float | None = Field( + default=None, + desc="Custom learning rate scale for the output weights.", + doc="May be used to freeze the output weights by setting their scale to zero.", + hint=FieldHint.feature, + ) prediction_loss_coefficient: list[float] | None = Field( default=None, desc="Loss coefficient for each prediction head.", doc="If not provided, all heads are equally weighted.", hint=FieldHint.feature, ) + word_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for word embeddings. Default: normal(std=hidden_size**-0.5)", + hint=FieldHint.feature, + ) + position_embedding_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for position embeddings. Default: normal(hidden_size**-0.5)", + hint=FieldHint.feature, + ) + output_weight_initialization: InitializationConfig = Field( + desc="Initialization configuration for untied output weights. Default: normal(hidden_size**-0.5)", + hint=FieldHint.feature, + ) def _validate(self) -> None: - default_init = init_normal_(0, self.hidden_size**-0.5) - self.word_embeddings_layer.default = LinearWeightConfig(weight_initialization=default_init) - # TODO: Use `word_embeddings_layer` as default? (More consistent with tied weights) - self.output_layer.default = LinearWeightConfig(weight_initialization=default_init) - self.position_embeddings_layer.default = LinearWeightConfig(weight_initialization=default_init) - - self.transformer.validate() with self._set_implicit_default(): if self.language_model_loss_factor is None: if self.distillation_model is None: self.language_model_loss_factor = 1.0 else: self.language_model_loss_factor = 0.0 - if self.use_position_embeddings is None: - self.use_position_embeddings = isinstance(self.transformer.rotary, NoRotaryConfig) super()._validate() if self.distillation_model is not None: if self.prediction_heads > 1: @@ -211,29 +202,34 @@ def _validate(self) -> None: # -1 because the first prediction head's transformer layer is accounted for in num_layers # +1 because the layer index starts at 1 Assert.eq( - len(self.transformer.per_layer_lr_scale), self.transformer.num_layers + self.prediction_heads - 1 + 1 + len(self.transformer.per_layer_lr_scale), self.transformer.num_blocks + self.prediction_heads - 1 + 1 ) - - @property - def num_absolute_position_embeddings(self) -> int: - # TODO: Rename from max embeddings. - return self.max_position_embeddings if self.use_absolute_position_embeddings else None + if not self.output_weight_initialization.is_default: + assert self.use_absolute_position_embeddings + if not self.output_weight_initialization.is_default: + assert not self.tie_word_embeddings @property def use_absolute_position_embeddings(self) -> int: - # TODO: Set through num embeddings instead instead. - return self.use_position_embeddings + return self.absolute_position_embeddings is not None + + @functools.cached_property + def word_embedding_weight_initialization_method(self) -> Initializer: + if self.word_embedding_weight_initialization.is_default: + return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.word_embedding_weight_initialization.get_initializer() + + @functools.cached_property + def position_embedding_weight_initialization_method(self) -> Initializer: + if self.position_embedding_weight_initialization.is_default: + return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.position_embedding_weight_initialization.get_initializer() - @classmethod - def from_flat_dict( - cls, - default: dict[str, typing.Any], - strict: bool = True, - ) -> typing.Self: - # The backward compatibility fix in `NormalizationArchitectureConfig` - # won't work for older checkpoints saved with a flat config. - # TODO v0.3: Remove flat format - cls._handle_renamed_field(default, "normalization_type", "type") - cls._handle_renamed_field(default, "layer_norm_eps", "epsilon") - cls._handle_renamed_field(default, "zero_centered_normalization", "zero_centered") - return super().from_flat_dict(default, strict) + @functools.cached_property + def output_weight_initialization_method(self) -> Initializer: + if self.output_weight_initialization.is_default: + return init_normal_(self.transformer.hidden_size**-0.5) + else: + return self.output_weight_initialization.get_initializer() diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 323e26a6..b8ef2c6d 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -5,7 +5,6 @@ from fast_llm.core.distributed import set_generator from fast_llm.core.ops import reduce_forward, split from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.layers.block.block import BlockLayerBase @@ -47,11 +46,10 @@ def __init__( ) self._residual_dtype = ( self._distributed_config.optimization_dtype - if config.transformer.full_precision_residual + if self._block_config.full_precision_residual else self._distributed_config.training_dtype ).torch - self._sequence_parallel = self._distributed_config.sequence_tensor_parallel - self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and config.parallel_embeddings + self._parallel_embeddings = self._distributed_config.tensor_parallel > 1 and self._config.parallel_embeddings self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) vocab_dim = TensorDim( "vocab", self._config.vocab_size, self._parallel_dim if self._parallel_embeddings else None @@ -61,19 +59,17 @@ def __init__( self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size self._vocab_end_index = (self._distributed_config.tensor_rank + 1) * vocab_dim.size - self.word_embeddings_weight = self._config.word_embeddings_layer.get_weight( - vocab_dim, self._hidden_dim, auto_grad_accumulation=True, lr_scale=self._lr_scale + self.word_embeddings_weight = ParameterMeta.from_dims( + (vocab_dim, hidden_dim), + init_method=self._config.word_embedding_weight_initialization_method, + lr_scale=self._config.embeddings_lr_scale, ) if self._config.use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( (TensorDim("position_embeddings", self._config.max_position_embeddings), self._hidden_dim), - init_method=init_normal_( - std=config.init_method_std_embed, - min_val=config.init_method_min_embed, - max_val=config.init_method_max_embed, - ), - allow_sequence_tensor_parallel=not config.parallel_embeddings, - lr_scale=config.embeddings_lr_scale, + init_method=self._config.position_embedding_weight_initialization_method, + allow_sequence_tensor_parallel=not self._config.parallel_embeddings, + lr_scale=self._config.embeddings_lr_scale, ) # PEFT. @@ -115,7 +111,7 @@ def _forward(self, input_: torch.Tensor, position_ids: torch.Tensor | None, mask with set_generator( self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator ): - embeddings = torch.dropout(embeddings, self._config.transformer.hidden_dropout, self.training) + embeddings = torch.dropout(embeddings, self._block_config.hidden_dropout, self.training) return embeddings.to(dtype=self._residual_dtype) def forward( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 59a071c6..a21fd993 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -6,7 +6,6 @@ from fast_llm.core.ops import split_op from fast_llm.engine.base_model.base_model import Layer -from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim, scalar_dim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import grad_is_context, wrap_forward_backward @@ -93,16 +92,9 @@ def __init__( # Only the first head defines the output weights if self._prediction_distance == 0 and not self._config.tie_word_embeddings: # untie embedding weights - self.output_weights = self._config.word_embeddings_layer.get_weight( - self._vocab_dim, hidden_dim, auto_grad_accumulation=True, lr_scale=self._lr_scale - ) self.output_weights = ParameterMeta.from_dims( (self._vocab_dim, hidden_dim), - init_method=init_normal_( - std=self._config.init_method_std_embed, - min_val=self._config.init_method_min_embed, - max_val=self._config.init_method_max_embed, - ), + init_method=self._config.output_weight_initialization_method, lr_scale=self._config.output_lr_scale, ) diff --git a/fast_llm/layers/ssm/block.py b/fast_llm/layers/ssm/block.py index 22d01a5c..fef890d4 100644 --- a/fast_llm/layers/ssm/block.py +++ b/fast_llm/layers/ssm/block.py @@ -20,8 +20,8 @@ def __init__( distributed_config: DistributedConfig, hidden_dim: TensorDim, block_index: int, - name: str, lr_scale: float | None, + name: str, mixer_class: type[BlockLayer], return_input: bool = False, ): @@ -35,4 +35,4 @@ def _mixer_class(self) -> type[BlockLayer]: @property def _mixer_config(self) -> SSMConfig: - return self._ssm_config + return self._config diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 9c991bd7..f9462a94 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear.linear import InputParallelLinear, OutputParallelLinear +from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_kaiming_ from fast_llm.tensor import ParameterMeta diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index b5596a2f..453c14af 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -10,7 +10,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockKwargs -from fast_llm.layers.common.linear.linear import Linear +from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.tensor import ParameterMeta from fast_llm.utils import Assert, combine_lr_scales, div diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 549fb902..2659e415 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -9,7 +9,7 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.block.block import BlockLayer from fast_llm.layers.block.config import BlockConfig, BlockDimNames, BlockKwargs -from fast_llm.layers.common.linear.linear import InputParallelLinear, Linear, OutputParallelLinear +from fast_llm.layers.common.linear import InputParallelLinear, Linear, OutputParallelLinear from fast_llm.layers.ssm.config import SSMConfig from fast_llm.layers.ssm.mamba import init_A, init_dtprojbias, init_kaiming_ from fast_llm.tensor import ParameterMeta diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 3ca2d71f..1ee83630 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -188,15 +188,12 @@ class GPTTrainerConfig(PretrainedGPTModelConfig, TrainerConfig): reference_models: dict[str, PretrainedGPTModelConfig] = FieldUpdate() def _validate(self) -> None: - if self.batch.sequence_length is None: - # TODO: Drop this. - self.batch.sequence_length = self.model.base_model.max_position_embeddings if self.model.base_model.use_megatron_initialization: set_megatron_distributed_seeds(self.model.distributed) super()._validate() if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) distillation_model = self.model.base_model.distillation_model dpo_reference_model = self.model.base_model.dpo_reference_model diff --git a/fast_llm/models/gpt/conversion.py b/fast_llm/models/gpt/conversion.py index 36975dea..f26e811d 100644 --- a/fast_llm/models/gpt/conversion.py +++ b/fast_llm/models/gpt/conversion.py @@ -176,7 +176,7 @@ def _create_weight_converters( self, ) -> list[WeightConverter]: converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks # Embeddings converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight")) @@ -241,13 +241,13 @@ def _create_transformer_layer_converters( converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_1", (), - transformer_config.add_mlp_bias, + transformer_config.add_bias, cls=IgnoreExportWeightConverter, ) converters += self._get_weight_and_bias_converters( f"{fast_llm_layer_name}.mlp.layer_2", (), - transformer_config.add_mlp_bias, + transformer_config.add_bias, cls=IgnoreExportWeightConverter, ) converters += [IgnoreExportWeightConverter(f"{fast_llm_layer_name}.mlp.router.weight", ())] @@ -256,7 +256,7 @@ def _create_transformer_layer_converters( return converters def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] @@ -344,12 +344,12 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig transformer_config: TransformerConfig = self._model.config.base_model.transformer return [ *self._get_weight_and_bias_converters( - f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_mlp_bias + f"{fast_llm_prefix}.mlp.layer_1", f"{hf_prefix}.mlp.c_fc", transformer_config.add_bias ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.c_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -463,13 +463,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -531,13 +531,13 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] @@ -641,20 +641,20 @@ def _get_mlp_converters(self, fast_llm_prefix: str, hf_prefix: str) -> list[Weig *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_1", (f"{hf_prefix}.mlp.gate_proj", f"{hf_prefix}.mlp.up_proj"), - transformer_config.add_mlp_bias, + transformer_config.add_bias, SplitWeightConverter, ), *self._get_weight_and_bias_converters( f"{fast_llm_prefix}.mlp.layer_2", f"{hf_prefix}.mlp.down_proj", - transformer_config.add_mlp_bias, + transformer_config.add_bias, MLPLayer2Converter, ), ] # Override base method to handle the MTP heads def _create_lm_head_converters(self) -> list[WeightConverter]: - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks prediction_heads = self._model.config.base_model.prediction_heads norm_bias: bool = isinstance(self._model.config.base_model.transformer.normalization, LayerNormalizationConfig) converters = [] diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index b13c7772..a0b2d6f4 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -100,18 +100,12 @@ def _get_block( name: str, return_input: bool = False, ): - lr_scale = ( - None - if self._config.transformer.per_layer_lr_scale is None - else self._config.transformer.per_layer_lr_scale[block_index] - ) return TransformerBlock( self._config.transformer, self._distributed_config, self._hidden_dim, block_index, name, - lr_scale, return_input, ) @@ -397,7 +391,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.load_balancing_loss, formatted_name="load balancing loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.transformer.expert_z_loss_coefficient: @@ -405,7 +399,7 @@ def loss_defs(self) -> list[LossDef]: LossDef( name=MLPLossNames.router_z_loss, formatted_name="router z loss", - count=self._config.transformer.num_layers, + count=self._config.transformer.num_blocks, ) ) if self._config.logit_z_loss: @@ -445,7 +439,7 @@ def get_tflops(self, phase: PhaseType, elapsed_time_per_iteration, batch_size, s consumed_tokens_per_iteration = sequence_length * batch_size - num_transformer_layers = transformer_config.num_layers + self._config.base_model.prediction_heads - 1 + num_transformer_layers = transformer_config.num_blocks + self._config.base_model.prediction_heads - 1 transformer_flops_base = ( 2 * checkpoint_activations_factor * consumed_tokens_per_iteration * num_transformer_layers ) diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9d54675b..1c85327e 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -58,13 +58,13 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_blocks - if len(self.hybrid_block_layout) != self.transformer.num_layers: - message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: + if len(self.hybrid_block_layout) != self.transformer.num_blocks: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_blocks}" + if self.transformer.num_blocks % len(self.hybrid_block_layout) != 0: raise ValueError(message) - num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + num_repeats = self.transformer.num_blocks // len(self.hybrid_block_layout) logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats @@ -171,7 +171,7 @@ def _validate(self) -> None: else: Assert.eq(self.reference_models.keys(), {name}) if self.model.base_model.use_absolute_position_embeddings: - Assert.geq(self.model.base_model.num_absolute_position_embeddings, self.batch.sequence_length) + Assert.geq(self.model.base_model.absolute_position_embeddings, self.batch.sequence_length) # if self.model.base_model.distillation_model is not None: # # TODO: Support loss masking for distillation? # assert not self.batch.use_loss_masking_spans diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index e9b18b84..012f2fae 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -21,7 +21,7 @@ from fast_llm.engine.checkpoint.huggingface import CustomModelingExportMixin, HuggingfaceStateDictCheckpointHandler from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType -from fast_llm.layers.common.normalization.config import RMSNormalizationConfig +from fast_llm.layers.common.normalization import RMSNormalizationConfig from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( @@ -223,7 +223,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() or [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear for i in range(num_layers): @@ -387,7 +387,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: # not using super() because LLamba model is called backbone in the checkpoints converters = [] - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False ssm_bias: bool = self._model.config.base_model.ssm.add_bias_linear @@ -581,7 +581,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: def _create_weight_converters(self) -> list[WeightConverter]: converters = super()._create_weight_converters() - num_layers = self._model.config.base_model.transformer.num_layers + num_layers = self._model.config.base_model.transformer.num_blocks norm_bias: bool = False # Embedding and output diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 9b79e74a..26f43721 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -1,6 +1,7 @@ import logging import typing +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.attention.block import TransformerBlock from fast_llm.layers.ssm.block import SSMBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTInferenceRunner, GPTModel @@ -16,6 +17,15 @@ class HybridSSMBaseModel[ConfigType: HybridSSMBaseModelConfig](GPTBaseModel[Conf As for the mixer, transformer uses MHA. For the LlambaBlock we support Mamba1 and discrete mamba2. """ + _is_setup: bool = False + + def __init__( + self, + config: HybridSSMBaseModelConfig, + distributed_config: DistributedConfig, + ): + super().__init__(config, distributed_config) + def _get_block( self, block_index: int, @@ -29,12 +39,6 @@ def _get_block( # Decoder block block_type = self._config.hybrid_block_layout[block_index - 1] - lr_scale = ( - None - if self._config.transformer.per_layer_lr_scale is None - else self._config.transformer.per_layer_lr_scale[block_index] - ) - if block_type == SSMBlockType.transformer: return TransformerBlock( self._config.transformer, @@ -42,7 +46,6 @@ def _get_block( self._hidden_dim, block_index, name, - lr_scale, return_input, ) else: @@ -51,10 +54,9 @@ def _get_block( self._config.ssm, self._distributed_config, self._hidden_dim, + self._config.ssm_block_type.get_mixer_class(), block_index, name, - lr_scale, - self._config.ssm_block_type.get_mixer_class(), return_input, ) diff --git a/tests/models/test_generate.py b/tests/models/test_generate.py index 7f0b902f..cb9c69cc 100644 --- a/tests/models/test_generate.py +++ b/tests/models/test_generate.py @@ -354,7 +354,7 @@ def _test_forward_return_hidden_states( # hidden_states include embeddings layer assert ( - len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_layers + len(res_fast_llm.hidden_states) - 1 == fast_llm_model.config.fast_llm_config.base_model.transformer.num_blocks ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index e9bdeba9..4815dcb3 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -162,6 +162,7 @@ def _update_and_add_testing_config( "model.base_model.transformer.num_attention_heads=8", "model.base_model.transformer.head_groups=8", "model.base_model.transformer.init_method_std=0.022", + "model.base_model.transformer.use_position_embeddings=True", f"model.base_model.vocab_size={MODEL_TEST_VOCAB_SIZE}", f"model.multi_stage.debug_param_init={_LOG_LEVEL}", f"model.multi_stage.debug_layer_outputs={_LOG_LEVEL}", @@ -258,6 +259,7 @@ def _update_and_add_testing_config( extra_args=[ "model.base_model.transformer.head_groups=4", "model.base_model.transformer.rotary.type=default", + "model.base_model.transformer.use_position_embeddings=False", # Unused, but prevents issues with conversion tests. "model.base_model.max_position_embeddings=2048", ],