From 50083ba88a0bfa58747d2bc8307814b62af1a79a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 15:14:13 -0400 Subject: [PATCH 01/10] SSM debugging --- Megatron-LM | 2 +- fast_llm/engine/multi_stage/stage_base.py | 2 + fast_llm/layers/language_model/head.py | 16 ++- fast_llm/layers/ssm/config.py | 34 +++--- fast_llm/layers/ssm/discrete_mamba2.py | 23 ++-- fast_llm/layers/ssm/llamba_block.py | 29 +++-- fast_llm/layers/ssm/mamba2.py | 38 ++++-- fast_llm/layers/ssm/mamba_layer.py | 36 +++--- fast_llm/layers/transformer/attention.py | 72 +++-------- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 10 +- fast_llm/layers/transformer/transformer.py | 94 ++++++++++++--- fast_llm/logging.py | 2 + fast_llm/models/gpt/model.py | 12 +- fast_llm/models/ssm/config.py | 40 +++---- fast_llm/models/ssm/model.py | 113 +++++------------- setup.cfg | 7 +- tests/data/test_blending.py | 1 + tests/data/test_concatenate.py | 1 + tests/data/test_fim.py | 2 + tests/test_attention.py | 4 +- tests/test_multi_stage.py | 8 +- tests/utils/model_configs.py | 1 + 23 files changed, 271 insertions(+), 282 deletions(-) diff --git a/Megatron-LM b/Megatron-LM index 511e8f5cb..75b0d9787 160000 --- a/Megatron-LM +++ b/Megatron-LM @@ -1 +1 @@ -Subproject commit 511e8f5cbe3ab8291953ac64e5beceb727a1b814 +Subproject commit 75b0d97876006c4b6b23fce302100d18dbf7db37 diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 2f18f1360..9a8ce2092 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -191,6 +191,8 @@ def initialize_weights(self) -> None: # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) meta.init_parameter(global_param, distributed=self._distributed) + # It happens. + Assert.eq(global_param.shape, global_shape) if self._mode.on_device: parameter.copy_(fsdp.parameter_global_to_shard(global_param, meta.tensor_name)) elif self._mode.on_device: diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 25fc2b28d..21bf3bbd0 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -125,12 +125,16 @@ def forward( self, input_: torch.Tensor, kwargs: dict, losses: dict | None = None, metrics: dict | None = None ) -> torch.Tensor: if isinstance(input_, TensorMeta): - return TensorMeta.from_tensor_space( - (DefaultDimNames.scalar,), - self._tensor_space, - tensor_name="Loss", - reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa - ) + if self._is_last_head: + return TensorMeta.from_tensor_space( + (DefaultDimNames.scalar,), + self._tensor_space, + tensor_name="Loss", + reductions=((DistributedDimNames.data, ReduceOp.AVG),), # noqa + ) + else: + return TensorMeta.from_dims(input_.dims[1:], tensor_name="Shared hidden") + if not self._is_last_head: # MTP: split the stacked input shared_hidden, input_ = torch.unbind(input_, dim=0) diff --git a/fast_llm/layers/ssm/config.py b/fast_llm/layers/ssm/config.py index 46d629aa8..a1f357de9 100644 --- a/fast_llm/layers/ssm/config.py +++ b/fast_llm/layers/ssm/config.py @@ -23,6 +23,7 @@ class SSMDimNames: # Mamba 2 x_proj_dim_2 = "x_proj_dim_2" # d_xb + c_heads = "c_heads" class SSMBlockType(enum.StrEnum): @@ -35,6 +36,22 @@ class SSMBlockType(enum.StrEnum): mamba2 = "m2" transformer = "t" + def get_mixer_class(self): + if self == SSMBlockType.mamba: + from fast_llm.layers.ssm.mamba_layer import MambaLayer + + return MambaLayer + elif self == SSMBlockType.mamba2: + from fast_llm.layers.ssm.mamba2 import Mamba2 + + return Mamba2 + elif self == SSMBlockType.mamba2_discrete: + from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 + + return DiscreteMamba2 + else: + raise NotImplementedError(self) + @config_class() class SSMConfig(LLMBlockConfig): @@ -95,11 +112,6 @@ class SSMConfig(LLMBlockConfig): desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.", hint=FieldHint.architecture, ) - debug_ssm: bool = Field( - default=False, - desc="debug_ssm", - hint=FieldHint.optional, - ) dt_min: float = Field( default=0.001, desc="Minimum step size for discretization", @@ -147,18 +159,6 @@ class SSMConfig(LLMBlockConfig): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - dt_min: float = Field( - default=0.001, - desc="Minimum step size for discretization", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) - dt_init_floor: float = Field( - default=1e-4, - desc="Minimum value for initializing dt", - hint=FieldHint.core, - valid=check_field(Assert.gt, 0), - ) dt_scale: float = Field( default=1.0, desc="Scale for dt", diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 934cd2b5d..734e35b21 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -1,5 +1,6 @@ import logging import math +import typing import einops import torch @@ -7,7 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames -from fast_llm.layers.transformer.config import TransformerKwargs +from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -36,29 +38,29 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class DiscreteMamba2(torch.nn.Module): +class DiscreteMamba2(Mixer): """DiscreteMamba2 (This code is adapted from https://github.com/cartesia-ai/edge/blob/main/cartesia-pytorch/cartesia_pytorch/Llamba/mixers/discrete_mamba2.py).""" + _mixer_name: typing.ClassVar[str] = "discrete_mamba_2" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): """ See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Other options are all experimental and should not need to be configured. """ # factory_kwargs = {"device": "meta"} # , "dtype": torch.bfloat16} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) - logger.info(f"Setting lr_scale for layer {layer_idx} of type {type(self)}: {mamba_layer_lr_scale}") + logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) @@ -226,9 +228,6 @@ def forward(self, hidden_states, kwargs): out = self.out_proj(y * torch.nn.functional.silu(z + self.z_bias)) outputs["hidden_states"] = out[:, :seqlen, :].contiguous() - if self._return_input: - return torch.stack([input_, outputs["hidden_states"]], dim=0) - # TODO: since we do not support inference for now, we only return the hidden states for now. return outputs["hidden_states"], None diff --git a/fast_llm/layers/ssm/llamba_block.py b/fast_llm/layers/ssm/llamba_block.py index ee222d6d2..986606634 100644 --- a/fast_llm/layers/ssm/llamba_block.py +++ b/fast_llm/layers/ssm/llamba_block.py @@ -1,6 +1,6 @@ import typing -from fast_llm.layers.transformer.transformer import BaseBlock +from fast_llm.layers.transformer.transformer import BaseBlock, Mixer if typing.TYPE_CHECKING: from fast_llm.engine.config_utils.tensor_space import TensorSpace @@ -8,27 +8,30 @@ from fast_llm.layers.transformer.config import TransformerConfig -class LlambaBlock(BaseBlock): +class SSMBlock(BaseBlock): """ A transformer-like decoder block with a SSM mixer, see https://arxiv.org/abs/2502.14458 """ _name = "Llamba block" - _mixer_module_name = "mixer" def __init__( self, - config_transformer: "TransformerConfig", - config_ssm: "SSMConfig", + transformer_config: "TransformerConfig", + ssm_config: "SSMConfig", tensor_space: "TensorSpace", - mixer_cls, - layer_index: int, + mixer_cls: type[Mixer], + block_index: int, return_input: bool = False, ): - self.mixer_cls = mixer_cls - self._config_ssm = config_ssm - self._debug_mode = self._config_ssm.debug_ssm - super().__init__(config_transformer, tensor_space, layer_index, return_input) + self._ssm_config = ssm_config + self._mixer_cls = mixer_cls + super().__init__(transformer_config, tensor_space, block_index, return_input) - def _create_mixer(self): - self.mixer = self.mixer_cls(self._config_ssm, layer_idx=self._layer_index, tensor_space=self._tensor_space) + def _create_mixer(self) -> Mixer: + return self._mixer_cls( + self._ssm_config, + tensor_space=self._tensor_space, + block_index=self._block_index, + transformer_config=self._config, + ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index a03509abb..ead32fa2a 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -43,24 +45,36 @@ def bias_init_method(conv_weight): return init_uniform_(-bound, bound) -class Mamba2(torch.nn.Module): +class Mamba2(Mixer): """ This code is adapted from https://github.com/jxiw/M1/blob/537a1ca5407a786a99dc6c721873493cf8750d5e/mamba/hybrid_mamba_layer.py """ + _mixer_name: typing.ClassVar[str] = "mamba_2" + + _XZ_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.inner_dim, + TransformerDimNames.sequence_q, + ) + _BC_DIMS = ( + TransformerDimNames.batch, + SSMDimNames.c_heads, + SSMDimNames.state_dim, + TransformerDimNames.sequence_q, + ) + def __init__( self, config: SSMConfig, - layer_idx: int, tensor_space: TensorSpace, - return_input: bool = False, + block_index: int, + transformer_config: TransformerConfig, ): - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config bias: bool = config.add_bias_linear - self.layer_idx = layer_idx - self._return_input = return_input - layer_lr_scale: float | None = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale: float | None = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale: float | tuple[float | None, ...] | None = get_lr_scale( self.config.mamba_lr_scale, layer_lr_scale ) @@ -236,6 +250,13 @@ def forward(self, hidden_states, kwargs): x = repeat_kv(x, self.repeat_group) x = einops.rearrange(x, "b n_group l dstate -> b (n_group dstate) l") + if self._debug_level: + self._debug_log(z, "z", self._XZ_DIMS, kwargs) + self._debug_log(x, "x", self._XZ_DIMS, kwargs) + self._debug_log(B, "b", self._BC_DIMS, kwargs) + self._debug_log(C, "c", self._BC_DIMS, kwargs) + self._debug_log(dt, "dt", self._XZ_DIMS, kwargs) + y = selective_scan_fn( x, dt, @@ -249,6 +270,9 @@ def forward(self, hidden_states, kwargs): return_last_state=False, ) + if self._debug_level: + self._debug_log(y, "y", self._XZ_DIMS, kwargs) + if ssm_state is not None: y, last_state = y ssm_state.copy_(einops.rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head)) diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 7c824d235..a95e94c03 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -1,4 +1,5 @@ import math +import typing from typing import Callable import einops @@ -7,6 +8,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.transformer.config import TransformerConfig +from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ from fast_llm.utils import get_lr_scale @@ -44,12 +47,12 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_dtprojbias( - d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float, factory_kwargs: dict + d_inner: int, dt_max: float, dt_min: float, dt_init_floor: float ) -> Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - dt = torch.exp( - torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) - ).clamp(min=dt_init_floor) + dt = torch.exp(torch.rand(d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( + min=dt_init_floor + ) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) tensor.copy_(inv_dt) @@ -58,20 +61,18 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) return init_ -class MambaLayer(torch.nn.Module): +class MambaLayer(Mixer): + _mixer_name: typing.ClassVar[str] = "mamba" + def __init__( self, config: SSMConfig, - layer_idx: int, + block_index: int, tensor_space: TensorSpace, - return_input: bool = False, + transformer_config: TransformerConfig, ): - factory_kwargs = {} - super().__init__() + super().__init__(tensor_space, block_index, debug_level=transformer_config.debug_transformer) self.config: SSMConfig = config - self.layer_idx = layer_idx - - self._debug_mode = config.debug_ssm # Tensor dims: td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) @@ -88,7 +89,7 @@ def __init__( self.d_state = td_state.size self.d_model = td_model.size self.dt_rank = tdt_rank.size - layer_lr_scale = config.per_layer_lr_scale[layer_idx] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) self.in_proj_weight = ParameterMeta.from_dims( @@ -113,7 +114,6 @@ def __init__( weight_init_method=kaiming_init_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.x_proj.weight.auto_grad_accumulation = True @@ -127,7 +127,7 @@ def __init__( self.dt_proj_bias = ParameterMeta.from_dims( (td_inner,), init_method=init_dtprojbias( - self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor, factory_kwargs + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor ), lr_scale=mamba_layer_lr_scale, ) @@ -153,10 +153,8 @@ def __init__( bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, - **factory_kwargs, ) self.out_proj.weight.auto_grad_accumulation = True - self._return_input = return_input def forward(self, hidden_states, kwargs): assert _mamba_available @@ -168,8 +166,6 @@ def forward(self, hidden_states, kwargs): "d (b l) -> b d l", l=seqlen, ) - if self._debug_mode: - print("XZ: ", xz.shape) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat @@ -189,6 +185,4 @@ def forward(self, hidden_states, kwargs): delta_bias=self.dt_proj_bias.float(), delta_softplus=True, ) - if self._return_input: - out = torch.stack((hidden_states, out), dim=0) return out, None diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 3351c9906..174e19588 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -13,9 +13,9 @@ TransformerKwargs, TransformerSubLayerName, ) -from fast_llm.logging import log_distributed_grad, log_distributed_tensor -from fast_llm.tensor import TensorMeta, init_normal_, init_zeros_ -from fast_llm.utils import Assert, get_lr_scale +from fast_llm.layers.transformer.transformer import Mixer +from fast_llm.tensor import init_normal_, init_zeros_ +from fast_llm.utils import get_lr_scale try: from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func # noqa @@ -50,11 +50,13 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple[torch.Tensor, None]: # no return grad, None -class Attention(torch.nn.Module): +class Attention(Mixer): """ A self-attention layer. """ + _mixer_name: typing.ClassVar[str] = "attn" + _QUERY_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, @@ -64,7 +66,7 @@ class Attention(torch.nn.Module): _KV_DIMS = ( TransformerDimNames.batch, TransformerDimNames.sequence_q, - TransformerDimNames.group_heads, + TransformerDimNames.head_groups, TransformerDimNames.kv_channels, ) _CONTEXT_DIMS = ( @@ -73,19 +75,9 @@ class Attention(torch.nn.Module): TransformerDimNames.composite_dense, ) - def __init__( - self, - config: TransformerConfig, - tensor_space: TensorSpace, - layer_index, - ): - super().__init__() + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int): + super().__init__(tensor_space, block_index, config.debug_transformer) self._config = config - self._tensor_space = tensor_space - # Assert.in_range_incl(layer_index, 1, max(self._config.num_layers, 1)) - self._layer_index = layer_index - self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel - self._debug_transformer = self._config.debug_transformer self._use_flash_attention = self._config.do_use_flash_attention(self._tensor_space.distributed_config) init_method_qkv = init_normal_( @@ -108,7 +100,7 @@ def __init__( hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) # TODO: Merge the query and key-value computations? (harder with sequence parallel.) @@ -178,10 +170,10 @@ def _attn_fused( query, key, beta=0, - alpha=self._softmax_scale / self._layer_index, + alpha=self._softmax_scale / self._block_index, ).view(b, self._local_head_groups, sq, self._local_heads_per_group, sk) - attn_weights = attn_weights.to(torch.float32) * self._layer_index + attn_weights = attn_weights.to(torch.float32) * self._block_index attn_weights = torch.where(mask, attn_weights, mask_value) attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1).to(query.dtype) @@ -200,40 +192,6 @@ def _attn_fused( .flatten(2) ) - def _get_meta( - self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> TensorMeta: - hidden_dims = {dim.name: dim for dim in kwargs[TransformerKwargs.hidden_dims]} - return TensorMeta.from_dims( - tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) - for dim_name in dim_names - ), - tensor_name=f"transformer layer {self._layer_index} attn {name}", - dtype=input_.dtype, - ) - - def _debug_log( - self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] - ) -> None: - # TODO: Local vs global - Assert.gt(self._debug_transformer, 0) - log_distributed_tensor( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name, dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - if tensor.requires_grad: - log_distributed_grad( - "", - tensor, - level=self._debug_transformer, - meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), - distributed=self._tensor_space.distributed, - ) - def _query_key_value_forward( self, input_: torch.Tensor, sequence_first: bool ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: @@ -300,7 +258,7 @@ def _decide_window_size(self) -> int | None: # https://github.com/huggingface/transformers/blob/5e2183f344911aa82aba0b83778a4f196cff378e/src/transformers/models/qwen2/modular_qwen2.py#L71 # TODO: make universal per layer config window_size = self._config.window_size - if self._config.max_window_layers is not None and self._layer_index < self._config.max_window_layers: + if self._config.max_window_layers is not None and self._block_index < self._config.max_window_layers: window_size = None return window_size @@ -341,7 +299,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ key = key.view(*key.shape[:2], self._local_head_groups, self._kv_channels) value = value.view(*value.shape[:2], self._local_head_groups, self._kv_channels) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query_rotary_input", self._QUERY_DIMS, kwargs) self._debug_log( key, @@ -395,7 +353,7 @@ def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[ kwargs[TransformerKwargs.attention_mask_value], ) - if self._debug_transformer: + if self._debug_level: self._debug_log(query, "query", self._QUERY_DIMS, kwargs) self._debug_log( key, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index a46af1387..73f83ccf5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -40,11 +40,11 @@ class MixtureOfExpertMLP(MLPBase): _group: ProcessGroup - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.gt(config.num_experts, 1) # TODO: Implement? assert not config.add_linear_biases, "Biases not supported for MoE." - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) self._config = config self._tensor_space = tensor_space self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory @@ -59,7 +59,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._z_loss_factor = config.expert_z_loss_coefficient self._moe_jitter_eps = config.moe_jitter_eps - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index b01eb2aa5..efe0c5cc5 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -14,10 +14,10 @@ class MLPBase(Layer, ABC): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): super().__init__() self._name = name - self._layer_index = layer_index + self._block_index = block_index init_method_1 = init_normal_( std=config.init_method_std_mlp_1, @@ -39,7 +39,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s self._activation_type = config.activation_type self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation - layer_lr_scale = config.per_layer_lr_scale[layer_index] if config.per_layer_lr_scale else None + layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None lr_scale = tuple(config.mlp_lr_scale) if isinstance(config.mlp_lr_scale, list) else config.mlp_lr_scale lr_scale = get_lr_scale(lr_scale, layer_lr_scale) @@ -69,9 +69,9 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s class MLP(MLPBase): - def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", layer_index: int = 0): + def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: str = "mlp", block_index: int = 0): Assert.eq(config.num_experts, 1) - super().__init__(config, tensor_space, name, layer_index) + super().__init__(config, tensor_space, name, block_index) def forward( self, diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index 147452073..d08db9a94 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -8,25 +8,85 @@ from fast_llm.engine.base_model.base_model import Layer from fast_llm.engine.config_utils.run import log_pipeline_parallel_main_rank from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace -from fast_llm.layers.transformer.attention import Attention from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames, TransformerKwargs from fast_llm.layers.transformer.mixture_of_experts import MixtureOfExpertMLP from fast_llm.layers.transformer.mlp import MLP from fast_llm.logging import log_distributed_grad, log_distributed_tensor, log_memory_usage from fast_llm.tensor import TensorMeta +from fast_llm.utils import Assert logger = logging.getLogger(__name__) +class Mixer(torch.nn.Module, abc.ABC): + """ + Base class for mixer modules. + """ + + _mixer_name: typing.ClassVar[str] + + def __init__(self, tensor_space: TensorSpace, block_index: int, debug_level: int = 0): + super().__init__() + self._tensor_space = tensor_space + self._sequence_parallel = self._tensor_space.distributed_config.sequence_tensor_parallel + self._block_index = block_index + self._debug_level = debug_level + + @abc.abstractmethod + def forward(self, input_: torch.Tensor, kwargs: dict[str, typing.Any]) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Mixer module forward. Returns the output hidden states and an optional bias, + in case its addition can be made more efficient in `_bias_dropout_add`. + """ + + def _get_meta( + self, input_: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> TensorMeta: + hidden_dims = { + dim.name: dim + for dim in kwargs[TransformerKwargs.hidden_dims] + (kwargs[TransformerKwargs.sequence_q_dim],) + } + return TensorMeta.from_dims( + tuple( + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + for dim_name in dim_names + ), + tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", + dtype=input_.dtype, + ) + + def _debug_log( + self, tensor: torch.Tensor, name: str, dim_names: tuple[str, ...], kwargs: dict[str, typing.Any] + ) -> None: + # TODO: Local vs global + Assert.gt(self._debug_level, 0) + log_distributed_tensor( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name, dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + if tensor.requires_grad: + log_distributed_grad( + "", + tensor, + level=self._debug_level, + meta=self._get_meta(tensor, name + " grad", dim_names, kwargs), + distributed=self._tensor_space.distributed, + ) + + class BaseBlock(Layer, abc.ABC): """ A transformer-like decoder base block with abstract mixer. """ - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "mixer" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): super().__init__() self._config: TransformerConfig = config @@ -35,18 +95,19 @@ def __init__( # For multi-token prediction, return a stack of shared_hidden and transformer_output. self._return_input: bool = return_input - self._layer_index = layer_index + self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) # Note, layer_lr_scale does not impact the norms - # TODO: add a seperate norm_lr_scale + # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) self.norm_2 = self._config.normalization.get_layer(hidden_dim) - self._create_mixer() + # The mixer needs to be created here for backward-compatible weight ordering. + setattr(self, self._mixer_module_name, self._create_mixer()) self.mlp = (MixtureOfExpertMLP if self._config.num_experts > 1 else MLP)( - self._config, self._tensor_space, f"{self.name} mlp", layer_index=layer_index + self._config, self._tensor_space, f"{self.name} mlp", block_index=block_index ) # PEFT. @@ -54,7 +115,7 @@ def __init__( self.norm_2 = self._config.peft.apply_other(self.norm_2) @abc.abstractmethod - def _create_mixer(self): + def _create_mixer(self) -> Mixer: pass @torch.compile @@ -67,7 +128,7 @@ def _bias_dropout_add( @property def name(self) -> str: - return f"{self._name} {self._layer_index}" + return f"{self._name} {self._block_index}" def _get_meta(self, tensor: torch.Tensor, name: str, kwargs: dict): dims = kwargs[TransformerKwargs.hidden_dims] @@ -137,14 +198,17 @@ def forward( return hidden_states -class TransformerLayer(BaseBlock): +class TransformerBlock(BaseBlock): _name = "Transformer layer" - _mixer_module_name = "self_attn" + # TODO: Standardize to `mixer` + _mixer_module_name: typing.ClassVar[str] = "self_attn" def __init__( - self, config: TransformerConfig, tensor_space: TensorSpace, layer_index: int, return_input: bool = False + self, config: TransformerConfig, tensor_space: TensorSpace, block_index: int, return_input: bool = False ): - super().__init__(config, tensor_space, layer_index, return_input) + super().__init__(config, tensor_space, block_index, return_input) + + def _create_mixer(self) -> Mixer: + from fast_llm.layers.transformer.attention import Attention - def _create_mixer(self): - self.self_attn = Attention(self._config, self._tensor_space, self._layer_index) + return Attention(self._config, self._tensor_space, self._block_index) diff --git a/fast_llm/logging.py b/fast_llm/logging.py index e8334de6e..6d555a0bb 100644 --- a/fast_llm/logging.py +++ b/fast_llm/logging.py @@ -138,6 +138,8 @@ def log_tensor[ if level < 1: return tensor = tensor.detach() + if tensor.ndim == 0: + tensor = tensor[None] save_stats = TensorLogs.config.save shape = tuple(tensor.shape) _, dtype = str(tensor.dtype).split("torch.") diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 444ad72b2..4c1eab46f 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -21,7 +21,7 @@ TransformerLossNames, ) from fast_llm.layers.transformer.preprocessing import BackupAttentionPreprocessor, FlashAttnVarlenPreprocessor -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, GPTModelConfig from fast_llm.models.gpt.megatron import get_init_megatron from fast_llm.tensor import ParameterMeta, TensorMeta @@ -68,11 +68,11 @@ def get_output_layers(self) -> list[Layer]: for i in range(self._config.prediction_heads): if i > 0: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, # TODO MTP: which index? - layer_index=max(self._config.transformer.num_layers + i, 1), + block_index=max(self._config.transformer.num_layers + i, 1), # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=i < self._config.prediction_heads - 1, @@ -91,10 +91,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, # The last layer only returns the transformer output. # The previous layers return a stack of shared_hidden and transformer_output. return_input=self._config.prediction_heads > 1 and i == self._config.transformer.num_layers - 1, @@ -336,7 +336,7 @@ def embedding(self) -> LanguageModelEmbedding: return self.layers[0] @property - def transformer_layers(self) -> list[TransformerLayer]: + def transformer_layers(self) -> list[TransformerBlock]: return self.layers[1:-1] @property diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index cc83f11be..9ca0123b2 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -9,9 +9,8 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.multi_stage.config import FastLLMModelConfig, PretrainedFastLLMModelConfig from fast_llm.engine.training.config import TrainerConfig -from fast_llm.layers.language_model.config import LanguageModelBaseConfig from fast_llm.layers.ssm.config import SSMBlockType, SSMConfig, SSMDimNames -from fast_llm.models.gpt.config import GPTBatchConfig, PretrainedGPTModelConfig +from fast_llm.models.gpt.config import GPTBaseModelConfig, GPTBatchConfig, PretrainedGPTModelConfig from fast_llm.utils import Assert if typing.TYPE_CHECKING: @@ -24,14 +23,14 @@ @config_class() -class HybridSSMBaseModelConfig(LanguageModelBaseConfig): +class HybridSSMBaseModelConfig(GPTBaseModelConfig): _abstract = False ssm: SSMConfig = Field( desc="Configuration for the transformer architecture.", hint=FieldHint.architecture, ) - hybrid_block_layout: list[str] | None = Field( + hybrid_block_layout: list[SSMBlockType] | None = Field( default=None, desc=f"Pattern of blocks to use in the model. Available types: {SSMBlockType.__members__.values()}", hint=FieldHint.core, @@ -41,9 +40,8 @@ class HybridSSMBaseModelConfig(LanguageModelBaseConfig): desc="Multi-token prediction mixer to use in the model. If None, will use the last block type in `hybrid_block_layout`.", hint=FieldHint.optional, ) - use_megatron_initialization: bool = Field( - default=False, desc="Exactly match the initialization of a Megatron model.", hint=FieldHint.testing - ) # TODO: is this needed? + # TODO: Support combination of different SSM block types. + ssm_block_type: SSMBlockType | None = Field(init=False) def setup_tensor_space(self, tensor_space: TensorSpace) -> None: """ @@ -83,6 +81,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) + tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) def _validate(self): with self._set_implicit_default(None): @@ -96,30 +95,21 @@ def _validate(self): if self.hybrid_block_layout is None: with self._set_implicit_default(): - self.hybrid_block_layout = [SSMBlockType.mamba2_discrete.value] + self.hybrid_block_layout = [SSMBlockType.mamba2_discrete] * self.transformer.num_layers if len(self.hybrid_block_layout) != self.transformer.num_layers: + message = f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" if self.transformer.num_layers % len(self.hybrid_block_layout) != 0: - raise ValueError( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}" - ) - num_repeats = int(self.transformer.num_layers // len(self.hybrid_block_layout)) - logger.warning( - f"hybrid_block_layout length {len(self.hybrid_block_layout)} does not match num_layers {self.transformer.num_layers}, will repeat {self.hybrid_block_layout} {num_repeats} times" - ) + raise ValueError(message) + num_repeats = self.transformer.num_layers // len(self.hybrid_block_layout) + logger.warning(f"{message}, will repeat {self.hybrid_block_layout} {num_repeats} times.") self.hybrid_block_layout = self.hybrid_block_layout * num_repeats - Assert.eq(len(self.hybrid_block_layout), self.transformer.num_layers) - Assert.custom( - lambda _: all(block_type in SSMBlockType.__members__.values() for block_type in self.hybrid_block_layout), - f"Invalid block type: {self.hybrid_block_layout}. Must be one of {SSMBlockType.__members__.values()}", - ) - Assert.custom( - lambda _: self.default_mtp_type in SSMBlockType.__members__.values() or self.default_mtp_type is None, - f"Invalid MTP type: {self.default_mtp_type}. Must be one of {SSMBlockType.__members__.values()} or None", - ) - super()._validate() + ssm_block_types = set(self.hybrid_block_layout) - {SSMBlockType.transformer} + # TODO: Support combination of different SSM block types. + Assert.leq(len(ssm_block_types), 1) + self.ssm_block_type = ssm_block_types.pop() if ssm_block_types else None class LLambaHuggingfaceCheckpointFormat(CheckpointFormat): diff --git a/fast_llm/models/ssm/model.py b/fast_llm/models/ssm/model.py index 02a5ac239..89f0cd4aa 100644 --- a/fast_llm/models/ssm/model.py +++ b/fast_llm/models/ssm/model.py @@ -5,11 +5,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding from fast_llm.layers.language_model.head import LanguageModelHead -from fast_llm.layers.ssm.discrete_mamba2 import DiscreteMamba2 -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.ssm.mamba2 import Mamba2 -from fast_llm.layers.ssm.mamba_layer import MambaLayer -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.gpt.model import GPTBaseModel, GPTModel from fast_llm.models.ssm.config import HybridSSMBaseModelConfig, HybridSSMModelConfig, SSMBlockType @@ -31,7 +28,6 @@ def __init__( config: HybridSSMBaseModelConfig, distributed_config: DistributedConfig, ): - self.SSM_BLOCK_CLS = LlambaBlock # TODO: extend to other block types if needed super().__init__(config, distributed_config) def get_output_layers(self) -> list[Layer]: @@ -39,52 +35,31 @@ def get_output_layers(self) -> list[Layer]: Get the output layers of the model. This includes the language model head and any additional heads specified in the configuration. """ - layers = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] + layers: list[Layer] = [LanguageModelHead(self._config, self._tensor_space, prediction_distance=0)] if self._config.prediction_heads > 1: block_type = self._config.default_mtp_type or self._config.hybrid_block_layout[-1] for i in range(1, self._config.prediction_heads): if block_type == SSMBlockType.transformer: layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=len(self._config.hybrid_block_layout), + block_index=len(self._config.hybrid_block_layout), return_input=i != self._config.prediction_heads - 1, ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=len(self._config.hybrid_block_layout), - tensor_space=self._tensor_space, - return_input=i != self._config.prediction_heads - 1, - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=len(self._config.hybrid_block_layout), + tensor_space=self._tensor_space, + return_input=i != self._config.prediction_heads - 1, + ) + ) layers.append(LanguageModelHead(self._config, self._tensor_space, prediction_distance=i)) return layers @@ -94,63 +69,35 @@ def get_layers(self) -> list[Layer]: Create a list of layers for the model, interleaving Transformer and Mamba blocks according to the block pattern. """ - layers = [LanguageModelEmbedding(self._config, self._tensor_space)] + layers: list[Layer] = [LanguageModelEmbedding(self._config, self._tensor_space)] # Create blocks according to pattern for i, block_type in enumerate(self._config.hybrid_block_layout): if block_type == SSMBlockType.transformer: # Transformer block layers.append( - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, return_input=( i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 ), ) ) - elif block_type == SSMBlockType.mamba2_discrete: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=DiscreteMamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba: - # Create Mamba block - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=MambaLayer, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) - - elif block_type == SSMBlockType.mamba2: - mamba_block = self.SSM_BLOCK_CLS( - config_transformer=self._config.transformer, - config_ssm=self._config.ssm, - mixer_cls=Mamba2, - layer_index=i + 1, - tensor_space=self._tensor_space, - return_input=( - i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 - ), - ) - layers.append(mamba_block) else: - raise ValueError(f"Invalid block type: {block_type}. Must be {SSMBlockType.__members__}") + layers.append( + SSMBlock( + transformer_config=self._config.transformer, + ssm_config=self._config.ssm, + mixer_cls=self._config.ssm_block_type.get_mixer_class(), + block_index=i + 1, + tensor_space=self._tensor_space, + return_input=( + i == len(self._config.hybrid_block_layout) - 1 and self._config.prediction_heads > 1 + ), + ) + ) # Add the output layers layers += self.get_output_layers() diff --git a/setup.cfg b/setup.cfg index 843aa15ca..c086af7d0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,14 +48,9 @@ HUGGINGFACE = # Required to run SSMs # To install on cpu environment (ex. for IDE support): -# MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[CORE,SSM]" --no-build-isolation +# MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install -e ".[SSM]" --no-build-isolation SSM = mamba_ssm[causal-conv1d]==2.2.4 - cartesia_pytorch>=0.0.2 - -GENERATION = - lm_eval>=0.4.9 - DEV = # Pre-commit git hook diff --git a/tests/data/test_blending.py b/tests/data/test_blending.py index 438782dfe..3e6c37632 100644 --- a/tests/data/test_blending.py +++ b/tests/data/test_blending.py @@ -193,6 +193,7 @@ def test_gpt_blended_mixed(): def test_gpt_blended_mixed_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_concatenate.py b/tests/data/test_concatenate.py index e951cc2b1..4f36cdf89 100644 --- a/tests/data/test_concatenate.py +++ b/tests/data/test_concatenate.py @@ -39,6 +39,7 @@ def test_gpt_concatenate(): def test_gpt_concatenate_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { diff --git a/tests/data/test_fim.py b/tests/data/test_fim.py index 7472f1958..004b96289 100644 --- a/tests/data/test_fim.py +++ b/tests/data/test_fim.py @@ -58,6 +58,7 @@ def test_gpt_fim(): def test_gpt_fim_data(): + get_test_dataset() get_test_data_and_compare_samples( { "datasets": { @@ -81,6 +82,7 @@ def test_gpt_fim_data(): def test_gpt_fim_data_legacy(): + get_test_dataset() get_test_data_and_compare_samples( { "format": "list", diff --git a/tests/test_attention.py b/tests/test_attention.py index 87b0d3e59..dd36b840a 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -17,12 +17,12 @@ def test_decide_window_size(): # Arrange - Case 1: window_size is returned (layer_index >= max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 2 + attention._block_index = 2 assert attention._decide_window_size() == 512 # Arrange - Case 2: window_size is None (layer_index < max_window_layers) attention._config = TransformerConfig(window_size=512, max_window_layers=2) - attention._layer_index = 1 + attention._block_index = 1 assert attention._decide_window_size() is None # Arrange - Case 3: max_window_layers is None (always return window_size) diff --git a/tests/test_multi_stage.py b/tests/test_multi_stage.py index c530a170c..2f125717e 100644 --- a/tests/test_multi_stage.py +++ b/tests/test_multi_stage.py @@ -3,9 +3,10 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.training.config import TrainerConfig from fast_llm.engine.training.trainer import Trainer -from fast_llm.layers.ssm.llamba_block import LlambaBlock -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.ssm.llamba_block import SSMBlock +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.utils import Assert +from tests.utils.dataset import get_model_test_dataset from tests.utils.model_configs import ModelTestingGroup from tests.utils.utils import requires_cuda @@ -23,6 +24,7 @@ def _get_trainer_from_args(args: list[str], model_type: str = "gpt") -> Trainer: @requires_cuda @pytest.mark.model_testing_group(ModelTestingGroup.basic) def test_frozen_weights(model_testing_config): + get_model_test_dataset() args = model_testing_config.config_args + ["run.tensor_logs.save=False"] model_ref = _get_trainer_from_args(args, model_testing_config.model_type)._multi_stage model_frozen = _get_trainer_from_args( @@ -39,7 +41,7 @@ def test_frozen_weights(model_testing_config): model_frozen._num_stages, ) frozen_parameter_counts = [ - sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerLayer, LlambaBlock)) else 0 + sum(p.numel() for p in layer.mlp.parameters()) if isinstance(layer, (TransformerBlock, SSMBlock)) else 0 for layer in model_ref.base_model.layers ] for weight_buffer_ref, weight_buffer_frozen in zip( diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 1eee3675d..42252c620 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,6 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", + f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 7b32699be7c1a1fb29cc7386eb33280b0bc19a5c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 17:28:56 -0400 Subject: [PATCH 02/10] stuff --- fast_llm/layers/ssm/mamba2.py | 57 ++++++++++++++--------------------- fast_llm/models/ssm/config.py | 2 +- tests/utils/model_configs.py | 2 +- 3 files changed, 24 insertions(+), 37 deletions(-) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index ead32fa2a..b936ccf14 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -7,6 +7,7 @@ from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames +from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ @@ -97,9 +98,9 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("1", 1), td_conv_kernel), + (td_inner, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_inner.size * td_conv_kernel.size), + -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 lr_scale=mamba_layer_lr_scale, @@ -110,9 +111,9 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, TensorDim("1", 1), td_conv_kernel), + (td_xb, td_conv_kernel), init_method=init_uniform_( - 1 / math.sqrt(td_xb.size * td_conv_kernel.size), + -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), ), ) @@ -133,7 +134,13 @@ def __init__( weight_init_method=kaiming_init_(td_model.size), lr_scale=mamba_layer_lr_scale, ) - + self.dt_in_proj = Linear( + td_model, + tdt_rank, + bias=config.add_bias_linear, + weight_init_method=kaiming_init_(transformer_config.hidden_size), + lr_scale=mamba_layer_lr_scale, + ) # Initialize special dt projection to preserve variance at initialization dt_scale = config.dt_scale # 1.0 dt_init_std = self.dt_rank**-0.5 * dt_scale @@ -144,24 +151,6 @@ def __init__( else: raise NotImplementedError - # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - dt_max = config.dt_max # or 0.1 - dt_min = config.dt_min # or 0.001 - dt_init_floor = config.dt_init_floor # or 1e-4 - dt = torch.exp(torch.rand(self.d_inner) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)).clamp( - min=dt_init_floor - ) - # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - inv_dt = dt + torch.log(-torch.expm1(-dt)) - - def init_from_tensor_( - value: torch.Tensor, - ) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.copy_(value) - - return init_ - self.dt_proj = Linear( tdt_rank, td_inner, @@ -171,18 +160,16 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) ) # define bias outside the linear layer since its also used in the selective_scan_fn self.dt_proj_bias = ParameterMeta.from_dims( - (td_inner,), init_method=init_from_tensor_(inv_dt), lr_scale=mamba_layer_lr_scale + (td_inner,), + init_method=init_dtprojbias( + self.d_inner, self.config.dt_max, self.config.dt_min, self.config.dt_init_floor + ), + lr_scale=mamba_layer_lr_scale, ) - A = einops.repeat( - torch.arange(1, self.d_state + 1, dtype=torch.float32), - "n -> d n", - d=self.d_inner, - ).contiguous() - A_log = torch.log(A).flatten() # Keep A_log in fp32 self.A_log = ParameterMeta.from_dims( (td_inner, td_state), - init_method=init_from_tensor_(A_log), + init_method=init_A(self.config.state_size, self.config.d_inner), lr_scale=mamba_layer_lr_scale, weight_decay=False, ) @@ -214,8 +201,8 @@ def forward(self, hidden_states, kwargs): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1) x = einops.rearrange(x, "b l d -> b d l") z = einops.rearrange(z, "b l d -> b d l") @@ -225,7 +212,7 @@ def forward(self, hidden_states, kwargs): B = einops.rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = einops.rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) + self.dt_proj_bias # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) + self.dt_proj_bias # B, L, d_inner dt = einops.rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: @@ -238,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), + weight=self.conv1d_weight, bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/models/ssm/config.py b/fast_llm/models/ssm/config.py index 9ca0123b2..b04b1f210 100644 --- a/fast_llm/models/ssm/config.py +++ b/fast_llm/models/ssm/config.py @@ -78,7 +78,7 @@ def setup_tensor_space(self, tensor_space: TensorSpace) -> None: tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_discrete_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.conv_dim, conv_dim)) elif SSMBlockType.mamba2.value in self.hybrid_block_layout: - inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner + self.ssm.dt_rank + inner_proj_dim: int = 2 * self.ssm.d_xb + 2 * d_inner # + self.ssm.dt_rank tensor_space.add_tensor_dim(TensorDim(SSMDimNames.inner_proj_mamba2, inner_proj_dim)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.x_proj_dim_2, self.ssm.d_xb)) tensor_space.add_tensor_dim(TensorDim(SSMDimNames.c_heads, d_inner // self.ssm.state_size)) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 42252c620..4976ad2b1 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,7 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", + # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From b49c42febac4f32dc1be83655b242d6199a385bc Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:16:42 -0400 Subject: [PATCH 03/10] misc --- fast_llm/layers/ssm/discrete_mamba2.py | 4 ++-- fast_llm/layers/ssm/mamba2.py | 8 ++++---- fast_llm/layers/ssm/mamba_layer.py | 4 ++-- .../modeling_ssm_hybrid_apriel15b.py | 20 +++++++++++++------ tests/utils/model_configs.py | 1 - 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index 734e35b21..c0ae7e781 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, TensorDim("1", 1), td_conv_kernel), + (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index b936ccf14..74c212add 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -4,7 +4,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorDim, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, td_conv_kernel), + (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -225,7 +225,7 @@ def forward(self, hidden_states, kwargs): if _causal_conv1d_available: x = _causal_conv1d_fn( x=x, - weight=self.conv1d_weight, + weight=einops.rearrange(self.conv1d_weight, "d 1 w -> d w"), bias=self.conv1d_bias, activation=self.activation, ) # B, L, D diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index a95e94c03..4493332ce 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -5,7 +5,7 @@ import einops import torch -from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace +from fast_llm.engine.config_utils.tensor_space import DefaultDimNames, TensorSpace from fast_llm.layers.common.linear import Linear from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig @@ -98,7 +98,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, TensorDim("D_inner_2", self.d_inner // self.d_inner), td_conv_kernel), + (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), init_method=kaiming_init_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py index f8f6a0520..4fde72458 100644 --- a/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py +++ b/fast_llm/models/ssm/external/apriel_15b_hybrid/modeling_ssm_hybrid_apriel15b.py @@ -843,9 +843,8 @@ def __init__( self.num_C_head = self.d_inner // self.d_state self.repeat_group = self.num_C_head // self.num_xb_head - self.in_proj = nn.Linear( - self.d_model, 2 * self.d_xb + 2 * self.d_inner + self.dt_rank, bias=bias, **factory_kwargs - ) + self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs) + self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs) self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization @@ -933,8 +932,17 @@ def forward( outputs = {} A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - zxbcdt = self.in_proj(hidden_states) - z, x, B, C, dt = torch.split(zxbcdt, [self.d_inner, self.d_xb, self.d_xb, self.d_inner, self.dt_rank], dim=-1) + zxbc = self.in_proj(hidden_states) + z, x, B, C = torch.split( + zxbc, + [ + self.d_inner, + self.d_xb, + self.d_xb, + self.d_inner, + ], + dim=-1, + ) x = rearrange(x, "b l d -> b d l") z = rearrange(z, "b l d -> b d l") @@ -944,7 +952,7 @@ def forward( B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous() C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous() - dt = self.dt_proj(dt) # B, L, d_inner + dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L if self.repeat_kv_before_conv: diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 4976ad2b1..1eee3675d 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -523,7 +523,6 @@ def _update_and_add_testing_config( model_type="hybrid_ssm", extra_args=[ "model.base_model.hybrid_block_layout=['t','m2']", - # f"model.base_model.transformer.debug_transformer={_LOG_LEVEL}", ], megatron_args=None, checkpoint_format=None, From 31f5d415ef0c7eeca54a26d415076cbf3ba33cfd Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 24 Jul 2025 18:20:26 -0400 Subject: [PATCH 04/10] misc --- fast_llm/models/ssm/conversion.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/fast_llm/models/ssm/conversion.py b/fast_llm/models/ssm/conversion.py index d57300252..43e3c67e5 100644 --- a/fast_llm/models/ssm/conversion.py +++ b/fast_llm/models/ssm/conversion.py @@ -3,6 +3,7 @@ import pathlib import typing +from fast_llm.config import MISSING from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import ( ConstantExportParamConverter, @@ -19,7 +20,7 @@ from fast_llm.engine.multi_stage.config import FastLLMModelConfig from fast_llm.functional.config import ActivationType from fast_llm.layers.common.config import RMSNormalizationConfig -from fast_llm.layers.ssm.config import SSMBlockType +from fast_llm.layers.ssm.config import DTInitType, SSMBlockType from fast_llm.models.gpt.conversion import CommonLlamaHuggingfaceCheckpointHandler, MLPLayer2Converter from fast_llm.models.ssm.config import ( AprielSSMHHybridHuggingfaceCheckpointFormat, @@ -42,11 +43,11 @@ class HybridModelCheckpointHandler(HuggingfaceStateDictCheckpointHandler): @classmethod def _create_config_converters(cls) -> list[ParamConverter]: - block_converter = RenameParamConverter( + block_converter = MappedConfigParamConverter( fast_llm_names=(("hybrid_block_layout",),), export_names=(("hybrid_block_layout",),), - ignore_missing=True, - default_value=[cls._default_block_type], + fast_llm_value=lambda x: [cls._default_block_type] if x == MISSING else x, + export_value=lambda x: [x_.value for x_ in x], ) return super()._create_config_converters() + [block_converter] @@ -202,7 +203,7 @@ def _create_config_converters(cls) -> list[ParamConverter]: ignore_missing=True, default_value=4, ), - RenameParamConverter( + MappedConfigParamConverter( fast_llm_names=(("ssm", "dt_init"),), export_names=( ( @@ -210,8 +211,8 @@ def _create_config_converters(cls) -> list[ParamConverter]: "dt_init", ), ), - ignore_missing=True, - default_value="random", + fast_llm_value=lambda x: DTInitType.random if x == MISSING else DTInitType(x), + export_value=lambda x: x.value, ), ] @@ -258,6 +259,9 @@ def _create_weight_converters(self) -> list[WeightConverter]: ) # ================================================ # Mamba2 specific parameters + converters += self._get_weight_and_bias_converters( + f"layers.{i+1}.mixer.dt_in_proj", f"model.layers.{i}.mixer.dt_in_proj", ssm_bias + ) converters += self._get_weight_and_bias_converters( f"layers.{i+1}.mixer.dt_proj", f"model.layers.{i}.mixer.dt_proj", False ) From 5eea938403a74bcf8ee7f0c504e3d8bb6fe118f7 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 11:52:24 -0400 Subject: [PATCH 05/10] fix --- fast_llm/models/custom/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/custom/model.py b/fast_llm/models/custom/model.py index c206ef406..534d813ff 100644 --- a/fast_llm/models/custom/model.py +++ b/fast_llm/models/custom/model.py @@ -7,7 +7,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.engine.schedule.config import BatchConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.layers.transformer.transformer import TransformerBlock from fast_llm.models.custom.config import CustomBaseModelConfig, CustomModelConfig from fast_llm.models.custom.head import CustomHead from fast_llm.models.gpt.config import GPTBaseModelConfig @@ -31,10 +31,10 @@ def get_layers(self) -> list[Layer]: return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ - TransformerLayer( + TransformerBlock( self._config.transformer, self._tensor_space, - layer_index=i + 1, + block_index=i + 1, ) for i in range(self._config.transformer.num_layers) ], From e536af9d935fe789b98683777e3e320eaf5d7e62 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 16:15:17 -0400 Subject: [PATCH 06/10] Concatenated dim --- fast_llm/engine/config_utils/tensor_space.py | 224 +++++++++++++----- fast_llm/engine/multi_stage/fsdp.py | 32 +-- fast_llm/engine/multi_stage/stage_base.py | 5 +- fast_llm/layers/common/config.py | 6 +- fast_llm/layers/common/linear.py | 8 +- fast_llm/layers/common/normalization.py | 4 +- fast_llm/layers/common/peft.py | 4 +- fast_llm/layers/language_model/embedding.py | 8 +- fast_llm/layers/language_model/head.py | 10 +- .../layers/language_model/preprocessing.py | 4 +- fast_llm/layers/transformer/attention.py | 16 +- .../layers/transformer/mixture_of_experts.py | 6 +- fast_llm/layers/transformer/mlp.py | 6 +- fast_llm/layers/transformer/preprocessing.py | 2 +- .../transformer/rotary/preprocessing.py | 4 +- fast_llm/layers/transformer/rotary/rotary.py | 4 +- fast_llm/layers/transformer/transformer.py | 4 +- fast_llm/models/gpt/megatron.py | 29 +-- fast_llm/models/gpt/model.py | 2 +- fast_llm/tensor.py | 169 ++++++++----- 20 files changed, 346 insertions(+), 201 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 99c1bcf70..cf2974a99 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -1,3 +1,4 @@ +import logging import math import typing @@ -5,9 +6,13 @@ from fast_llm.utils import Assert, div if typing.TYPE_CHECKING: + import torch + from fast_llm.core.distributed import ProcessGroup from fast_llm.engine.distributed.distributed import Distributed +logger = logging.getLogger(__name__) + class TensorDim: def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): @@ -19,11 +24,11 @@ def __init__(self, name: str, global_size: int | None, parallel_dim: Distributed def __repr__(self) -> str: return ( - f"TensorDim(" + f"{type(self).__name__}(" f"name={self._name}," f" size={self._size}," f" global_size={self._global_size}," - f" parallel_dim={None if self.parallel_dim is None else self._parallel_dim}" + f" parallel_dim={self._parallel_dim}" f")" ) @@ -38,83 +43,180 @@ def name(self) -> str: def size(self) -> int: return self._size - @property - def expanded_shape(self) -> tuple[int, ...]: - return (self._size,) - - @property - def ndim(self) -> int: - return 1 - @property def global_size(self) -> int: return self._global_size @property - def global_expanded_shape(self) -> tuple[int, ...]: - return (self._size if self._parallel_dim is None else self._size * self._parallel_dim.size,) + def is_parallel(self) -> bool: + return self._parallel_dim is not None and self._parallel_dim.size > 1 @property def parallel_dim(self) -> DistributedDim | None: + # TODO: Make more flexible for derived classes? return self._parallel_dim - @property - def parallel_dim_index(self) -> int | None: - return None if self._parallel_dim is None else 0 - @property def parallel_group(self) -> "ProcessGroup|None": + # TODO: Make more flexible for derived classes? return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim is not None + assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + if self.is_parallel: + from fast_llm.core.ops import gather_op + + return gather_op(tensor, self.parallel_group, dim) + else: + return tensor + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + if self.is_parallel: + output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) + output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) + return output.flatten(dim, dim + 1) + else: + return tensor + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + return ( + tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] + if self.parallel_dim is not None and self.parallel_dim.size > 1 + else tensor + ) + class CompositeTensorDim(TensorDim): - def __init__(self, name: str, dims: tuple[TensorDim, ...]): - # TODO: Recursive composition?? - parallel_dims = [(i, dim.parallel_dim) for i, dim in enumerate(dims) if dim.parallel_dim] - Assert.leq(len(parallel_dims), 1) + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = None + for dim, tensor_dim in enumerate(tensor_dims): + if tensor_dim.is_parallel: + # TODO: Allow more than one parallel subdim? + assert parallel_dim is None + parallel_dim = tensor_dim.parallel_dim + self._parallel_dim_index = dim super().__init__( name=name, - global_size=math.prod(dim.global_size for dim in dims), - parallel_dim=parallel_dims[0][1] if parallel_dims else None, - ) - self._dims = dims - self._parallel_dim_index = ( - sum(dim.ndim for dim in self._dims[: parallel_dims[0][0]]) - + self._dims[parallel_dims[0][0]].parallel_dim_index - if parallel_dims - else None + global_size=math.prod(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, ) + self._tensor_dims = tensor_dims - @property - def dims(self) -> tuple[TensorDim, ...]: - return self._dims + def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + assert self._parallel_dim_index is not None + dims = list(self._tensor_dims) + dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) + return CompositeTensorDim(self.name, tuple(dims)) - @property - def ndim(self) -> int: - return sum(dim.ndim for dim in self._dims) + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global(tensor, dim + i) - @property - def expanded_shape(self) -> tuple[int, ...]: - return sum((dim.expanded_shape for dim in self._dims), ()) + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def global_expanded_shape(self) -> tuple[int, ...]: - return sum((dim.global_expanded_shape for dim in self._dims), ()) + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in enumerate(self._tensor_dims): + tensor = tensor_dim.local_to_global_partial(tensor, dim + i) + + return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) + for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): + tensor = tensor_dim.global_to_local(tensor, dim + i) + return tensor if expand else tensor.flatten(dim, dim + len(self._tensor_dims) - 1) - @property - def parallel_dim_index(self) -> int | None: - return self._parallel_dim_index + +class ConcatenatedTensorDim(TensorDim): + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): + parallel_dim = tensor_dims[0].parallel_dim + for dim, tensor_dim in enumerate(tensor_dims[1:]): + # TODO: Allow more flexibility? + Assert.is_(tensor_dim.parallel_dim, parallel_dim) + + super().__init__( + name=name, + global_size=sum(dim.global_size for dim in tensor_dims), + parallel_dim=parallel_dim, + ) + self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: - assert self.parallel_dim_index is not None - dims = list(self.dims) - dims[self.parallel_dim_index] = dims[self.parallel_dim_index].replace_parallel_dim(distributed_dim) - return CompositeTensorDim(self.name, tuple(dims)) + assert self.is_parallel + return ConcatenatedTensorDim( + self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) + ) + + def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def local_to_global_partial( + self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 + ) -> "torch.Tensor": + import torch + + return ( + torch.concatenate( + [ + tensor_dim.local_to_global_partial(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) + + def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + if self.is_parallel and expand: + raise NotImplementedError() + import torch + + return ( + torch.concatenate( + [ + tensor_dim.global_to_local(tensor_, dim) + for tensor_, tensor_dim in zip( + tensor.split([tensor_dim.global_size for tensor_dim in self._tensor_dims], dim), + self._tensor_dims, + strict=True, + ) + ], + dim, + ) + if self.is_parallel + else tensor + ) class DefaultDimNames: @@ -147,21 +249,19 @@ def distributed(self) -> "Distributed": assert self._is_setup return self._distributed - def add_tensor_dim(self, dim: TensorDim) -> None: - if isinstance(dim, CompositeTensorDim): - for dim_ in dim.dims: - Assert.incl(dim_.name, self._tensor_dims) - Assert.eq(dim_, self._tensor_dims[dim_.name]) - if dim.name in self._tensor_dims: - Assert.eq(dim, self._tensor_dims[dim.name]) + def add_tensor_dim(self, tensor_dim: TensorDim) -> None: + if tensor_dim.name in self._tensor_dims: + Assert.eq(tensor_dim, self._tensor_dims[tensor_dim.name]) else: - if dim.parallel_dim is not None: - assert dim.parallel_dim.name in self._distributed_config.distributed_dims, dim.parallel_dim.name + if tensor_dim.parallel_dim is not None: + assert ( + tensor_dim.parallel_dim.name in self._distributed_config.distributed_dims + ), tensor_dim.parallel_dim.name Assert.eq( - dim.parallel_dim.__dict__, - self._distributed_config.distributed_dims[dim.parallel_dim.name].__dict__, + tensor_dim.parallel_dim.__dict__, + self._distributed_config.distributed_dims[tensor_dim.parallel_dim.name].__dict__, ) - self._tensor_dims[dim.name] = dim + self._tensor_dims[tensor_dim.name] = tensor_dim - def get_tensor_dim(self, name: str) -> TensorDim: + def __getitem__(self, name: str) -> TensorDim: return self._tensor_dims[name] diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index 5b44bf14b..be15cd37a 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -441,39 +441,21 @@ def _get_parameter_shard_indices_in_full_weight( where it is located in the shard if it exists, or -1 if it's not in the shard. Used to determine the location of each entry in a different distributed configuration. """ - - # Create an empty index for the global parameter. - index = torch.full( - parameter_meta.global_shape, - -1, - dtype=torch.int64, - device=device, - ) # Set the shard slice of the global parameter to corresponding indices of the parameter slice of the shard begin, end = self._get_parameter_range_in_shard(parameter_name) - buffer_index = parameter_meta.global_to_local(index, expand=True) - # Copying directly into `buffer_index` requires a view of the tensor, which may not be feasible. - # In that case, we work with a separate tensor to be copied back into `buffer_index`. - try: - buffer_index_flat = buffer_index.view(-1) - is_view = True - except RuntimeError: - buffer_index_flat = buffer_index.new_full((buffer_index.numel(),), -1) - is_view = False - - # Copy the shard indices at their respective positions in the flat buffer index. - buffer_index_flat[ + # Create an empty local index to hold the local shard indices. + buffer_index = torch.full_like(parameter_meta, -1, dtype=torch.int64, device=device) + + # Copy the shard indices at their respective positions in the buffer index. + buffer_index.flatten()[ self._index_buffer_to_param( self._fsdp_dim.rank * self._shard_size, parameter_name ) : self._index_buffer_to_param((self._fsdp_dim.rank + 1) * self._shard_size, parameter_name) ].copy_(torch.arange(begin, end, dtype=torch.int64, device=device)) - # If needed, copy the flat buffer index back into the index. - if not is_view: - buffer_index.copy_(buffer_index_flat.view_as(buffer_index)) - - return index + # Create a global index from the local one. + return parameter_meta.local_to_global_partial(buffer_index, -1) def copy_shard_overlaps( self, diff --git a/fast_llm/engine/multi_stage/stage_base.py b/fast_llm/engine/multi_stage/stage_base.py index 9a8ce2092..3218a1963 100644 --- a/fast_llm/engine/multi_stage/stage_base.py +++ b/fast_llm/engine/multi_stage/stage_base.py @@ -185,8 +185,9 @@ def initialize_weights(self) -> None: # Multi-gpu init may be different because of TP or FSDP (different shape), or PP (not on device) global_shape = meta.global_shape - if self._distributed_config.reproducible_init and ( - global_shape.numel() != parameter.numel() or not self._mode.on_device + if meta.requires_global_initialization or ( + self._distributed_config.reproducible_init + and (global_shape.numel() != parameter.numel() or not self._mode.on_device) ): # Initialize all global weights on every gpu, then select the appropriate slice if applicable. global_param = parameter.new_empty(global_shape, device=self._distributed.device) diff --git a/fast_llm/layers/common/config.py b/fast_llm/layers/common/config.py index 9f32ac689..07dadbc22 100644 --- a/fast_llm/layers/common/config.py +++ b/fast_llm/layers/common/config.py @@ -99,7 +99,7 @@ class LayerNormalizationBaseConfig(NormalizationConfig): ) def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> "LayerNorm | RMSNorm": - from fast_llm.tensor import init_uniform_ + from fast_llm.tensor import init_uniform_centered_ kwargs = { "hidden_dim": hidden_dim, @@ -110,9 +110,7 @@ def get_layer(self, hidden_dim: "TensorDim", lr_scale: float | None = None) -> " } if self.initialization_range: mean = 0 if self.zero_centered else 1 - kwargs["weight_init_method"] = init_uniform_( - mean - self.initialization_range, mean + self.initialization_range - ) + kwargs["weight_init_method"] = init_uniform_centered_(self.initialization_range, mean=mean) return self.module_class(**kwargs) @property diff --git a/fast_llm/layers/common/linear.py b/fast_llm/layers/common/linear.py index cd19a47a5..7249ef569 100644 --- a/fast_llm/layers/common/linear.py +++ b/fast_llm/layers/common/linear.py @@ -94,8 +94,8 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None - assert out_dim.parallel_dim is None + assert not in_dim.is_parallel + assert not out_dim.is_parallel super().__init__( in_dim, out_dim, @@ -132,7 +132,7 @@ def __init__( sequence_parallel: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert in_dim.parallel_dim is None + assert not in_dim.is_parallel self._group_size = 1 if out_dim.parallel_dim is None else out_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( @@ -176,7 +176,7 @@ def __init__( transposed_weight: bool = False, lr_scale: float | None | tuple[float | None, ...] = None, ): - assert out_dim.parallel_dim is None + assert not out_dim.is_parallel self._group_size = 1 if in_dim.parallel_dim is None else in_dim.parallel_dim.size self._sequence_parallel = sequence_parallel and self._group_size > 1 super().__init__( diff --git a/fast_llm/layers/common/normalization.py b/fast_llm/layers/common/normalization.py index 5f30beaef..bccc1d627 100644 --- a/fast_llm/layers/common/normalization.py +++ b/fast_llm/layers/common/normalization.py @@ -158,7 +158,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: @@ -242,7 +242,7 @@ def __init__( lr_scale: float | None = None, ): super().__init__() - assert hidden_dim.parallel_dim is None + assert not hidden_dim.is_parallel self._eps = eps self._zero_centered = zero_centered if implementation == NormalizationImplementation.auto: diff --git a/fast_llm/layers/common/peft.py b/fast_llm/layers/common/peft.py index 3a1966e51..08f3e535b 100644 --- a/fast_llm/layers/common/peft.py +++ b/fast_llm/layers/common/peft.py @@ -19,12 +19,12 @@ def lora_linear( ): layer.weight.requires_grad = False in_dim = layer._in_dim + assert not in_dim.is_parallel, "LoRA not supported with tensor parallelism." if in_dim.parallel_dim is not None: - assert in_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." in_dim = TensorDim(in_dim.name, in_dim.global_size) out_dim = layer._out_dim + assert not out_dim.is_parallel, "LoRA not supported with tensor parallelism." if out_dim.parallel_dim is not None: - assert out_dim.parallel_dim.size == 1, "LoRA not supported with tensor parallelism." out_dim = TensorDim(out_dim.name, out_dim.global_size) if out_channel_begin is not None or out_channel_end is not None: if out_channel_begin is None: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 7036a1e97..f6f43d199 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -46,10 +46,10 @@ def __init__( self._dropout_p = config.transformer.hidden_dropout self._use_absolute_position_embeddings = config.use_absolute_position_embeddings - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - vocab_dim = tensor_space.get_tensor_dim( + hidden_dim = tensor_space[TransformerDimNames.hidden] + vocab_dim = tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] if self._parallel_embeddings: self._vocab_start_index = self._distributed_config.tensor_rank * vocab_dim.size @@ -66,7 +66,7 @@ def __init__( ) if self._use_absolute_position_embeddings: self.position_embeddings_weight = ParameterMeta.from_dims( - (tensor_space.get_tensor_dim(LanguageModelDimNames.position_embed), hidden_dim), + (tensor_space[LanguageModelDimNames.position_embed], hidden_dim), init_method=init_normal_( std=config.init_method_std_embed, min_val=config.init_method_min_embed, diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 21bf3bbd0..210cad644 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -61,7 +61,7 @@ def __init__( if self._cross_entropy_splits is not None and self._sequence_parallel: assert not self._parallel_embeddings - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] self._loss_coefficient = ( config.prediction_loss_coefficient[prediction_distance] if config.prediction_loss_coefficient else 1.0 @@ -108,9 +108,9 @@ def _init_output_weights(self, hidden_dim: TensorDim, config) -> None: if self._tie_word_embeddings or self._prediction_distance > 0: return # untie embedding weights - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab_tp if self._parallel_embeddings else LanguageModelDimNames.vocab - ) + ] self.output_weights = ParameterMeta.from_dims( (vocab_dim, hidden_dim), init_method=init_normal_( @@ -338,9 +338,9 @@ def _logits_cross_entropy_forward_backward( logits_scale_factor=self._logits_scale_factor, ) if self._debug_transformer and self._cross_entropy_splits is None: - vocab_dim = self._tensor_space.get_tensor_dim( + vocab_dim = self._tensor_space[ LanguageModelDimNames.vocab if self._sequence_parallel_logits else LanguageModelDimNames.vocab_tp - ) + ] dims = [*kwargs[TransformerKwargs.hidden_dims][:-1], vocab_dim] sequence_index = 1 - int(kwargs[TransformerKwargs.sequence_first]) dims[sequence_index] = ( diff --git a/fast_llm/layers/language_model/preprocessing.py b/fast_llm/layers/language_model/preprocessing.py index d719bef3d..c8d53a789 100644 --- a/fast_llm/layers/language_model/preprocessing.py +++ b/fast_llm/layers/language_model/preprocessing.py @@ -28,7 +28,7 @@ def __init__( assert config.use_absolute_position_embeddings self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: @@ -76,7 +76,7 @@ def __init__(self, config: LanguageModelBaseConfig, tensor_space: TensorSpace): self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def preprocess_meta(self, kwargs: dict[str, typing.Any]) -> None: return diff --git a/fast_llm/layers/transformer/attention.py b/fast_llm/layers/transformer/attention.py index 174e19588..c59b191af 100644 --- a/fast_llm/layers/transformer/attention.py +++ b/fast_llm/layers/transformer/attention.py @@ -91,14 +91,14 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i max_val=self._config.init_method_max_attn_proj, ) - self._kv_channels = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels).size - self._head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).global_size - self._local_head_groups = self._tensor_space.get_tensor_dim(TransformerDimNames.head_groups).size - self._local_heads_per_group = self._tensor_space.get_tensor_dim(TransformerDimNames.group_heads).size + self._kv_channels = self._tensor_space[TransformerDimNames.kv_channels].size + self._head_groups = self._tensor_space[TransformerDimNames.head_groups].global_size + self._local_head_groups = self._tensor_space[TransformerDimNames.head_groups].size + self._local_heads_per_group = self._tensor_space[TransformerDimNames.group_heads].size self._local_heads = self._local_head_groups * self._local_heads_per_group self._softmax_scale = self._kv_channels ** (-self._config.attention_softmax_scale_power) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] layer_lr_scale = config.per_layer_lr_scale[block_index] if config.per_layer_lr_scale else None attention_lr_scale = get_lr_scale(self._config.attention_lr_scale, layer_lr_scale) @@ -106,7 +106,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # TODO: Merge the query and key-value computations? (harder with sequence parallel.) self.query = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_query), + self._tensor_space[TransformerDimNames.composite_query], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -115,7 +115,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i ) self.key_value = OutputParallelLinear( hidden_dim, - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_key_value), + self._tensor_space[TransformerDimNames.composite_key_value], bias=self._config.add_attn_qkv_bias, weight_init_method=init_method_qkv, bias_init_method=init_method_qkv if self._config.random_bias_init else init_zeros_, @@ -129,7 +129,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, block_i # Output. self.dense = InputParallelLinear( - self._tensor_space.get_tensor_dim(TransformerDimNames.composite_dense), + self._tensor_space[TransformerDimNames.composite_dense], hidden_dim, bias=self._config.add_attn_dense_bias, weight_init_method=init_method_std_attn_proj, diff --git a/fast_llm/layers/transformer/mixture_of_experts.py b/fast_llm/layers/transformer/mixture_of_experts.py index 73f83ccf5..4fd2844d5 100644 --- a/fast_llm/layers/transformer/mixture_of_experts.py +++ b/fast_llm/layers/transformer/mixture_of_experts.py @@ -63,8 +63,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s router_lr_scale = get_lr_scale(config.router_lr_scale, layer_lr_scale) self.router = Linear( - tensor_space.get_tensor_dim(TransformerDimNames.hidden), - tensor_space.get_tensor_dim(TransformerDimNames.unshared_experts), + tensor_space[TransformerDimNames.hidden], + tensor_space[TransformerDimNames.unshared_experts], bias=False, weight_init_method=init_normal_( std=config.init_method_std, min_val=config.init_method_min, max_val=config.init_method_max @@ -255,7 +255,7 @@ def _debug_log( def _get_meta(self, tensor: torch.Tensor, name: str, dim_name: str, kwargs: dict[str, typing.Any]) -> TensorMeta: return TensorMeta.from_dims( - kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space.get_tensor_dim(dim_name),), + kwargs[TransformerKwargs.hidden_dims][:-1] + (self._tensor_space[dim_name],), tensor_name=f"{self._name} {name}", dtype=tensor.dtype, ) diff --git a/fast_llm/layers/transformer/mlp.py b/fast_llm/layers/transformer/mlp.py index efe0c5cc5..101d97ef3 100644 --- a/fast_llm/layers/transformer/mlp.py +++ b/fast_llm/layers/transformer/mlp.py @@ -30,8 +30,8 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s max_val=config.init_method_max_mlp_2, ) - hidden_dim = tensor_space.get_tensor_dim(TransformerDimNames.hidden) - self._intermediate_dim = tensor_space.get_tensor_dim(TransformerDimNames.composite_expert_mlp) + hidden_dim = tensor_space[TransformerDimNames.hidden] + self._intermediate_dim = tensor_space[TransformerDimNames.composite_expert_mlp] self._sequence_parallel = tensor_space.distributed_config.sequence_tensor_parallel self._recompute_level = config.mlp_recompute_level @@ -46,7 +46,7 @@ def __init__(self, config: TransformerConfig, tensor_space: TensorSpace, name: s # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = LinearBase( hidden_dim, - tensor_space.get_tensor_dim(TransformerDimNames.composite_gated_expert_mlp), + tensor_space[TransformerDimNames.composite_gated_expert_mlp], bias=config.add_mlp_bias, weight_init_method=init_method_1, bias_init_method=init_method_1 if config.random_bias_init else init_zeros_, diff --git a/fast_llm/layers/transformer/preprocessing.py b/fast_llm/layers/transformer/preprocessing.py index dc3ddeb52..3f0e14eb7 100644 --- a/fast_llm/layers/transformer/preprocessing.py +++ b/fast_llm/layers/transformer/preprocessing.py @@ -28,7 +28,7 @@ def __init__( self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config assert not self._config.do_use_flash_attention(self._distributed_config) - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] def _create_tensors(self, sequence_length: int) -> None: if sequence_length <= self._tensor_cache_max_sequence_length: diff --git a/fast_llm/layers/transformer/rotary/preprocessing.py b/fast_llm/layers/transformer/rotary/preprocessing.py index cc83dae02..c357411b6 100644 --- a/fast_llm/layers/transformer/rotary/preprocessing.py +++ b/fast_llm/layers/transformer/rotary/preprocessing.py @@ -25,8 +25,8 @@ def __init__( self._config = config self._tensor_space = tensor_space self._distributed_config = self._tensor_space.distributed_config - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: self._create_tensors(kwargs[TransformerKwargs.sequence_length]) diff --git a/fast_llm/layers/transformer/rotary/rotary.py b/fast_llm/layers/transformer/rotary/rotary.py index 056b9aa4c..17b18a1ca 100644 --- a/fast_llm/layers/transformer/rotary/rotary.py +++ b/fast_llm/layers/transformer/rotary/rotary.py @@ -82,8 +82,8 @@ def __init__( super().__init__(config, tensor_space) self._tensor_space = tensor_space if self._tensor_space is not None: - self._scalar_dim = self._tensor_space.get_tensor_dim(DefaultDimNames.scalar) - self._kv_channels_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.kv_channels) + self._scalar_dim = self._tensor_space[DefaultDimNames.scalar] + self._kv_channels_dim = self._tensor_space[TransformerDimNames.kv_channels] def preprocess(self, batch, kwargs: dict[str, typing.Any]) -> None: assert self._tensor_space is not None diff --git a/fast_llm/layers/transformer/transformer.py b/fast_llm/layers/transformer/transformer.py index d08db9a94..75d06f268 100644 --- a/fast_llm/layers/transformer/transformer.py +++ b/fast_llm/layers/transformer/transformer.py @@ -48,7 +48,7 @@ def _get_meta( } return TensorMeta.from_dims( tuple( - hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space.get_tensor_dim(dim_name) + hidden_dims[dim_name] if dim_name in hidden_dims else self._tensor_space[dim_name] for dim_name in dim_names ), tensor_name=f"Block {self._block_index} {self._mixer_name} {name}", @@ -97,7 +97,7 @@ def __init__( self._block_index = block_index self._debug_mode = self._config.debug_transformer or self._config.debug_transformer_memory - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] # Note, layer_lr_scale does not impact the norms # TODO: add a separate norm_lr_scale self.norm_1 = self._config.normalization.get_layer(hidden_dim) diff --git a/fast_llm/models/gpt/megatron.py b/fast_llm/models/gpt/megatron.py index e7379e61e..20ed8e828 100644 --- a/fast_llm/models/gpt/megatron.py +++ b/fast_llm/models/gpt/megatron.py @@ -14,8 +14,8 @@ def get_init_megatron( meta: "ParameterMeta", config: TransformerConfig -) -> typing.Callable[["torch.Tensor", "Distributed"], "torch.Tensor"]: - def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): +) -> typing.Callable[["torch.Tensor", "Distributed"], None]: + def init_megatron(tensor: "torch.Tensor", distributed: "Distributed") -> None: Assert.eq(distributed.config.world_size, 1) if "bias" in meta.tensor_name: # Generator unused. @@ -29,11 +29,11 @@ def init_megatron(tensor: "torch.Tensor", distributed: "Distributed"): elif config.num_experts > 1 and "mlp.layer_" in meta.tensor_name: tensor_ = _init_moe_mlp_megatron(config, meta, tensor, distributed) elif "mlp.layer_2" in meta.tensor_name: - tensor_ = _init_transposed_mlp_weight_megatron(config, meta, tensor, distributed) + tensor_ = _init_transposed_mlp_weight_megatron(meta, tensor, distributed) else: # Word embedding (override generator), layer norm (generator unused), other mlp weights. return meta.param_init_method(meta, tensor, distributed.tp_init_generator) - return tensor.copy_(tensor_.reshape_as(tensor)) + tensor.copy_(tensor_.reshape_as(tensor)) return init_megatron @@ -58,9 +58,9 @@ def _init_attention_megatron( generator = distributed.tp_init_generator state = generator.get_state() # Initialize a mock dense layer to advance the random state - dense_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + dense_tensor_ := tensor.new_empty( config.kv_channels * config.num_attention_heads, config.hidden_size, ), @@ -68,9 +68,9 @@ def _init_attention_megatron( ) # QKV is split differently. (Assuming no tensor-parallel.) heads_per_group = div(config.num_attention_heads, config.head_groups) - qkv_tensor_ = meta.param_init_method( + meta.param_init_method( meta, - tensor.new_empty( + qkv_tensor_ := tensor.new_empty( config.head_groups, heads_per_group + 2, config.kv_channels, @@ -110,18 +110,19 @@ def _init_position_embeddings_megatron( # Megatron initializes the position embeddings on cpu twice. assert meta.param_init_method is not None generator = distributed.default_cpu_generator - tensor_ = meta.param_init_method(meta, torch.empty(tensor.shape, dtype=tensor.dtype), generator) - return meta.param_init_method(meta, tensor_, generator) + meta.param_init_method(meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), generator) + meta.param_init_method(meta, tensor_, generator) + return tensor_ def _init_transposed_mlp_weight_megatron( - config: TransformerConfig, meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" + meta: "ParameterMeta", tensor: "torch.Tensor", distributed: "Distributed" ) -> "torch.Tensor": import torch # Megatron never transposes the mlp layer 2 weight. assert meta.param_init_method is not None - tensor_ = meta.param_init_method(meta, torch.empty_like(tensor), distributed.tp_init_generator) + meta.param_init_method(meta, tensor_ := torch.empty_like(tensor), distributed.tp_init_generator) return tensor_.view(meta.size(1), meta.size(0)).t() @@ -132,8 +133,8 @@ def _init_moe_router_megatron( # Megatron initializes the router on cpu. assert meta.param_init_method is not None - tensor_ = meta.param_init_method( - meta, torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator + meta.param_init_method( + meta, tensor_ := torch.empty(tensor.shape, dtype=tensor.dtype), distributed.default_cpu_generator ) return tensor_ diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 4c1eab46f..49a5dcbd3 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -155,7 +155,7 @@ def preprocess_meta( sequence_first = self._config.sequence_first assert not (need_sequence_first and not sequence_first) - hidden_dim = self._tensor_space.get_tensor_dim(TransformerDimNames.hidden) + hidden_dim = self._tensor_space[TransformerDimNames.hidden] hidden_dims = ( (hidden_sequence_q_dim, batch_dim, hidden_dim) if sequence_first diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d780e4d6d..b3795b740 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -1,17 +1,21 @@ +import abc import functools +import logging import math import typing import torch from fast_llm.core.distributed import ReduceOp -from fast_llm.core.ops import gather_op, reduce_op +from fast_llm.core.ops import reduce_op from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.functional.triton.pointwise import triton_add, triton_copy from fast_llm.utils import Assert +logger = logging.getLogger(__name__) + class _SafeTensorSliceMeta(type): def __instancecheck__(self, instance) -> bool: @@ -146,7 +150,7 @@ def from_tensor_space( reductions: tuple[tuple[str, ReduceOp], ...] = (), **kwargs: typing.Any, ) -> typing.Self: - dims = tuple(tensor_space.get_tensor_dim(dim_name) for dim_name in dim_names) + dims = tuple(tensor_space[dim_name] for dim_name in dim_names) if reductions: # kwarg not available for ParameterMeta, so we only provide if necessary. kwargs["reductions"] = tuple( @@ -158,22 +162,23 @@ def from_tensor_space( def global_shape(self) -> torch.Size: return torch.Size([dim.global_size for dim in self.dims]) - def local_to_global( - self, - tensor: torch.Tensor, - *, - distributed: Distributed, - ) -> tuple[torch.Tensor, ...]: + def local_to_global(self, tensor: torch.Tensor, *, distributed: Distributed) -> tuple[torch.Tensor, ...]: + """ + Reconstruct a global tensor from its distributed slices. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) # Tensors are always either split or duplicated in the tensor-parallel direction. # TODO: Avoid hard-coded assumptions on duplication - is_first_rank = distributed.config.tensor_rank == 0 - modified = False - for i, dim in enumerate(self.dims): - if dim.parallel_group is not None: - tensor = gather_op( - tensor.unflatten(i, dim.expanded_shape), dim.parallel_group, i + dim.parallel_dim_index - ).flatten(i, i + len(dim.expanded_shape) - 1) - is_first_rank, modified = is_first_rank and dim.parallel_group.rank() == 0, True + is_first_rank, modified = distributed.config.tensor_rank == 0, False + + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global(tensor, dim) + is_first_rank &= tensor_dim.parallel_dim.rank == 0 + modified = True for distributed_dim, op in self._reductions: if distributed_dim.group is not None: @@ -182,28 +187,48 @@ def local_to_global( tensor = tensor.clone() tensor = reduce_op(tensor, distributed_dim.group, op=op) is_first_rank, modified = is_first_rank and distributed_dim.group.rank() == 0, True + Assert.eq(tensor.shape, self.global_shape) return tensor, is_first_rank - def global_to_local( - self, - tensor: torch.Tensor | SafeTensorSlice, - # Return an expanded tensor, avoiding `flatten` which copies the data. - expand: bool = False, - ) -> torch.Tensor: + def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int = -1) -> torch.Tensor: """ - Recover the tensor-parallel slice of a tensor. Support lazy-loaded safetensor slices. + Construct a tensor of shape `self.global_shape` that contains its local slice at the appropriate location, + i.e. for which `self.global_to_local(self.local_to_global_partial(tensor)) == tensor`. + Other entries are filled with `fill_value`. + Returns a view of the input tensor (or the input tensor itself) when possible. + """ + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.shape) + assert not self._reductions + logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") + for dim, tensor_dim in enumerate(self.dims): + if tensor_dim.is_parallel: + tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) + logger.info( + f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" + ) + + Assert.eq(tensor.shape, self.global_shape) + return tensor + + def global_to_local(self, tensor: torch.Tensor | SafeTensorSlice) -> torch.Tensor: + """ + Select the local slice of a global tensor. Support lazy-loaded safetensor slices. + Returns a view of the input tensor (or the input tensor itself) when possible. """ # Take a trivial slice to convert safetensor slices. - tensor_ = tensor[:] + tensor = tensor[:] assert not self._reductions + if tensor.ndim == 0: + tensor = tensor[None] + Assert.eq(tensor.shape, self.global_shape) - for i, dim in reversed(list(enumerate(self.dims))): - if dim.parallel_dim is not None and dim.parallel_dim.size > 1: - tensor_ = tensor_.unflatten(i, dim.global_expanded_shape).chunk( - dim.parallel_dim.size, i + dim.parallel_dim_index - )[dim.parallel_dim.rank] + for dim, tensor_dim in reversed(list(enumerate(self.dims))): + tensor = tensor_dim.global_to_local(tensor, dim) - return tensor_ if expand else tensor_.reshape(self.shape) + Assert.eq(tensor.shape, self.shape) + return tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -237,7 +262,7 @@ def __init__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable[["ParameterMeta", torch.Tensor, torch.Generator], torch.Tensor] | None = None, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None" = None, weight_decay: bool = True, # Pass a list to split the parameter in contiguous (dim=0) chunks of equal size for optimization. lr_scale: float | None | tuple[float | None, ...] = None, @@ -247,7 +272,11 @@ def __init__( allow_no_grad: bool = False, ): super().__init__(data, tensor_name=tensor_name, dims=dims) - self.param_init_method = init_method + if init_method is not None and not isinstance(init_method, Initializer): + # Support non-wrapped callables for convenience. + assert callable(init_method) + init_method = LambdaInitializer(init_method) + self.param_init_method: Initializer | None = init_method self.param_weight_decay = weight_decay self._is_param = True self.param_grad_is_zero = False @@ -272,7 +301,7 @@ def __new__( *, tensor_name: str = "", dims: tuple[TensorDim, ...], - init_method: typing.Callable, + init_method: "Initializer | typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None] | None", weight_decay: bool = True, lr_scale: float | None | tuple[float | None, ...] = None, allow_sequence_tensor_parallel: bool = True, @@ -293,12 +322,20 @@ def __repr__(self, *, tensor_contents=()) -> str: def init_parameter(self, tensor: torch.Tensor, distributed: Distributed) -> None: assert self.param_init_method is not None - if distributed.config.tensor_parallel == 1 or distributed.config.reproducible_init: + if ( + distributed.config.tensor_parallel == 1 + or distributed.config.reproducible_init + or self.param_init_method.requires_global_initialization + ): generator = distributed.pp_init_generator else: generator = distributed.tp_init_generator if self.is_tensor_parallel else distributed.pp_init_generator self.param_init_method(self, tensor, generator) + @property + def requires_global_initialization(self) -> bool: + return self.param_init_method.requires_global_initialization + def save(self) -> dict[str, typing.Any]: return { "name": self.tensor_name, @@ -330,11 +367,32 @@ def accumulate_gradient(param: torch.Tensor, grad: torch.Tensor) -> None: triton_add(grad, param.grad_buffer, out=param.grad_buffer) # noqa -def init_fill_(value) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa - return tensor.fill_(value) +class Initializer(abc.ABC): + @abc.abstractmethod + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + pass + + requires_global_initialization = False + + +class LambdaInitializer(Initializer): + def __init__( + self, + init_method: typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], None], + requires_global_initialization: bool = False, + ) -> None: + self._init_method = init_method + self.requires_global_initialization = requires_global_initialization + + def __call__(self, meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: + return self._init_method(meta, tensor, generator) + - return init_ +def init_fill_(value: float) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa + tensor.fill_(value) + + return LambdaInitializer(init_) init_zeros_ = init_fill_(0.0) @@ -342,30 +400,35 @@ def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) def init_normal_( - mean=0.0, std=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + mean: float = 0.0, std: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.normal_(mean, std, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) - return init_ + return LambdaInitializer(init_) -def kaiming_init_(d_in): +def init_kaiming_(d_in: float) -> LambdaInitializer: return init_normal_(0.0, math.sqrt(2.0 / d_in)) def init_uniform_( - low=0.0, high=1.0, min_val=None, max_val=None -) -> typing.Callable[[ParameterMeta, torch.Tensor, torch.Generator], torch.Tensor]: - def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator): # noqa + low: float = 0.0, high: float = 1.0, min_val: float | None = None, max_val: float | None = None +) -> LambdaInitializer: + def init_(meta: ParameterMeta, tensor: torch.Tensor, generator: torch.Generator) -> None: # noqa tensor = tensor.uniform_(low, high, generator=generator) if min_val is not None or max_val is not None: - return tensor.clamp_(min=min_val, max=max_val) # noqa - else: - return tensor + tensor.clamp_(min=min_val, max=max_val) + + return LambdaInitializer(init_) + - return init_ +def init_uniform_centered_(high: float, max_val: float | None = None, mean: float = 0.0) -> LambdaInitializer: + return init_uniform_( + mean - high, + mean + high, + min_val=None if max_val is None else mean - max_val, + max_val=None if max_val is None else mean + max_val, + ) From 017f5cc5a021d9a2ef58e5d1903f60c4917f311c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 28 Jul 2025 18:09:53 -0400 Subject: [PATCH 07/10] fixes --- fast_llm/layers/ssm/discrete_mamba2.py | 24 ++++++++++----------- fast_llm/layers/ssm/mamba2.py | 26 +++++++++++----------- fast_llm/layers/ssm/mamba_layer.py | 30 ++++++++++++-------------- 3 files changed, 39 insertions(+), 41 deletions(-) diff --git a/fast_llm/layers/ssm/discrete_mamba2.py b/fast_llm/layers/ssm/discrete_mamba2.py index c0ae7e781..6012f74a7 100644 --- a/fast_llm/layers/ssm/discrete_mamba2.py +++ b/fast_llm/layers/ssm/discrete_mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig, TransformerKwargs from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, init_uniform_, init_zeros_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_, init_uniform_, init_zeros_ from fast_llm.utils import get_lr_scale logger = logging.getLogger(__name__) @@ -62,14 +62,14 @@ def __init__( mamba_layer_lr_scale = get_lr_scale(self.config.mamba_lr_scale, layer_lr_scale) logger.info(f"Setting lr_scale for layer {block_index} of type {type(self)}: {mamba_layer_lr_scale}") - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv = tensor_space.get_tensor_dim(SSMDimNames.conv_dim) - td_n_qk_heads = tensor_space.get_tensor_dim(SSMDimNames.qk_heads) - td_n_v_heads = tensor_space.get_tensor_dim(SSMDimNames.v_heads) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) - td_inner_proj = tensor_space.get_tensor_dim(SSMDimNames.inner_proj_discrete_mamba2) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv = tensor_space[SSMDimNames.conv_dim] + td_n_qk_heads = tensor_space[SSMDimNames.qk_heads] + td_n_v_heads = tensor_space[SSMDimNames.v_heads] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_discrete_mamba2] self.d_model = td_model.size self.d_inner = td_inner.size @@ -88,7 +88,7 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.z_bias = ( @@ -103,7 +103,7 @@ def __init__( ) self.conv1d_weight = ParameterMeta.from_dims( - (td_conv, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_conv, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( 1 / math.sqrt(td_conv.size * td_conv_kernel.size), 1 / math.sqrt(td_conv.size * td_conv_kernel.size) ), # see https://github.com/pytorch/pytorch/blob/1eba9b3aa3c43f86f4a2c807ac8e12c4a7767340/torch/nn/modules/conv.py#L180C53-L180C67 @@ -126,7 +126,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) diff --git a/fast_llm/layers/ssm/mamba2.py b/fast_llm/layers/ssm/mamba2.py index 74c212add..9dfad8462 100644 --- a/fast_llm/layers/ssm/mamba2.py +++ b/fast_llm/layers/ssm/mamba2.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.mamba_layer import init_A, init_dtprojbias from fast_llm.layers.transformer.config import TransformerConfig, TransformerDimNames from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_fill_, init_ones_, init_uniform_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_fill_, init_kaiming_, init_ones_, init_uniform_ from fast_llm.utils import get_lr_scale try: @@ -80,13 +80,13 @@ def __init__( self.config.mamba_lr_scale, layer_lr_scale ) - td_inner: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_dim) - td_state: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.state_dim) - td_model: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.model_dim) - tdt_rank: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.dt_rank) - td_xb: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.x_proj_dim_2) - td_inner_proj: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.inner_proj_mamba2) - td_conv_kernel: TensorDim = tensor_space.get_tensor_dim(name=SSMDimNames.conv_kernel_size) + td_inner: TensorDim = tensor_space[SSMDimNames.inner_dim] + td_state: TensorDim = tensor_space[SSMDimNames.state_dim] + td_model: TensorDim = tensor_space[SSMDimNames.model_dim] + tdt_rank: TensorDim = tensor_space[SSMDimNames.dt_rank] + td_xb: TensorDim = tensor_space[SSMDimNames.x_proj_dim_2] + td_inner_proj: TensorDim = tensor_space[SSMDimNames.inner_proj_mamba2] + td_conv_kernel: TensorDim = tensor_space[SSMDimNames.conv_kernel_size] self.repeat_kv_before_conv = config.repeat_kv_before_conv @@ -98,7 +98,7 @@ def __init__( if self.repeat_kv_before_conv: self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_inner.size * td_conv_kernel.size), 1 / math.sqrt(td_inner.size * td_conv_kernel.size), @@ -111,7 +111,7 @@ def __init__( ) else: self.conv1d_weight = ParameterMeta.from_dims( - (td_xb, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), + (td_xb, tensor_space[DefaultDimNames.scalar], td_conv_kernel), init_method=init_uniform_( -1 / math.sqrt(td_xb.size * td_conv_kernel.size), 1 / math.sqrt(td_xb.size * td_conv_kernel.size), @@ -131,14 +131,14 @@ def __init__( td_model, td_inner_proj, bias=bias, - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.dt_in_proj = Linear( td_model, tdt_rank, bias=config.add_bias_linear, - weight_init_method=kaiming_init_(transformer_config.hidden_size), + weight_init_method=init_kaiming_(transformer_config.hidden_size), lr_scale=mamba_layer_lr_scale, ) # Initialize special dt projection to preserve variance at initialization @@ -185,7 +185,7 @@ def __init__( td_inner, td_model, bias=bias, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), ) def forward(self, hidden_states, kwargs): diff --git a/fast_llm/layers/ssm/mamba_layer.py b/fast_llm/layers/ssm/mamba_layer.py index 4493332ce..5e0ae786e 100644 --- a/fast_llm/layers/ssm/mamba_layer.py +++ b/fast_llm/layers/ssm/mamba_layer.py @@ -10,7 +10,7 @@ from fast_llm.layers.ssm.config import SSMConfig, SSMDimNames from fast_llm.layers.transformer.config import TransformerConfig from fast_llm.layers.transformer.transformer import Mixer -from fast_llm.tensor import ParameterMeta, init_ones_, kaiming_init_ +from fast_llm.tensor import ParameterMeta, init_kaiming_, init_ones_ from fast_llm.utils import get_lr_scale try: @@ -75,15 +75,13 @@ def __init__( self.config: SSMConfig = config # Tensor dims: - td_inner = tensor_space.get_tensor_dim(SSMDimNames.inner_dim) - td_inner_proj = tensor_space.get_tensor_dim( - SSMDimNames.inner_proj_mamba - ) # TensorDim("D_inner_2", self.d_inner * 2) - tdt_rank = tensor_space.get_tensor_dim(SSMDimNames.dt_rank) - td_x_proj = tensor_space.get_tensor_dim(SSMDimNames.x_proj_dim) - td_state = tensor_space.get_tensor_dim(SSMDimNames.state_dim) - td_model = tensor_space.get_tensor_dim(SSMDimNames.model_dim) - td_conv_kernel = tensor_space.get_tensor_dim(SSMDimNames.conv_kernel_size) + td_inner = tensor_space[SSMDimNames.inner_dim] + td_inner_proj = tensor_space[SSMDimNames.inner_proj_mamba] # TensorDim("D_inner_2", self.d_inner * 2) + tdt_rank = tensor_space[SSMDimNames.dt_rank] + td_x_proj = tensor_space[SSMDimNames.x_proj_dim] + td_state = tensor_space[SSMDimNames.state_dim] + td_model = tensor_space[SSMDimNames.model_dim] + td_conv_kernel = tensor_space[SSMDimNames.conv_kernel_size] self.d_conv = td_conv_kernel.size self.d_inner = td_inner.size self.d_state = td_state.size @@ -94,12 +92,12 @@ def __init__( self.in_proj_weight = ParameterMeta.from_dims( (td_inner_proj, td_model), - init_method=kaiming_init_(td_model.size), + init_method=init_kaiming_(td_model.size), ) self.conv1d_weight = ParameterMeta.from_dims( - (td_inner, tensor_space.get_tensor_dim(DefaultDimNames.scalar), td_conv_kernel), - init_method=kaiming_init_(td_inner.size), + (td_inner, tensor_space[DefaultDimNames.scalar], td_conv_kernel), + init_method=init_kaiming_(td_inner.size), lr_scale=mamba_layer_lr_scale, ) @@ -111,7 +109,7 @@ def __init__( self.x_proj = Linear( td_inner, td_x_proj, - weight_init_method=kaiming_init_(td_inner.size), + weight_init_method=init_kaiming_(td_inner.size), bias=False, lr_scale=mamba_layer_lr_scale, ) @@ -120,7 +118,7 @@ def __init__( # TODO: the weights are initialized a bit differently here https://github.com/state-spaces/mamba/blob/0cce0fa645f100f00620ddf2333c2b7712abfdec/mamba_ssm/modules/mamba_simple.py#L82 self.dt_proj_weight = ParameterMeta.from_dims( (td_inner, tdt_rank), - init_method=kaiming_init_(tdt_rank.size), + init_method=init_kaiming_(tdt_rank.size), lr_scale=mamba_layer_lr_scale, ) @@ -151,7 +149,7 @@ def __init__( td_inner, td_model, bias=False, # TODO: note, if bias is used there is a problem in the MambaInnerFn.backward for the bias grads. I think this bias is not used in other mamba repos. - weight_init_method=kaiming_init_(td_model.size), + weight_init_method=init_kaiming_(td_model.size), lr_scale=mamba_layer_lr_scale, ) self.out_proj.weight.auto_grad_accumulation = True From 6bf06d6aecb9a2a0de67ad7a42690db071a812f4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 15:51:13 -0400 Subject: [PATCH 08/10] fix --- fast_llm/tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index b3795b740..d080e6a1e 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -201,13 +201,9 @@ def local_to_global_partial(self, tensor: torch.Tensor, fill_value: float | int tensor = tensor[None] Assert.eq(tensor.shape, self.shape) assert not self._reductions - logger.info(f"AAAA {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape}") for dim, tensor_dim in enumerate(self.dims): if tensor_dim.is_parallel: tensor = tensor_dim.local_to_global_partial(tensor, dim, fill_value) - logger.info( - f"BBBB {self.tensor_name} {self.shape} {self.global_shape} {tensor.shape} {tensor_dim.is_parallel}" - ) Assert.eq(tensor.shape, self.global_shape) return tensor From 2ddc3a748817ee98785344e03809cfd67590e954 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 29 Jul 2025 16:15:10 -0400 Subject: [PATCH 09/10] fix --- fast_llm/engine/config_utils/tensor_space.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index cf2974a99..6c4b95b20 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -95,7 +95,7 @@ class CompositeTensorDim(TensorDim): def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): - if tensor_dim.is_parallel: + if tensor_dim.parallel_dim is not None: # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim From bd4ff0d03fd7f878c6b8d1551ffa682326f2d150 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 12 Aug 2025 14:21:51 -0400 Subject: [PATCH 10/10] doc --- fast_llm/engine/config_utils/tensor_space.py | 86 +++++++++++++++++++- fast_llm/tensor.py | 8 +- 2 files changed, 91 insertions(+), 3 deletions(-) diff --git a/fast_llm/engine/config_utils/tensor_space.py b/fast_llm/engine/config_utils/tensor_space.py index 6c4b95b20..66176ee0f 100644 --- a/fast_llm/engine/config_utils/tensor_space.py +++ b/fast_llm/engine/config_utils/tensor_space.py @@ -15,6 +15,16 @@ class TensorDim: + """ + Describes a simple, atomic dimension of a tensor and its size. + The dimension may be parallelized along a distributed dimension `parallel_dim`, + in which case its actual (local) `size` will differ from its `global_size`. + + TensorDim's are used to represent the metadata of tensors through `TensorMeta`. + + This class also serves as a base for more complex tensor dimensions. + """ + def __init__(self, name: str, global_size: int | None, parallel_dim: DistributedDim | None = None): # TODO: Handle None for unknown sizes? self._name = name @@ -62,10 +72,25 @@ def parallel_group(self) -> "ProcessGroup|None": return None if self._parallel_dim is None else self._parallel_dim.group def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + + Used in`TensorMeta.replace_tensor_parallel_dim`. + """ assert self.is_parallel return TensorDim(self.name, self.size * distributed_dim.size, distributed_dim) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + If the dimension is parallelized, this amounts to gathering along dimension `dim` + and parallel dimension `parallel_dim`, otherwise return the input tensor. + The method needs to be called my all members of the parallel group using their appropriate local slice. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: from fast_llm.core.ops import gather_op @@ -76,6 +101,14 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`. + Unlike `local_to_global`, this method does not need to be called from a distributed setting. + Instead, entries from other ranks are populated with `fill_value`. + + Used in`TensorMeta.local_to_global_partial`, + which iterates over the tensor dimensions to fully reconstruct the global tensor. + """ if self.is_parallel: output = tensor.new_full((*tensor.shape[:dim], self.parallel_dim.size, *tensor.shape[dim:]), fill_value) output.narrow(dim, self.parallel_dim.rank, 1).copy_(tensor.unsqueeze(dim)).squeeze(dim) @@ -84,6 +117,14 @@ def local_to_global_partial( return tensor def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + If the dimension is parallel, this amounts to taking the `rank`th chunk of size `size` along dimension `dim` + and parallel dimension `self.parallel_dim`, otherwise return the input tensor. + + Used in`TensorMeta.local_to_global`, + which iterates over the tensor dimensions to fully reconstruct the local tensor. + """ return ( tensor.chunk(self.parallel_dim.size, dim)[self.parallel_dim.rank] if self.parallel_dim is not None and self.parallel_dim.size > 1 @@ -92,11 +133,20 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class CompositeTensorDim(TensorDim): + """ + A composite tensor dimension that represent multiple dimensions flattened into ones. + Typically happens for flattened view or higher-dimensional tensors, or tensors that can be expanded as such. + If one of the composed dimensions -- other than the first one -- is parallelized, + this is **not** equivalent to an atomic `TensorDim` of the same size, + as the relation between local and global tensors is different. + + At most one of the sub-dimensions may be parallelized. TODO: Allow for more than one? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = None for dim, tensor_dim in enumerate(tensor_dims): if tensor_dim.parallel_dim is not None: - # TODO: Allow more than one parallel subdim? assert parallel_dim is None parallel_dim = tensor_dim.parallel_dim self._parallel_dim_index = dim @@ -109,12 +159,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self._parallel_dim_index is not None dims = list(self._tensor_dims) dims[self._parallel_dim_index] = dims[self._parallel_dim_index].replace_parallel_dim(distributed_dim) return CompositeTensorDim(self.name, tuple(dims)) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global(tensor, dim + i) @@ -124,6 +181,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ tensor = tensor.unflatten(dim, [tensor_dim.size for tensor_dim in self._tensor_dims]) for i, tensor_dim in enumerate(self._tensor_dims): tensor = tensor_dim.local_to_global_partial(tensor, dim + i) @@ -131,6 +192,9 @@ def local_to_global_partial( return tensor.flatten(dim, dim + len(self._tensor_dims) - 1) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ tensor = tensor.unflatten(dim, [tensor_dim.global_size for tensor_dim in self._tensor_dims]) for i, tensor_dim in reversed(list(enumerate(self._tensor_dims))): tensor = tensor_dim.global_to_local(tensor, dim + i) @@ -138,6 +202,12 @@ def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = F class ConcatenatedTensorDim(TensorDim): + """ + A complex tensor dimension that results from concatenating tensors. + + All sub-dimensions should have the same `parallel_dim` (may be None). TODO: Allow for more complex scenarios? + """ + def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): parallel_dim = tensor_dims[0].parallel_dim for dim, tensor_dim in enumerate(tensor_dims[1:]): @@ -152,12 +222,19 @@ def __init__(self, name: str, tensor_dims: tuple[TensorDim, ...]): self._tensor_dims = tensor_dims def replace_parallel_dim(self, distributed_dim: DistributedDim) -> typing.Self: + """ + Create a copy of the tensor dimension, where the parallel dimension is replaced by `distributed_dim`, + but the local size remains the same. + """ assert self.is_parallel return ConcatenatedTensorDim( self.name, tuple(tensor_dim.replace_parallel_dim(distributed_dim) for tensor_dim in self._tensor_dims) ) def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from local `tensor` slices whose dimension `dim` is described by `self`. + """ import torch return ( @@ -179,6 +256,10 @@ def local_to_global(self, tensor: "torch.Tensor", dim: int = 0) -> "torch.Tensor def local_to_global_partial( self, tensor: "torch.Tensor", dim: int = 0, fill_value: float | int = -1 ) -> "torch.Tensor": + """ + Partially reconstruct a global tensor from a local `tensor` whose dimension `dim` is described by `self`, + populating other ranks with `fill_value`. + """ import torch return ( @@ -198,6 +279,9 @@ def local_to_global_partial( ) def global_to_local(self, tensor: "torch.Tensor", dim: int = 0, expand: bool = False) -> "torch.Tensor": + """ + Partially recover a local tensor slice from a global `tensor` whose dimension `dim` is described by `self`. + """ if self.is_parallel and expand: raise NotImplementedError() import torch diff --git a/fast_llm/tensor.py b/fast_llm/tensor.py index d080e6a1e..c17df9d0c 100644 --- a/fast_llm/tensor.py +++ b/fast_llm/tensor.py @@ -240,8 +240,12 @@ def validate(self, tensor: torch.Tensor, device: torch.device | None = None) -> return validate_tensor(tensor, self, device) def replace_tensor_parallel_dim(self, distributed_dim: DistributedDim) -> "TensorMeta": - # Replace the tensor-parallel `DistributedDim` in `meta`. - # Note: This will turn `ParameterMeta` into `TensorMeta` + """ + Replace the tensor-parallel `DistributedDim` in `meta`, preserving the local size. + Requires for advanced tensor manipulations, + ex. turn tensor-parallel slices of a tensor into slices of a different tensor-parallel size. + Note: This will turn `ParameterMeta` into `TensorMeta` + """ if not self.is_tensor_parallel: return self dims = list(self.dims)