Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3ed07e5
stuff
jlamypoirier Aug 1, 2025
94bf7ac
stuff
jlamypoirier Aug 1, 2025
a9d1d56
stuff
jlamypoirier Aug 1, 2025
87988d5
stuff
jlamypoirier Aug 1, 2025
2a2b764
stuff
jlamypoirier Aug 1, 2025
c44e3b7
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 8, 2025
230551d
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 8, 2025
dfe4780
stuff
jlamypoirier Aug 8, 2025
d64e032
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 14, 2025
4561843
stuff
jlamypoirier Aug 14, 2025
ccbb38f
stuff
jlamypoirier Aug 14, 2025
b70dd19
stuff
jlamypoirier Aug 14, 2025
af990c9
stuff
jlamypoirier Aug 14, 2025
07fba17
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 14, 2025
0d2fc89
stuff
jlamypoirier Aug 14, 2025
a7cb018
stuff
jlamypoirier Aug 15, 2025
ddf3ac2
peft
jlamypoirier Aug 15, 2025
495618c
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 15, 2025
e7741b7
stuff
jlamypoirier Aug 15, 2025
385eb0d
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 15, 2025
0d3f4a6
stuff
jlamypoirier Aug 15, 2025
b6fd59e
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 20, 2025
651be5d
Reduce diff
jlamypoirier Aug 20, 2025
92da4cd
Reduce diff
jlamypoirier Aug 20, 2025
34f2a7c
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 20, 2025
11a5a21
Fix merge
jlamypoirier Aug 20, 2025
782dfa3
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 20, 2025
ce5d351
Merge branch 'block_interface' into block_interface_config
jlamypoirier Aug 20, 2025
0418be1
Fix merge
jlamypoirier Aug 20, 2025
e81d098
Merge branch 'block_interface_linear' into block_interface_config
jlamypoirier Aug 20, 2025
313b1be
Merge branch 'block_interface_linear' into block_interface_config
jlamypoirier Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/developer_guide/conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ Continuing our `AwesomeModel` handler example, we define:
def _create_weight_converters(self) -> list[WeightConverter]:
converters = []
# The set of converters may depend on the base model configuration, which is accessible through `self._model.base_model_config`.
num_layers = self._model.config.base_model.transformer.num_layers
num_layers = self._model.config.base_model.transformer.num_blocks

# A simple renaming example, for the word embeddings.
converters.append(WeightConverter("layers.0.word_embeddings_weight", "model.embed_tokens.weight"))
Expand Down
57 changes: 27 additions & 30 deletions fast_llm/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from fast_llm.core.distributed import set_generator
from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim
from fast_llm.engine.config_utils.initialization import init_normal_
from fast_llm.engine.config_utils.tensor_dim import CompositeTensorDim, ConcatenatedTensorDim, TensorDim
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.functional.autograd import wrap_forward_backward
from fast_llm.layers.attention.config import AttentionConfig, AttentionKwargs
from fast_llm.layers.block.block import BlockLayer
from fast_llm.layers.block.config import BlockConfig, BlockDimNames
from fast_llm.utils import div
from fast_llm.layers.block.peft import TransformerSubLayerName
from fast_llm.layers.common.linear import InputParallelLinear, OutputParallelLinear
from fast_llm.utils import combine_lr_scales, div

