Skip to content

Commit df30991

Browse files
authored
[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers (#194)
1 parent 5180937 commit df30991

File tree

23 files changed

+1550
-21
lines changed

23 files changed

+1550
-21
lines changed

.github/workflows/ci.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
run: |
3030
pip install "torch>=2.2.2"
3131
pip install pybind11
32-
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
32+
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
3333
3434
- name: Run tests
3535
run: pytest .

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131
- run: |
3232
pip install "torch>=2.2.2"
3333
pip install pybind11
34-
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
34+
FLASH_ATTENTION_SKIP_CUDA_BUILD=TRUE FLASH_ATTENTION_FORCE_BUILD=TRUE MAMBA_SKIP_CUDA_BUILD=TRUE MAMBA_FORCE_BUILD=TRUE CAUSAL_CONV1D_FORCE_BUILD=TRUE CAUSAL_CONV1D_SKIP_CUDA_BUILD=TRUE pip install --no-build-isolation -e ".[CORE,OPTIONAL,DEV,DOCS]"
3535
- name: Build the documentation
3636
run: mkdocs build
3737

fast_llm/functional/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ActivationType(str, enum.Enum):
4343
silu = "silu"
4444
relu = "relu"
4545
squared_relu = "squared_relu"
46+
identity = "identity"
4647

4748
@property
4849
def activation_fn(self) -> typing.Callable[["torch.Tensor"], "torch.Tensor"]:
@@ -70,6 +71,7 @@ def _set_activation_fn_map() -> None:
7071
ActivationType.silu: torch.nn.functional.silu,
7172
ActivationType.relu: torch.nn.functional.relu,
7273
ActivationType.squared_relu: lambda x: torch.pow(torch.nn.functional.relu(x), 2),
74+
ActivationType.identity: lambda x: x,
7375
}
7476

7577

@@ -80,6 +82,7 @@ def _set_activation_fn_map() -> None:
8082
ActivationType.silu: "silu",
8183
ActivationType.relu: "relu",
8284
ActivationType.squared_relu: "relu2",
85+
ActivationType.identity: "identity",
8386
}
8487
_ACTIVATION_HF_NAMES_INV = {value: key for key, value in _ACTIVATION_HF_NAMES.items()}
8588

