Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion megatron/core/models/mamba/mamba_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
TERowParallelLinear,
)
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules
from megatron.core.ssm.mamba_layer import MambaLayer,MambaLayerSubmodules
from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer, ParallelHybridLayerSubmodules
from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules
from megatron.core.ssm.mlp_layer import MLPLayer
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
Expand Down Expand Up @@ -64,5 +66,29 @@
mlp_bda=get_bias_dropout_add,
),
),

parallel_hybrid_layer=ModuleSpec(
module=ParallelHybridLayer,
submodules=ParallelHybridLayerSubmodules(
mamba_mixer=ModuleSpec(
module=MambaMixer,
submodules=MambaMixerSubmodules(
in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear
),
),
parallel_hybrid_bda=get_bias_dropout_add,
self_attention=ModuleSpec(
module=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
),
),
),
),
),
),
)
5 changes: 5 additions & 0 deletions megatron/core/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class MambaModel(LanguageModule):
(used with pipeline parallelism). Defaults to True.
hybrid_attention_ratio (float, optional): The target ratio of attention
layers to total layers
parallel_hybrid_ratio (float, optional): The target ratio of parallel hybrid
layers to total layers
hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers
hybrid_override_pattern (str, optional): The hybrid layer pattern to override with
post_process (bool, optional): Include an output layer (used with pipeline parallelism).
Expand Down Expand Up @@ -60,6 +62,7 @@ def __init__(
max_sequence_length: int,
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
parallel_hybrid_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_process: bool = True,
Expand All @@ -84,6 +87,7 @@ def __init__(
self.max_sequence_length = max_sequence_length
self.pre_process = pre_process
self.hybrid_attention_ratio = hybrid_attention_ratio
self.parallel_hybrid_ratio = parallel_hybrid_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.post_process = post_process
Expand Down Expand Up @@ -121,6 +125,7 @@ def __init__(
self.config,
pre_process=self.pre_process,
hybrid_attention_ratio=self.hybrid_attention_ratio,
parallel_hybrid_ratio=self.parallel_hybrid_ratio,
hybrid_mlp_ratio=self.hybrid_mlp_ratio,
hybrid_override_pattern=self.hybrid_override_pattern,
post_process=self.post_process,
Expand Down
22 changes: 21 additions & 1 deletion megatron/core/ssm/mamba_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer
from megatron.core.transformer.utils import sharded_state_dict_default
from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor

Expand Down Expand Up @@ -86,6 +87,7 @@ class MambaStackSubmodules:
mamba_layer: Union[ModuleSpec, type] = IdentityOp
attention_layer: Union[ModuleSpec, type] = IdentityOp
mlp_layer: Union[ModuleSpec, type] = IdentityOp
parallel_hybrid_layer: Union[ModuleSpec, type] = IdentityOp


class MambaStack(MegatronModule):
Expand Down Expand Up @@ -123,6 +125,7 @@ def __init__(
pre_process: bool = True,
hybrid_attention_ratio: float = 0.0,
hybrid_mlp_ratio: float = 0.0,
parallel_hybrid_ratio: float = 0.0,
hybrid_override_pattern: str = None,
post_layer_norm: bool = True,
post_process: bool = True,
Expand All @@ -146,11 +149,13 @@ def __init__(
self.hybrid_attention_ratio = hybrid_attention_ratio
self.hybrid_mlp_ratio = hybrid_mlp_ratio
self.hybrid_override_pattern = hybrid_override_pattern
self.parallel_hybrid_ratio = parallel_hybrid_ratio

layer_type_list = allocate_layers(
self.config.num_layers,
self.hybrid_attention_ratio,
self.hybrid_mlp_ratio,
self.parallel_hybrid_ratio,
self.hybrid_override_pattern,
)

Expand Down Expand Up @@ -188,6 +193,13 @@ def __init__(
layer_number=i + 1,
pg_collection=pg_collection,
)
elif layer_type == LayerSymbols.PARALLEL:
layer = build_module(
submodules.parallel_hybrid_layer,
config=self.config,
layer_number=i + 1 + pp_layer_offset,
model_comm_pgs=model_comm_pgs,
)
else:
assert False, "unexpected layer_type"
self.layers.append(layer)
Expand Down Expand Up @@ -333,7 +345,15 @@ def forward(
)
with inner_fp8_context:
if isinstance(layer, TransformerLayer):
hidden_states, _ = layer(
hidden_states, _ = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_context=inference_context,
rotary_pos_emb=rotary_pos_emb,
sequence_len_offset=sequence_len_offset,
)
if isinstance(layer, ParallelHybridLayer):
hidden_states = layer(
hidden_states=hidden_states,
attention_mask=attention_mask,
inference_context=inference_context,
Expand Down
87 changes: 48 additions & 39 deletions megatron/core/ssm/mamba_hybrid_layer_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class Symbols:
MAMBA = "M"
ATTENTION = "*"
MLP = "-"
VALID = {MAMBA, ATTENTION, MLP}
PARALLEL = "P"
VALID = {MAMBA, ATTENTION, MLP, PARALLEL}


def _allocate_auto(
total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float
total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, target_parallel_hybrid_ratio: float
) -> list:
# First, allocate attention (evenly spaced, starting and ending with mamba)
attention_layers_count: int = round(total_layers_count * target_attention_ratio)
mamba_layers_count: int = total_layers_count - attention_layers_count
mamba_sections_count: int = attention_layers_count + 1
Expand All @@ -39,8 +39,6 @@ def _allocate_auto(
else:
x -= 1

# Next, allocate mlp
# (evenly distributed, but right-justified, not replacing attention)
mlp_layers_count: int = round(total_layers_count * target_mlp_ratio)
if mlp_layers_count > 0:
mamba_layers_count -= mlp_layers_count
Expand All @@ -55,6 +53,26 @@ def _allocate_auto(
else:
x -= 1

parallel_layers_count: int = round(total_layers_count * target_parallel_hybrid_ratio)
if parallel_layers_count > 0:
remaining_mamba_count = layer_type_list.count(Symbols.MAMBA)
if remaining_mamba_count > 0:
if parallel_layers_count >= remaining_mamba_count:
for l in range(total_layers_count):
if layer_type_list[l] == Symbols.MAMBA:
layer_type_list[l] = Symbols.PARALLEL
else:
mamba_to_parallel_ratio: float = (remaining_mamba_count - parallel_layers_count) / parallel_layers_count

x: float = mamba_to_parallel_ratio
for l in range(total_layers_count):
if layer_type_list[l] == Symbols.MAMBA:
if x < 0.5:
layer_type_list[l] = Symbols.PARALLEL
x += mamba_to_parallel_ratio
else:
x -= 1

return layer_type_list


Expand Down Expand Up @@ -85,20 +103,21 @@ def allocate_layers(
total_layers_count: int,
target_attention_ratio: float,
target_mlp_ratio: float,
target_parallel_hybrid_ratio: float,
override_pattern: str = None,
) -> list:
assert total_layers_count > 0
assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0
assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0
assert target_attention_ratio + target_mlp_ratio <= 1.0
# Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio
assert target_parallel_hybrid_ratio >= 0.0 and target_parallel_hybrid_ratio <= 1.0
assert target_attention_ratio + target_mlp_ratio + target_parallel_hybrid_ratio <= 1.0

layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio)
layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio, target_parallel_hybrid_ratio)

if override_pattern is not None:
layer_type_list_override = _allocate_override(total_layers_count, override_pattern)
log_single_rank(logger, logging.INFO, "Using hybrid override pattern")
if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match(
if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0) and not _layer_counts_match(
layer_type_list_override, layer_type_list
):
raise ValueError(
Expand All @@ -116,18 +135,21 @@ def allocate_layers(
log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}")
layer_type_list = layer_type_list_override

if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None:
if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0 or override_pattern is not None:
actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION)
actual_attention_ratio = actual_attention_layers_count / total_layers_count
actual_mlp_layers_count = layer_type_list.count(Symbols.MLP)
actual_mlp_ratio = actual_mlp_layers_count / total_layers_count
actual_parallel_layers_count = layer_type_list.count(Symbols.PARALLEL)
actual_parallel_ratio = actual_parallel_layers_count / total_layers_count
allocation_string = "".join(layer_type_list)
log_single_rank(
logger,
logging.INFO,
f"Hybrid allocation ({Symbols.MAMBA} is mamba, "
f"{Symbols.ATTENTION} is attention, "
f"{Symbols.MLP} is mlp):",
f"{Symbols.MLP} is mlp, "
f"{Symbols.PARALLEL} is parallel):",
)
log_single_rank(logger, logging.INFO, allocation_string)
log_single_rank(
Expand All @@ -153,39 +175,26 @@ def allocate_layers(
f"Target mlp ratio: {target_mlp_ratio:.2f}. "
f"Actual mlp ratio: {actual_mlp_ratio:.2f}.",
)
log_single_rank(
logger,
logging.INFO,
f"{actual_parallel_layers_count} parallel layers in " f"{total_layers_count} total layers.",
)
log_single_rank(
logger,
logging.INFO,
f"Target parallel ratio: {target_parallel_hybrid_ratio:.2f}. "
f"Actual parallel ratio: {actual_parallel_ratio:.2f}.",
)
return layer_type_list


if __name__ == "__main__":
test_cases = [
# (10, 0.2, 0.0),
# (48, 0.0, 0.0), # will not print anything
# (48, 0.1, 0.0),
# 48, 0.3, 0.0),
# (48, 0.5, 0.0),
# (48, 0.6, 0.0),
# (48, 0.7, 0.0),
# (10, 0.0, 0.1),
# (10, 0.0, 0.3),
# (10, 0.0, 0.5),
# (10, 0.1, 0.1),
# (10, 0.2, 0.2),
# (10, 0.3, 0.3),
# (10, 0.5, 0.5),
# (48, 0.2, 0.3),
# (48, 0.5, 0.2),
# (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"),
# (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"),
# (48, 0.5, 0.5),
# (10, 0.3, 0.2, "MMM*-*M*M-"),
# (10, 0.3, 0.2, "MM*M-*M*M-"),
(9, 0.0, 0.0, "M*-M*-M*-"),
(9, 0.0, 0.0, "MMMMMMMMM"),
(9, 0.0, 0.0, 0.0, "M*-M*-M*-"),
(9, 0.0, 0.0, 0.0, "MMMMMMMMM"),
(10, 0.2, 0.1, 0.2),
]
for t in test_cases:
print("")
allocate_layers(*t)
allocate_layers(*t)
Loading