try:
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa
Expand Down Expand Up @@ -93,53 +94,49 @@ def __init__(

self._softmax_scale = self._config.kv_channels ** (-self._config.attention_softmax_scale_power)

init_method_qkv = init_normal_(
std=self._config.init_method_std_qkv,
min_val=self._config.init_method_min_qkv,
max_val=self._config.init_method_max_qkv,
)
init_method_std_attn_proj = init_normal_(
std=self._config.init_method_std_attn_proj,
min_val=self._config.init_method_min_attn_proj,
max_val=self._config.init_method_max_attn_proj,
)
lr_scale = combine_lr_scales(self._lr_scale, self._config.attention_lr_scale)

# TODO: Merge the query and key-value computations? (harder with sequence parallel.)
self.query = self._config.query_layer.get_layer(
self.query = OutputParallelLinear(
hidden_dim,
query_dim,
bias=self._block_config.add_linear_biases,
weight_init_method=self._config.query_layer.weight_initialization,
bias_init_method=self._config.query_layer.bias_initialization,
bias=self._config.add_qkv_bias,
weight_init_method=self._config.qkv_weight_initialization_method,
bias_init_method=self._config.qkv_bias_initialization_method,
sequence_parallel=self._sequence_parallel,
lr_scale=self._lr_scale,
peft=self._config.block.peft,
lr_scale=lr_scale,
)
# TODO: Separate Peft (others?) for key and value
self.key_value = self._config.query_layer.get_layer(
self.key_value = OutputParallelLinear(
hidden_dim,
query_dim,
weight_init_method=self._config.key_layer.weight_initialization,
bias_init_method=self._config.key_layer.bias_initialization,
key_value_dim,
bias=self._config.add_qkv_bias,
weight_init_method=self._config.qkv_weight_initialization_method,
bias_init_method=self._config.qkv_bias_initialization_method,
sequence_parallel=self._sequence_parallel,
lr_scale=self._lr_scale,
peft=self._config.block.peft,
lr_scale=lr_scale,
)
self._query_key_value = wrap_forward_backward(self._query_key_value_forward, self._query_key_value_backward)

# Rotary embeddings.
self._rotary = self._config.rotary.get_layer(kv_channels_dim)

# Output.
self.dense = self._config.dense_layer.get_layer(
self.dense = InputParallelLinear(
dense_dim,
hidden_dim,
weight_init_method=self._config.dense_layer.weight_initialization,
bias_init_method=self._config.dense_layer.bias_initialization,
bias=self._config.add_dense_bias,
weight_init_method=self._config.dense_weight_initialization_method,
bias_init_method=self._config.dense_bias_initialization_method,
sequence_parallel=self._sequence_parallel,
lr_scale=self._lr_scale,
peft=self._config.block.peft,
lr_scale=lr_scale,
)
# PEFT.
self.query = self._block_config.peft.apply_linear(self.query, True)

self.key_value = self._block_config.peft.apply_linear(
self.key_value, True, out_channel_begin=div(self.key_value._out_dim.global_size, 2)
)
self.dense = self._block_config.peft.apply_linear(self.dense, TransformerSubLayerName.dense)

if self._debug.enabled:
self._query_dims = (
Expand Down
117 changes: 72 additions & 45 deletions fast_llm/layers/attention/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import functools
import logging
import typing
import warnings

from fast_llm.config import Config, Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none
from fast_llm.engine.config_utils.data_type import DataType
from fast_llm.engine.config_utils.initialization import init_normal_, init_zeros_
from fast_llm.engine.config_utils.initialization import InitializationConfig, Initializer, init_normal_, init_zeros_
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.layers.attention.rotary.config import RotaryConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.common.linear.config import LinearConfig
from fast_llm.layers.block.config import BlockConfig, BlockKwargs, MixerConfig
from fast_llm.utils import Assert, div

logger = logging.getLogger(__name__)
Expand All @@ -29,29 +29,14 @@ class AttentionKwargs(BlockKwargs):
past_key_values = "past_key_values"


@config_class()
class AttentionConfig(Config):
# TODO: Make mixer class dynamic.
@config_class(dynamic_type={MixerConfig: "attention"})
class AttentionConfig(MixerConfig):
_abstract = False

# Needed for backward compatibility. TODO: remove
module_name: typing.ClassVar[str] = "attn"

# TODO: Review names
query_layer: LinearConfig = Field(
desc="Configuration for the query layer.",
hint=FieldHint.architecture,
)
key_layer: LinearConfig = Field(
desc="Configuration for the key layer.",
hint=FieldHint.architecture,
)
# TODO: Use
value_layer: LinearConfig = Field(
desc="Configuration for the value layer.",
hint=FieldHint.architecture,
)
dense_layer: LinearConfig = Field(
desc="Initialization configuration for the dense layer.",
hint=FieldHint.feature,
)
rotary: RotaryConfig = Field(
desc="Configuration for the rotary positional embeddings.",
hint=FieldHint.architecture,
Expand Down Expand Up @@ -107,35 +92,41 @@ class AttentionConfig(Config):
" Under muP (if scaling number of heads instead of kv_channels): use 0.5.",
valid=skip_valid_if_none(check_field(Assert.geq, 0)),
)
qkv_weight_initialization: InitializationConfig = Field(
desc="Initialization configuration for the query, key and value layer weights."
" Default: normal(std=hidden_size**-0.5)",
hint=FieldHint.feature,
)
qkv_bias_initialization: InitializationConfig = Field(
desc="Initialization configuration for the query, key and value layer biases. Default: fill with zeros.",
hint=FieldHint.feature,
)
dense_weight_initialization: InitializationConfig = Field(
desc="Initialization configuration for the dense layer weight."
" Default: normal(std=(2 * num_blocks * hidden_size)**-0.5)",
hint=FieldHint.feature,
)
dense_bias_initialization: InitializationConfig = Field(
desc="Initialization configuration for the dense layer biases. Default: fill with zeros.",
hint=FieldHint.feature,
)

def _validate(self) -> None:
with self._set_implicit_default():
# TODO: Make this work without inheritance.
# TODO: hidden_size not yet validated.
if self.kv_channels is None:
self.kv_channels = div(self.hidden_size, self.num_attention_heads)
# TODO: Block variables as defaults?
for layer, scale, enable_peft in (
zip(
(self.query_layer, self.key_layer, self.value_layer, self.dense_layer),
(1, 1, 1, 2 * max(self.num_blocks, 1)),
(True, False, True, False),
),
):
layer.default = LinearConfig(
bias=True,
weight_initialization=init_normal_(0, (self.hidden_size * scale) ** -0.5),
bias_initialization=init_zeros_,
lr_scale=None,
enable_peft=True,
)
super()._validate()
self.kv_channels = div(self.block.hidden_size, self.num_attention_heads)

super()._validate()

if not TritonConfig.TRITON_ENABLED:
warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.")

Assert.multiple(self.num_attention_heads, self.head_groups)
if not self.qkv_bias_initialization.is_default:
assert self.add_qkv_bias
if not self.dense_bias_initialization.is_default:
assert self.add_dense_bias

@functools.cached_property
def projection_size(self):
Expand All @@ -145,8 +136,44 @@ def projection_size(self):
def do_use_flash_attention(self, distributed_config: DistributedConfig) -> bool:
return self.use_flash_attention and distributed_config.training_dtype in (DataType.float16, DataType.bfloat16)

@functools.cached_property
def add_qkv_bias(self) -> bool:
return self.block.add_linear_biases

@functools.cached_property
def add_dense_bias(self) -> bool:
return self.block.add_linear_biases

@functools.cached_property
def qkv_weight_initialization_method(self) -> Initializer:
if self.qkv_weight_initialization.is_default:
return init_normal_(0, self.block.hidden_size**-0.5)
else:
return self.qkv_weight_initialization.get_initializer()

@functools.cached_property
def qkv_bias_initialization_method(self) -> Initializer:
if self.qkv_bias_initialization.is_default:
return init_zeros_
else:
return self.qkv_bias_initialization.get_initializer()

@functools.cached_property
def dense_weight_initialization_method(self) -> Initializer:
if self.dense_weight_initialization.is_default:
return init_normal_(0, self.block.hidden_size**-0.5 / max(2 * self.block.num_blocks, 1))
else:
return self.dense_weight_initialization.get_initializer()

@functools.cached_property
def dense_bias_initialization_method(self) -> Initializer:
if self.dense_bias_initialization.is_default:
return init_zeros_
else:
return self.dense_bias_initialization.get_initializer()


@config_class()
# TODO: Use composition instead
class TransformerConfig(AttentionConfig, BlockConfig):
_abstract = False
# TODO: Remove
class TransformerConfig(BlockConfig):
pass
41 changes: 8 additions & 33 deletions fast_llm/layers/block/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from fast_llm.engine.config_utils.tensor_dim import TensorDim
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.engine.distributed.distributed import Distributed
from fast_llm.layers.block.config import BlockConfig, BlockKwargs
from fast_llm.layers.block.config import BlockConfig, BlockKwargs, BlockLayerConfig
from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage
from fast_llm.tensor import TensorMeta

Expand Down Expand Up @@ -116,7 +116,7 @@ def __init__(
self._lr_scale = lr_scale


class BlockLayer[ConfigType: Config](BlockLayerBase[ConfigType]):
class BlockLayer[ConfigType: BlockLayerConfig](BlockLayerBase[ConfigType]):
"""
Base class for mixer and MLP modules.
"""
Expand All @@ -137,9 +137,6 @@ class Block[ConfigType: BlockConfig](BlockLayerBase[ConfigType], Layer):
A transformer-like decoder base block with abstract mixer.
"""

# TODO: Standardize to `mixer`
_mixer_module_name: typing.ClassVar[str] = "mixer"

def __init__(
self,
config: ConfigType,
Expand Down Expand Up @@ -173,9 +170,8 @@ def __init__(
# Attribute should be mixer, but Attention uses a different name for backward compatibility. TODO: Fix.
setattr(
self,
self._mixer_module_name,
self._mixer_class(
self._mixer_config,
self._config.mixer.module_name,
self._config.mixer.get_layer(
self._config,
self._distributed_config,
self._hidden_dim,
Expand All @@ -184,34 +180,13 @@ def __init__(
self._lr_scale,
),
)

# TODO: Use dynamic type.
from fast_llm.layers.block.mlp.mixture_of_experts import MixtureOfExpertMLP
from fast_llm.layers.block.mlp.mlp import MLP

self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)(
self._config,
self._config,
self._distributed_config,
self._hidden_dim,
self._block_index,
f"{self._name} MLP",
self._lr_scale,
self.mlp = self._config.mlp.get_layer(
self._distributed_config, self._hidden_dim, self._block_index, f"{self._name} mlp", self._lr_scale
)

@functools.cached_property
@abc.abstractmethod
def _mixer_class(self) -> type[BlockLayer]:
pass

@property
@abc.abstractmethod
def _mixer_config(self) -> Config:
pass

def setup(self, distributed: Distributed) -> None:
super().setup(distributed)
getattr(self, self._mixer_module_name).setup(distributed)
getattr(self, self._config.mixer.module_name).setup(distributed)
self.mlp.setup(distributed)

@torch.compile
Expand Down Expand Up @@ -241,7 +216,7 @@ def forward(
hidden_states = self.norm_1(input_)
if self._debug.enabled:
self._debug(hidden_states, "norm 1", kwargs[BlockKwargs.hidden_dims], kwargs)
hidden_states, bias = getattr(self, self._mixer_module_name)(hidden_states, kwargs)
hidden_states, bias = getattr(self, self._config.mixer.module_name)(hidden_states, kwargs)
if self._debug.enabled:
self._debug(
hidden_states if bias is None else hidden_states + bias,
Expand Down
Loading