fast_llm/functional/triton/mlp.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ def triton_mlp_activation_backward_kernel(
119119
grad = 2 * relu_out
120120
if gated or recompute:
121121
out = relu_out * relu_out
122+
elif activation_type == _TritonActivationType.identity:
123+
grad = 1
124+
if gated or recompute:
125+
out = input_
122126
else:
123127
raise NotImplementedError()
124128

fast_llm/layers/language_model/config.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from fast_llm.engine.config_utils.tensor_space import TensorDim, TensorSpace
66
from fast_llm.engine.distributed.config import DistributedDimNames
77
from fast_llm.functional.config import CrossEntropyImpl
8+
from fast_llm.layers.ssm.config import SSMArchitectureConfig, SSMConfig
89
from fast_llm.layers.transformer.config import TransformerArchitectureConfig, TransformerConfig
910
from fast_llm.utils import Assert
1011

@@ -43,6 +44,13 @@ class LanguageModelArchitectureConfig(BaseModelArchitectureConfig):
4344
desc="Configuration for the transformer architecture.",
4445
hint=FieldHint.core,
4546
)
47+
48+
ssm: SSMArchitectureConfig = Field(
49+
default_factory=SSMArchitectureConfig,
50+
desc="Configuration for the transformer architecture.",
51+
hint=FieldHint.core,
52+
)
53+
4654
max_position_embeddings: int = Field(
4755
default=2048,
4856
desc="Number of absolute position embeddings, if applicable.",
@@ -125,6 +133,8 @@ class LanguageModelBaseConfig(LanguageModelArchitectureConfig, BaseModelConfig):
125133
architecture_class = LanguageModelArchitectureConfig
126134

127135
transformer: TransformerConfig = FieldUpdate(default_factory=TransformerConfig)
136+
ssm: SSMConfig = FieldUpdate(default_factory=SSMConfig)
137+
128138
init_method_std_embed: float = Field(
129139
default=None,
130140
desc="Initialization scale for the vocabulary embedding and output weights (logits).",

fast_llm/layers/ssm/config.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
from fast_llm.config import Field, FieldHint, FieldUpdate, check_field, config_class
2+
from fast_llm.engine.base_model.config import BaseModelArchitectureConfig, BaseModelConfig
3+
from fast_llm.functional.config import ActivationType
4+
from fast_llm.layers.common.config import NormalizationArchitectureConfig, NormalizationConfig
5+
from fast_llm.utils import Assert
6+
7+
8+
class SSMDimNames:
9+
model_dim = "model_dim" # Model dimension (D)
10+
state_dim = "state_dim" # State dimension (N)
11+
conv_dim = "conv_dim" # Dimension of the conv1d input in mamba layers
12+
inner_dim = "inner_dim" # Inner dimension after expansion
13+
dt_rank = "dt_rank" # Rank of Δ
14+
inner_proj_mamba = "inner_proj_mamba" # Inner projection dimension for mamba
15+
inner_proj_mamba2 = "inner_proj_mamba2" # Inner projection dimension for mamba2
16+
x_proj_dim = "x_proj_dim" # X projection dimension
17+
head_dim = "head_dim" # Dimension of the mamba2 head (P)
18+
conv_kernel_size = "conv_kernel_size" # Kernel size of the conv1d in mamba layers
19+
qk_heads = "qk_heads" # Number of QK heads
20+
v_heads = "v_heads" # Number of V heads
21+
22+
23+
@config_class()
24+
class SSMArchitectureConfig(BaseModelArchitectureConfig):
25+
_abstract = False
26+
27+
# Normalization
28+
normalization: NormalizationArchitectureConfig = Field(
29+
default_factory=NormalizationArchitectureConfig,
30+
desc="Configuration for the normalization layers architecture.",
31+
hint=FieldHint.core,
32+
)
33+
34+
expansion_factor: int = Field(
35+
default=2, desc="Expansion factor for Mamba blocks.", hint=FieldHint.core, valid=check_field(Assert.gt, 0)
36+
)
37+
38+
state_size: int = Field(
39+
default=16,
40+
desc="State size for Mamba blocks.",
41+
hint=FieldHint.core,
42+
valid=check_field(Assert.gt, 0),
43+
)
44+
conv_kernel_dimension: int = Field(
45+
default=4,
46+
desc="Conv kernel dimension for Mamba blocks.",
47+
hint=FieldHint.core,
48+
valid=check_field(Assert.gt, 0),
49+
)
50+
51+
# Layer parameters
52+
add_bias_linear: bool = Field(
53+
default=False,
54+
desc="Whether to use bias in SSM layers",
55+
hint=FieldHint.core,
56+
)
57+
58+
dt_rank: int = Field(
59+
default=None,
60+
desc="Rank of the Δ projection matrix. If 'None', will be set to ceil(hidden_size/16)",
61+
hint=FieldHint.core,
62+
)
63+
64+
chunk_size: int = Field(
65+
default=256,
66+
desc="Chunk size for Mamba2 blocks.",
67+
hint=FieldHint.core,
68+
)
69+
70+
n_qk_heads: int = Field(
71+
default=32,
72+
desc="Number of QK heads for Mamba2 blocks.",
73+
hint=FieldHint.core,
74+
)
75+
76+
n_v_heads: int = Field(
77+
default=32,
78+
desc="Number of V heads for Mamba2 blocks.",
79+
hint=FieldHint.core,
80+
)
81+
82+
activation_type: ActivationType = Field(
83+
default=None,
84+
desc="The MLP intermediate activation type. Default: SiLU for gated MLP, GeLU otherwise.",
85+
hint=FieldHint.core,
86+
)
87+
88+
def _validate(self) -> None:
89+
with self._set_implicit_default():
90+
if self.activation_type is None:
91+
self.activation_type = ActivationType.silu
92+
if self.dt_rank is None:
93+
self.dt_rank = -1 # set to -1, it will be overwrittem in ssm validation
94+
95+
super()._validate()
96+
97+
98+
@config_class()
99+
class SSMConfig(SSMArchitectureConfig, BaseModelConfig):
100+
"""Configuration for a Structured State Space Model (SSM) layer."""
101+
102+
normalization: NormalizationConfig = FieldUpdate(default_factory=NormalizationConfig)
103+
104+
debug_ssm: bool = Field(
105+
default=False,
106+
desc="debug_ssm",
107+
hint=FieldHint.optional,
108+
)
109+
110+
dt_min: float = Field(
111+
default=0.001,
112+
desc="Minimum step size for discretization",
113+
hint=FieldHint.core,
114+
valid=check_field(Assert.gt, 0),
115+
)
116+
117+
dt_max: float = Field(
118+
default=0.1,
119+
desc="Maximum step size for discretization",
120+
hint=FieldHint.core,
121+
valid=check_field(Assert.gt, 0),
122+
)
123+
124+
dt_init_floor: float = Field(
125+
default=1e-4,
126+
desc="Minimum value for initializing dt",
127+
hint=FieldHint.core,
128+
valid=check_field(Assert.gt, 0),
129+
)
130+
131+
def _validate(self) -> None:
132+
"""Validate configuration parameters."""
133+
134+
super()._validate()
135+
Assert.geq(self.dt_max, self.dt_min)

0 commit comments

Comments
 (0)