diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 8ef4a2ab3e..e60ec6ad84 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -6,8 +6,10 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules -from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer,MambaLayerSubmodules +from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer, ParallelHybridLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mlp_layer import MLPLayer from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules @@ -64,5 +66,29 @@ mlp_bda=get_bias_dropout_add, ), ), + + parallel_hybrid_layer=ModuleSpec( + module=ParallelHybridLayer, + submodules=ParallelHybridLayerSubmodules( + mamba_mixer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + parallel_hybrid_bda=get_bias_dropout_add, + self_attention=ModuleSpec( + module=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + ), + ), + ), ), ) diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index fb3df5e23f..dab9ac42ae 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -31,6 +31,8 @@ class MambaModel(LanguageModule): (used with pipeline parallelism). Defaults to True. hybrid_attention_ratio (float, optional): The target ratio of attention layers to total layers + parallel_hybrid_ratio (float, optional): The target ratio of parallel hybrid + layers to total layers hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers hybrid_override_pattern (str, optional): The hybrid layer pattern to override with post_process (bool, optional): Include an output layer (used with pipeline parallelism). @@ -60,6 +62,7 @@ def __init__( max_sequence_length: int, pre_process: bool = True, hybrid_attention_ratio: float = 0.0, + parallel_hybrid_ratio: float = 0.0, hybrid_mlp_ratio: float = 0.0, hybrid_override_pattern: str = None, post_process: bool = True, @@ -84,6 +87,7 @@ def __init__( self.max_sequence_length = max_sequence_length self.pre_process = pre_process self.hybrid_attention_ratio = hybrid_attention_ratio + self.parallel_hybrid_ratio = parallel_hybrid_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern self.post_process = post_process @@ -121,6 +125,7 @@ def __init__( self.config, pre_process=self.pre_process, hybrid_attention_ratio=self.hybrid_attention_ratio, + parallel_hybrid_ratio=self.parallel_hybrid_ratio, hybrid_mlp_ratio=self.hybrid_mlp_ratio, hybrid_override_pattern=self.hybrid_override_pattern, post_process=self.post_process, diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index cd8eb21bae..65466169e6 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -29,6 +29,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor @@ -86,6 +87,7 @@ class MambaStackSubmodules: mamba_layer: Union[ModuleSpec, type] = IdentityOp attention_layer: Union[ModuleSpec, type] = IdentityOp mlp_layer: Union[ModuleSpec, type] = IdentityOp + parallel_hybrid_layer: Union[ModuleSpec, type] = IdentityOp class MambaStack(MegatronModule): @@ -123,6 +125,7 @@ def __init__( pre_process: bool = True, hybrid_attention_ratio: float = 0.0, hybrid_mlp_ratio: float = 0.0, + parallel_hybrid_ratio: float = 0.0, hybrid_override_pattern: str = None, post_layer_norm: bool = True, post_process: bool = True, @@ -146,11 +149,13 @@ def __init__( self.hybrid_attention_ratio = hybrid_attention_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern + self.parallel_hybrid_ratio = parallel_hybrid_ratio layer_type_list = allocate_layers( self.config.num_layers, self.hybrid_attention_ratio, self.hybrid_mlp_ratio, + self.parallel_hybrid_ratio, self.hybrid_override_pattern, ) @@ -188,6 +193,13 @@ def __init__( layer_number=i + 1, pg_collection=pg_collection, ) + elif layer_type == LayerSymbols.PARALLEL: + layer = build_module( + submodules.parallel_hybrid_layer, + config=self.config, + layer_number=i + 1 + pp_layer_offset, + model_comm_pgs=model_comm_pgs, + ) else: assert False, "unexpected layer_type" self.layers.append(layer) @@ -333,7 +345,15 @@ def forward( ) with inner_fp8_context: if isinstance(layer, TransformerLayer): - hidden_states, _ = layer( + hidden_states, _ = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + ) + if isinstance(layer, ParallelHybridLayer): + hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, inference_context=inference_context, diff --git a/megatron/core/ssm/mamba_hybrid_layer_allocation.py b/megatron/core/ssm/mamba_hybrid_layer_allocation.py index 26972b5454..9332a18e73 100644 --- a/megatron/core/ssm/mamba_hybrid_layer_allocation.py +++ b/megatron/core/ssm/mamba_hybrid_layer_allocation.py @@ -18,13 +18,13 @@ class Symbols: MAMBA = "M" ATTENTION = "*" MLP = "-" - VALID = {MAMBA, ATTENTION, MLP} + PARALLEL = "P" + VALID = {MAMBA, ATTENTION, MLP, PARALLEL} def _allocate_auto( - total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float + total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, target_parallel_hybrid_ratio: float ) -> list: - # First, allocate attention (evenly spaced, starting and ending with mamba) attention_layers_count: int = round(total_layers_count * target_attention_ratio) mamba_layers_count: int = total_layers_count - attention_layers_count mamba_sections_count: int = attention_layers_count + 1 @@ -39,8 +39,6 @@ def _allocate_auto( else: x -= 1 - # Next, allocate mlp - # (evenly distributed, but right-justified, not replacing attention) mlp_layers_count: int = round(total_layers_count * target_mlp_ratio) if mlp_layers_count > 0: mamba_layers_count -= mlp_layers_count @@ -55,6 +53,26 @@ def _allocate_auto( else: x -= 1 + parallel_layers_count: int = round(total_layers_count * target_parallel_hybrid_ratio) + if parallel_layers_count > 0: + remaining_mamba_count = layer_type_list.count(Symbols.MAMBA) + if remaining_mamba_count > 0: + if parallel_layers_count >= remaining_mamba_count: + for l in range(total_layers_count): + if layer_type_list[l] == Symbols.MAMBA: + layer_type_list[l] = Symbols.PARALLEL + else: + mamba_to_parallel_ratio: float = (remaining_mamba_count - parallel_layers_count) / parallel_layers_count + + x: float = mamba_to_parallel_ratio + for l in range(total_layers_count): + if layer_type_list[l] == Symbols.MAMBA: + if x < 0.5: + layer_type_list[l] = Symbols.PARALLEL + x += mamba_to_parallel_ratio + else: + x -= 1 + return layer_type_list @@ -85,20 +103,21 @@ def allocate_layers( total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, + target_parallel_hybrid_ratio: float, override_pattern: str = None, ) -> list: assert total_layers_count > 0 assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0 assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0 - assert target_attention_ratio + target_mlp_ratio <= 1.0 - # Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio + assert target_parallel_hybrid_ratio >= 0.0 and target_parallel_hybrid_ratio <= 1.0 + assert target_attention_ratio + target_mlp_ratio + target_parallel_hybrid_ratio <= 1.0 - layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio) + layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio, target_parallel_hybrid_ratio) if override_pattern is not None: layer_type_list_override = _allocate_override(total_layers_count, override_pattern) log_single_rank(logger, logging.INFO, "Using hybrid override pattern") - if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match( + if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0) and not _layer_counts_match( layer_type_list_override, layer_type_list ): raise ValueError( @@ -116,18 +135,21 @@ def allocate_layers( log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}") layer_type_list = layer_type_list_override - if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None: + if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0 or override_pattern is not None: actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION) actual_attention_ratio = actual_attention_layers_count / total_layers_count actual_mlp_layers_count = layer_type_list.count(Symbols.MLP) actual_mlp_ratio = actual_mlp_layers_count / total_layers_count + actual_parallel_layers_count = layer_type_list.count(Symbols.PARALLEL) + actual_parallel_ratio = actual_parallel_layers_count / total_layers_count allocation_string = "".join(layer_type_list) log_single_rank( logger, logging.INFO, f"Hybrid allocation ({Symbols.MAMBA} is mamba, " f"{Symbols.ATTENTION} is attention, " - f"{Symbols.MLP} is mlp):", + f"{Symbols.MLP} is mlp, " + f"{Symbols.PARALLEL} is parallel):", ) log_single_rank(logger, logging.INFO, allocation_string) log_single_rank( @@ -153,39 +175,26 @@ def allocate_layers( f"Target mlp ratio: {target_mlp_ratio:.2f}. " f"Actual mlp ratio: {actual_mlp_ratio:.2f}.", ) + log_single_rank( + logger, + logging.INFO, + f"{actual_parallel_layers_count} parallel layers in " f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"Target parallel ratio: {target_parallel_hybrid_ratio:.2f}. " + f"Actual parallel ratio: {actual_parallel_ratio:.2f}.", + ) return layer_type_list if __name__ == "__main__": test_cases = [ - # (10, 0.2, 0.0), - # (48, 0.0, 0.0), # will not print anything - # (48, 0.1, 0.0), - # 48, 0.3, 0.0), - # (48, 0.5, 0.0), - # (48, 0.6, 0.0), - # (48, 0.7, 0.0), - # (10, 0.0, 0.1), - # (10, 0.0, 0.3), - # (10, 0.0, 0.5), - # (10, 0.1, 0.1), - # (10, 0.2, 0.2), - # (10, 0.3, 0.3), - # (10, 0.5, 0.5), - # (48, 0.2, 0.3), - # (48, 0.5, 0.2), - # (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"), - # (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.5, 0.5), - # (10, 0.3, 0.2, "MMM*-*M*M-"), - # (10, 0.3, 0.2, "MM*M-*M*M-"), - (9, 0.0, 0.0, "M*-M*-M*-"), - (9, 0.0, 0.0, "MMMMMMMMM"), + (9, 0.0, 0.0, 0.0, "M*-M*-M*-"), + (9, 0.0, 0.0, 0.0, "MMMMMMMMM"), + (10, 0.2, 0.1, 0.2), ] for t in test_cases: print("") - allocate_layers(*t) + allocate_layers(*t) \ No newline at end of file diff --git a/megatron/core/ssm/parallel_hybrid_layer.py b/megatron/core/ssm/parallel_hybrid_layer.py new file mode 100644 index 0000000000..322615eea9 --- /dev/null +++ b/megatron/core/ssm/parallel_hybrid_layer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Union, Tuple + +import torch + +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + +from megatron.core.ssm.mamba_mixer import MambaMixerSubmodules +from megatron.core.transformer.attention import SelfAttentionSubmodules + + +@dataclass +class ParallelHybridLayerSubmodules: + """Configuration class for specifying the submodules of a parallel hybrid layer.""" + mamba_mixer: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp + input_layernorm: Union[ModuleSpec, type] = IdentityOp + parallel_hybrid_bda: Union[ModuleSpec, type] = IdentityOp + + +class ParallelHybridLayer(MegatronModule): + """A parallel hybrid layer that combines Mamba and Attention components.""" + + def __init__( + self, + config: TransformerConfig, + submodules: ParallelHybridLayerSubmodules, + layer_number: int = 1, + residual_in_fp32=False, + model_comm_pgs: ModelCommProcessGroups = None, + ): + super().__init__(config) + assert model_comm_pgs is not None, "model_comm_pgs must be provided for ParallelHybridLayer" + + self.config = config + self.layer_number = layer_number + self.residual_in_fp32 = residual_in_fp32 + self.hidden_dropout = config.hidden_dropout + + self.input_layernorm = build_module( + submodules.input_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + mamba_submodules = MambaMixerSubmodules( + in_proj=submodules.mamba_mixer.submodules.in_proj, + out_proj=submodules.mamba_mixer.submodules.out_proj, + ) + + self.mamba_mixer = build_module( + submodules.mamba_mixer.module, + submodules=mamba_submodules, + config=self.config, + layer_number=layer_number, + d_model=self.config.hidden_size, + model_comm_pgs=model_comm_pgs + ) + + attention_optional_kwargs = {} + if self.config.context_parallel_size > 1 and self.config.cp_comm_type is not None: + if isinstance(self.config.cp_comm_type, list): + attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type[self.layer_number] + else: + attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type + model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups() + attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs + + attention_submodules = SelfAttentionSubmodules( + linear_qkv=submodules.self_attention.module.submodules.linear_qkv, + core_attention=submodules.self_attention.module.submodules.core_attention, + linear_proj=submodules.self_attention.module.submodules.linear_proj, + q_layernorm=getattr(submodules.self_attention.module.submodules, 'q_layernorm', None), + k_layernorm=getattr(submodules.self_attention.module.submodules, 'k_layernorm', None), + ) + + self.self_attention = build_module( + submodules.self_attention.module, + submodules=attention_submodules, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + self.parallel_hybrid_bda = build_module(submodules.parallel_hybrid_bda) + self.bias_dropout_add_exec_handler = torch.enable_grad + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + position_ids: Optional[torch.Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + residual = hidden_states + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = hidden_states.to(dtype=self.config.params_dtype) + hidden_states = self.input_layernorm(hidden_states) + + outputs = [] + biases = [] + + mamba_output, mamba_bias = self.mamba_mixer( + hidden_states, + inference_context=inference_context, + ) + outputs.append(mamba_output) + if mamba_bias is not None: + biases.append(mamba_bias) + + attn_output, attn_bias = self.self_attention( + hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + outputs.append(attn_output) + if attn_bias is not None: + biases.append(attn_bias) + + combined_output = sum(outputs) + combined_bias = sum(biases) if biases else None + + out_with_bias = (combined_output, combined_bias) + + with self.bias_dropout_add_exec_handler(): + hidden_states = self.parallel_hybrid_bda( + training=self.training, + fused=self.config.bias_dropout_fusion + )(out_with_bias, residual, self.hidden_dropout) + + return hidden_states + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + """Allocate inference cache for both components.""" + caches = {} + + if self.mamba_mixer is not None: + mamba_cache = self.mamba_mixer.allocate_inference_cache( + batch_size, max_seqlen, dtype + ) + caches['mamba'] = mamba_cache + + if self.self_attention is not None: + pass + + return caches + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + from megatron.core.transformer.utils import sharded_state_dict_default + + sharded_state_dict = {} + + norm_sd = sharded_state_dict_default( + self.input_layernorm, f'{prefix}input_layernorm.', sharded_offsets, metadata + ) + sharded_state_dict.update(norm_sd) + + mamba_sd = sharded_state_dict_default( + self.mamba_mixer, f'{prefix}mamba_mixer.', sharded_offsets, metadata + ) + sharded_state_dict.update(mamba_sd) + + attn_sd = sharded_state_dict_default( + self.self_attention, f'{prefix}self_attention.', sharded_offsets, metadata + ) + sharded_state_dict.update(attn_sd) + + return sharded_state_dict + + \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index d138e0cfa8..905b536c44 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1251,8 +1251,7 @@ def core_transformer_config_from_args(args, config_class=None): elif hasattr(args, 'kitchen_recipe_number') and args.kitchen_recipe_number is not None: kw_args['use_kitchen'] = True kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number) - - + # Return config. return config_class(**kw_args) @@ -3147,6 +3146,7 @@ def _add_experimental_args(parser): '--hidden-size * expand // --mamba-head-dim') group.add_argument('--is-hybrid-model', default=False, action="store_true", help='Indicates whether the model is a hybrid model.') + group.add_argument('--parallel-hybrid-ratio', type=float, default=0.0, help='Ratio of parallel hybrid layers.') group.add_argument('--disable-mamba-mem-eff-path', default=False, action="store_true", help='Disable Mamba efficient path.') group.add_argument('--yaml-cfg', type=str, default=None, diff --git a/tools/checkpoint/loader_parallelhybrid.py b/tools/checkpoint/loader_parallelhybrid.py new file mode 100644 index 0000000000..b32f76215a --- /dev/null +++ b/tools/checkpoint/loader_parallelhybrid.py @@ -0,0 +1,319 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import json +import os +import sys +import torch +import types + +from loader_base import MegatronCheckpointLoaderBase + + +def add_arguments(parser): + """Add command-line arguments relevant to Falcon-H1 model loading.""" + group = parser.add_argument_group(title='Falcon-H1 loader') + + group.add_argument('--true-vocab-size', type=int, default=None, + help='Original size of vocab; if specified, trims padding from embedding table.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to a vocab file. If specified, determines vocab size to trim padding.') + group.add_argument('--megatron-path', type=str, default=None, + help='Base directory of Megatron repository') + group.add_argument('--position-embedding-type', + type=str, + default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Type of position embedding.') + group.add_argument('--loader-transformer-impl', default='local', + choices=['local', 'transformer_engine'], + help='Which Transformer implementation to use.') + + +class MegatronCheckpointLoaderFalconH1(MegatronCheckpointLoaderBase): + """ + Falcon-H1 specific checkpoint loader that handles hybrid architecture + with alternating Mamba+Attention layers and MLP-only layers. + + Architecture: + - Even layers (0,2,4,...): Hybrid (Mamba mixer + Self-attention) + - Odd layers (1,3,5,...): MLP-only + """ + + def build_sys_argv(self): + """ + Construct a sys.argv list for Megatron's argument parser. + """ + return [ + *super().build_sys_argv(), + '--position-embedding-type', self.args.position_embedding_type, + ] + + def import_model_provider(self): + """Return the Mamba model provider for Falcon-H1.""" + from pretrain_mamba import model_provider + return model_provider + + def is_hybrid_layer(self, layer_idx): + """Determine if a layer is hybrid (Mamba + Attention) or MLP-only.""" + return layer_idx % 2 == 0 + + def extract_mamba_weights(self, model, layer_idx): + """Extract Mamba mixer weights from a hybrid layer.""" + layer_name = f"decoder.layers.{layer_idx}.mamba_mixer" + + mamba_weights = {} + + # Get the mamba mixer module + mamba_mixer = None + for name, module in model.named_modules(): + if name == layer_name: + mamba_mixer = module + break + + if mamba_mixer is None: + raise ValueError(f"Could not find mamba_mixer at layer {layer_idx}") + + # Extract Mamba-specific parameters + mamba_weights["A_log"] = getattr(mamba_mixer, 'A_log', None) + mamba_weights["D"] = getattr(mamba_mixer, 'D', None) + mamba_weights["dt_bias"] = getattr(mamba_mixer, 'dt_bias', None) + + # Conv1D weights + if hasattr(mamba_mixer, 'conv1d'): + mamba_weights["conv1d_weight"] = mamba_mixer.conv1d.weight + mamba_weights["conv1d_bias"] = mamba_mixer.conv1d.bias + + # Input and output projections + if hasattr(mamba_mixer, 'in_proj'): + mamba_weights["in_proj_weight"] = mamba_mixer.in_proj.weight + # Note: pre_norm_weight is extracted separately above + + if hasattr(mamba_mixer, 'out_proj'): + mamba_weights["out_proj_weight"] = mamba_mixer.out_proj.weight + + # Norm weights - GET BOTH TYPES + if hasattr(mamba_mixer, 'norm'): + mamba_weights["internal_norm_weight"] = mamba_mixer.norm.weight + + # Pre-norm weight (from in_proj layer norm) + if hasattr(mamba_mixer, 'in_proj') and hasattr(mamba_mixer.in_proj, 'layer_norm_weight'): + mamba_weights["pre_norm_weight"] = mamba_mixer.in_proj.layer_norm_weight + + return mamba_weights + + def extract_attention_weights(self, model, layer_idx): + """Extract self-attention weights from a hybrid layer.""" + layer_name = f"decoder.layers.{layer_idx}.self_attention" + + attention_weights = {} + + # Get the self attention module + self_attention = None + for name, module in model.named_modules(): + if name == layer_name: + self_attention = module + break + + if self_attention is None: + raise ValueError(f"Could not find self_attention at layer {layer_idx}") + + # QKV projection + if hasattr(self_attention, 'linear_qkv'): + attention_weights["qkv_weight"] = self_attention.linear_qkv.weight + attention_weights["qkv_norm_weight"] = getattr(self_attention.linear_qkv, 'layer_norm_weight', None) + + # Output projection + if hasattr(self_attention, 'linear_proj'): + attention_weights["proj_weight"] = self_attention.linear_proj.weight + + return attention_weights + + def extract_mlp_weights(self, model, layer_idx): + """Extract MLP weights from an MLP-only layer.""" + layer_name = f"decoder.layers.{layer_idx}.mlp" + + mlp_weights = {} + + # Get the MLP module + mlp = None + for name, module in model.named_modules(): + if name == layer_name: + mlp = module + break + + if mlp is None: + raise ValueError(f"Could not find mlp at layer {layer_idx}") + + # FC1 (first linear layer) + if hasattr(mlp, 'linear_fc1'): + mlp_weights["fc1_weight"] = mlp.linear_fc1.weight + mlp_weights["fc1_norm_weight"] = getattr(mlp.linear_fc1, 'layer_norm_weight', None) + + # FC2 (second linear layer) + if hasattr(mlp, 'linear_fc2'): + mlp_weights["fc2_weight"] = mlp.linear_fc2.weight + + return mlp_weights + + def send_model_over_queue(self): + """Send Falcon-H1 model over the queue with proper hybrid layer handling.""" + # Send metadata first + self.send_metadata_over_queue() + + # Get model parameters + tp_size = self.margs.tensor_model_parallel_size + pp_size = self.margs.pipeline_model_parallel_size + vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 + + # Get first pipeline models for embeddings/final norm + first_pipeline_models = self.all_models[0][0] + + # 1) Send embeddings + message = {} + for i, model in enumerate(first_pipeline_models): + # Extract embedding weights + for name, param in model.named_parameters(): + if 'embedding.word_embeddings.weight' in name: + if i == 0: + message["word embeddings"] = param + else: + message["word embeddings"] = torch.cat([message["word embeddings"], param], dim=0) + elif 'position_embeddings.weight' in name and self.md.position_embedding_type == 'learned_absolute': + if i == 0: # Only take from rank 0 + message["position embeddings"] = param + + if "position embeddings" not in message: + message["position embeddings"] = None + + self.queue_put("embeddings", message) + + # 2) Process each layer based on type + total_layer_num = 0 + for vp_rank in range(vp_size): + for pp_rank in range(pp_size): + models = self.all_models[pp_rank][vp_rank] + + # Determine number of layers in this model shard + model = models[0] + layer_count = 0 + max_layer_idx = -1 + for name, _ in model.named_parameters(): + if 'decoder.layers.' in name: + # Extract layer index + parts = name.split('.') + if len(parts) > 2 and parts[2].isdigit(): + layer_idx = int(parts[2]) + max_layer_idx = max(max_layer_idx, layer_idx) + + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + + for layer_idx in range(num_layers): + if self.is_hybrid_layer(total_layer_num): + # Process hybrid layer (Mamba + Attention) + message = {} + + # Collect Mamba weights across TP ranks + mamba_weights_per_rank = [] + attention_weights_per_rank = [] + + for model_tp in models: + mamba_weights = self.extract_mamba_weights(model_tp, layer_idx) + attention_weights = self.extract_attention_weights(model_tp, layer_idx) + mamba_weights_per_rank.append(mamba_weights) + attention_weights_per_rank.append(attention_weights) + + # Mamba components (typically not sharded across TP) + message["mamba A_log"] = mamba_weights_per_rank[0]["A_log"] + message["mamba D"] = mamba_weights_per_rank[0]["D"] + message["mamba dt_bias"] = mamba_weights_per_rank[0]["dt_bias"] + message["mamba conv1d weight"] = mamba_weights_per_rank[0]["conv1d_weight"] + message["mamba conv1d bias"] = mamba_weights_per_rank[0]["conv1d_bias"] + message["mamba pre norm weight"] = mamba_weights_per_rank[0]["pre_norm_weight"] + message["mamba internal norm weight"] = mamba_weights_per_rank[0]["internal_norm_weight"] + + # Mamba projections (may be sharded) + if len(mamba_weights_per_rank) > 1 and mamba_weights_per_rank[1]["in_proj_weight"] is not None: + # Concatenate across TP ranks + message["mamba in_proj weight"] = torch.cat([w["in_proj_weight"] for w in mamba_weights_per_rank], dim=0) + message["mamba out_proj weight"] = torch.cat([w["out_proj_weight"] for w in mamba_weights_per_rank], dim=1) + else: + message["mamba in_proj weight"] = mamba_weights_per_rank[0]["in_proj_weight"] + message["mamba out_proj weight"] = mamba_weights_per_rank[0]["out_proj_weight"] + + # Attention components (sharded across TP) + message["attention input norm weight"] = attention_weights_per_rank[0]["qkv_norm_weight"] + + # Concatenate QKV and dense weights across TP ranks + if len(attention_weights_per_rank) > 1: + message["attention qkv weight"] = torch.cat([w["qkv_weight"] for w in attention_weights_per_rank], dim=0) + message["attention dense weight"] = torch.cat([w["proj_weight"] for w in attention_weights_per_rank], dim=1) + else: + message["attention qkv weight"] = attention_weights_per_rank[0]["qkv_weight"] + message["attention dense weight"] = attention_weights_per_rank[0]["proj_weight"] + + self.queue_put(f"hybrid layer {total_layer_num}", message) + + else: + # Process MLP-only layer + message = {} + + # Collect MLP weights across TP ranks + mlp_weights_per_rank = [] + for model_tp in models: + mlp_weights = self.extract_mlp_weights(model_tp, layer_idx) + mlp_weights_per_rank.append(mlp_weights) + + # MLP norm (not sharded) + message["mlp input norm weight"] = mlp_weights_per_rank[0]["fc1_norm_weight"] + + # MLP weights (sharded across TP) + if len(mlp_weights_per_rank) > 1: + message["mlp fc1 weight"] = torch.cat([w["fc1_weight"] for w in mlp_weights_per_rank], dim=0) + message["mlp fc2 weight"] = torch.cat([w["fc2_weight"] for w in mlp_weights_per_rank], dim=1) + else: + message["mlp fc1 weight"] = mlp_weights_per_rank[0]["fc1_weight"] + message["mlp fc2 weight"] = mlp_weights_per_rank[0]["fc2_weight"] + + self.queue_put(f"mlp layer {total_layer_num}", message) + + total_layer_num += 1 + + # 3) Send final norm + message = {} + for name, param in models[0].named_parameters(): + if 'decoder.final_norm.weight' in name: + message["weight"] = param + break + self.queue_put("final norm", message) + + # 4) Send output layer + if self.md.output_layer: + message = {} + output_weights = [] + for model in models: + for name, param in model.named_parameters(): + if 'output_layer.weight' in name: + output_weights.append(param) + break + + if output_weights: + if len(output_weights) > 1: + message["weight"] = torch.cat(output_weights, dim=0) + else: + message["weight"] = output_weights[0] + self.queue_put("output layer", message) + + self.queue.put("done") + + +def load_checkpoint(queue, args): + """ + Required top-level function that creates the loader, + calls its .load(), and handles exceptions by signaling 'exit'. + """ + loader = MegatronCheckpointLoaderFalconH1(args, queue) + try: + loader.load() + except Exception as e: + queue.put("exit") + raise e \ No newline at end of file diff --git a/tools/checkpoint/saver_parallelhybrid_hf.py b/tools/checkpoint/saver_parallelhybrid_hf.py new file mode 100644 index 0000000000..0414fc2096 --- /dev/null +++ b/tools/checkpoint/saver_parallelhybrid_hf.py @@ -0,0 +1,500 @@ +import sys +import os +import gc +import math +import json +from pathlib import Path +from shutil import rmtree + +import torch +import torch.multiprocessing as mp +from transformers import ( + AutoModelForCausalLM, + FalconH1Config, + FalconH1ForCausalLM, + GenerationConfig, +) + +sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir, + os.path.pardir))) +try: + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding +except ModuleNotFoundError: + print("Unable to import Megatron. Exiting.") + exit(1) + +def add_arguments(parser): + group = parser.add_argument_group(title="Parallel Hybrid HF saver.") + group.add_argument( + "--hf-tokenizer", + type=str, + default=None, + help="HF tokenizer (example: tiiuae/Falcon-H1-0.5B-Instruct)", + ) + group.add_argument( + "--check-eq-hf", + type=str, + default=None, + help="check equality with HF model, example: tiiuae/Falcon-H1-1.5B-Instruct", + ) + group.add_argument( + "--save-chat-model", + action='store_true', + help="flag to save chat model or not", + ) + +def perform_check( + state_dict: dict[str, torch.Tensor], ref_state_dict: dict[str, torch.Tensor] +) -> dict[str, torch.Tensor]: + """ + Given a reference state dict, check that state_dict is equal to it + then pop the keys from ref_state_dict + """ + for key in state_dict: + if key in ref_state_dict: + if not torch.equal(ref_state_dict[key], state_dict[key]): + print(f"Warning: Mismatch found in {key}") + ref_state_dict.pop(key) + else: + print(f"Warning: Key {key} not found in reference model") + return ref_state_dict + +def save_layer( + state_dict: dict[str, torch.Tensor], + index_dict: dict, + dir_path: str, + filename: str, + check_reference: bool = False, + ref_state_dict: dict[str, torch.Tensor] = None, +) -> tuple[dict, dict[str, torch.Tensor]]: + """check state dict against a reference one if needed + update index_dict + save state dict + """ + if check_reference and ref_state_dict is not None: + ref_state_dict = perform_check(state_dict, ref_state_dict) + for layer_name, weight_matrix in state_dict.items(): + index_dict["weight_map"][layer_name] = filename + index_dict["metadata"]["total_size"] += weight_matrix.numel() + print(f"saving state dict to {dir_path}/{filename}") + torch.save(state_dict, f"{dir_path}/{filename}") + return index_dict, ref_state_dict + +def is_hybrid_layer(layer_idx: int) -> bool: + """Determine if a layer is hybrid (Mamba + Attention) or MLP-only""" + return layer_idx % 2 == 0 + +def process_hybrid_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: + """Process weights for hybrid layers (Mamba + Attention)""" + state_dict = {} + + # Mamba mixer components + state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = message["mamba A_log"] + state_dict[f"model.layers.{layer_idx}.mamba.D"] = message["mamba D"] + state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = message["mamba dt_bias"] + state_dict[f"model.layers.{layer_idx}.mamba.conv1d.weight"] = message["mamba conv1d weight"] + state_dict[f"model.layers.{layer_idx}.mamba.conv1d.bias"] = message["mamba conv1d bias"] + state_dict[f"model.layers.{layer_idx}.mamba.in_proj.weight"] = message["mamba in_proj weight"] + state_dict[f"model.layers.{layer_idx}.mamba.out_proj.weight"] = message["mamba out_proj weight"] + + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["mamba pre norm weight"] + state_dict[f"model.layers.{layer_idx}.mamba.norm.weight"] = message["mamba internal norm weight"] + + # Self-attention components - PROPER QKV SPLITTING + qkv_weight = message["attention qkv weight"] + + # using standard Llama QKV layout + head_size = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads # 128 + heads_per_group = falcon_h1_config.num_attention_heads // falcon_h1_config.num_key_value_heads # 4 + qkv_total_heads = falcon_h1_config.num_attention_heads + 2 * falcon_h1_config.num_key_value_heads # 12 + + # Reshape QKV to [12, 128, 1024] like Llama does + qkv_weights = qkv_weight.reshape([qkv_total_heads, head_size, falcon_h1_config.hidden_size]) + + # Create slices for Q, K, V exactly like Llama saver + q_slice = torch.cat([ + torch.arange( + (heads_per_group + 2) * i, + (heads_per_group + 2) * i + heads_per_group, + ) + for i in range(falcon_h1_config.num_key_value_heads) + ]) + k_slice = torch.arange(heads_per_group, qkv_total_heads, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_heads, (heads_per_group + 2)) + + # Extract Q, K, V using Llama's slicing approach + state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = qkv_weights[q_slice].reshape(-1, falcon_h1_config.hidden_size) + state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = qkv_weights[k_slice].reshape(-1, falcon_h1_config.hidden_size) + state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = qkv_weights[v_slice].reshape(-1, falcon_h1_config.hidden_size) + + # Attention output projection + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = message["attention dense weight"] + + # Attention layer norm + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["attention input norm weight"] + + return state_dict + +def process_mlp_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: + """Process weights for MLP-only layers""" + state_dict = {} + + # MLP components - FIXED NAMES TO FEED_FORWARD + mlp_fc1_weight = message["mlp fc1 weight"] + + # Split gate and up projections (assuming SwiGLU like Llama) + intermediate_size = falcon_h1_config.intermediate_size + + # Split the fc1 weight into gate_proj and up_proj + gate_proj_weight = mlp_fc1_weight[:intermediate_size, :] + up_proj_weight = mlp_fc1_weight[intermediate_size:, :] + + state_dict[f"model.layers.{layer_idx}.feed_forward.gate_proj.weight"] = gate_proj_weight + state_dict[f"model.layers.{layer_idx}.feed_forward.up_proj.weight"] = up_proj_weight + state_dict[f"model.layers.{layer_idx}.feed_forward.down_proj.weight"] = message["mlp fc2 weight"] + + # MLP layer norm - FIXED NAME + state_dict[f"model.layers.{layer_idx}.pre_ff_layernorm.weight"] = message["mlp input norm weight"] + + return state_dict + +def save_checkpoint(queue: mp.Queue, args): + def queue_get(name=None): + val = queue.get() + if val == "exit": + print("Loader exited, exiting saver") + exit(1) + if name is not None and args.checking and val["name"] != name: + val_name = val["name"] + print( + f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.' + ) + exit(1) + if name is not None: + print(f"received {name}") + return val + + md = queue_get() + + ### Verify compatibility of args + if not hasattr(md, "checkpoint_args"): + raise ValueError("missing checkpoint_args in metadata") + + # Falcon-H1 specific validations + if not hasattr(md.checkpoint_args, 'hybrid_architecture'): + print("Warning: hybrid_architecture not specified in checkpoint_args, assuming Falcon-H1") + + torch_dtype = torch.float32 + if md.checkpoint_args.bf16: + torch_dtype = torch.bfloat16 + if md.checkpoint_args.fp16: + raise ValueError("bf16 and fp16 cannot be both set.") + elif md.checkpoint_args.fp16: + torch_dtype = torch.float16 + if md.checkpoint_args.bf16: + raise ValueError("bf16 and fp16 cannot be both set.") + + ### init + save_dir = Path(args.save_dir) + tmp_save_dir = save_dir / "tmp" + save_dir.mkdir(exist_ok=True) + tmp_save_dir.mkdir(exist_ok=True) + index_dict = { + "weight_map": {}, + "metadata": {"total_size": 0}, + } + tokenizer = None + ref_state_dict = None + + ### prepare a reference model if needed + if args.check_eq_hf: + print(f"preparing checks with given HF model {args.check_eq_hf}") + ref_model = AutoModelForCausalLM.from_pretrained(args.check_eq_hf, trust_remote_code=True) + ref_state_dict = ref_model.state_dict() + + ### save tokenizer conf files + if args.hf_tokenizer: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.hf_tokenizer) + print(f"saving tokenizer to {args.save_dir}") + tokenizer.save_pretrained(args.save_dir) + + ### save config.json + falcon_h1_config = FalconH1Config( + # Basic model parameters from checkpoint + vocab_size=md.true_vocab_size if md.true_vocab_size else md.checkpoint_args.padded_vocab_size, + hidden_size=md.checkpoint_args.hidden_size, + intermediate_size=md.checkpoint_args.ffn_hidden_size, + num_hidden_layers=md.checkpoint_args.num_layers, + num_attention_heads=md.checkpoint_args.num_attention_heads, + num_key_value_heads=md.checkpoint_args.num_query_groups, + max_position_embeddings=md.checkpoint_args.max_position_embeddings, + rms_norm_eps=md.checkpoint_args.norm_epsilon, + tie_word_embeddings=not md.checkpoint_args.untie_embeddings_and_output_weights, + attention_dropout=md.checkpoint_args.attention_dropout, + + # Mamba parameters from checkpoint + mamba_d_state=md.checkpoint_args.mamba_state_dim, + mamba_d_conv=md.checkpoint_args.d_conv, + mamba_expand=md.checkpoint_args.expand, + mamba_d_ssm=md.checkpoint_args.d_inner, + mamba_n_heads=md.checkpoint_args.d_inner // md.checkpoint_args.mamba_head_dim, + mamba_d_head=md.checkpoint_args.mamba_head_dim, + mamba_n_groups=md.checkpoint_args.mamba_num_groups, + mamba_chunk_size=md.checkpoint_args.chunk_size, + mamba_conv_bias=md.checkpoint_args.conv_bias, + mamba_proj_bias=md.checkpoint_args.add_bias_linear, + mamba_norm_before_gate=md.checkpoint_args.norm_before_gate, + mamba_rms_norm=md.checkpoint_args.rmsnorm, + + # RoPE parameters from checkpoint + rope_theta=md.checkpoint_args.rotary_base, + + # Bias parameters from checkpoint + attention_bias=md.checkpoint_args.add_bias_linear, + mlp_bias=md.checkpoint_args.add_bias_linear, + projectors_bias=md.checkpoint_args.add_bias_linear, + + # Token IDs - from tokenizer if available, otherwise defaults + pad_token_id=getattr(tokenizer, 'pad_token_id', 0) if tokenizer else 0, + bos_token_id=getattr(tokenizer, 'bos_token_id', 1) if tokenizer else 1, + eos_token_id=getattr(tokenizer, 'eos_token_id', 2) if tokenizer else 2, + + # Parameters using FalconH1Config defaults (not in checkpoint) + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + num_logits_to_keep=1, + rope_scaling=None, + + # Model metadata + torch_dtype=torch_dtype, + architectures=["FalconH1ForCausalLM"], + model_type="falcon_h1", + transformers_version="4.52.0", + ) + + if args.hf_tokenizer: + falcon_h1_config.eos_token_id = tokenizer.eos_token_id + falcon_h1_config.bos_token_id = tokenizer.bos_token_id + + print(f"saving config.json to {tmp_save_dir}") + falcon_h1_config.save_pretrained(tmp_save_dir) + + ### save embedding layer + def pad_weight(orig_word_embed, true_vocab_size): + if true_vocab_size is not None: + # figure out what our padded vocab size is + orig_vocab_size = orig_word_embed.shape[0] + md.checkpoint_args.padded_vocab_size = _vocab_size_with_padding(true_vocab_size, md.checkpoint_args) + + # Cut out extra padding we don't need + if orig_vocab_size > md.checkpoint_args.padded_vocab_size: + full_word_embed = orig_word_embed[0:md.checkpoint_args.padded_vocab_size,:] + + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < md.checkpoint_args.padded_vocab_size: + padding_size = md.checkpoint_args.padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat(( + orig_word_embed, + orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1))) + + # Same size! + else: + full_word_embed = orig_word_embed + else: + print("Original vocab size not specified, leaving embedding table as-is. " + "If you've changed the tensor parallel size this could cause problems.") + md.checkpoint_args.padded_vocab_size = orig_word_embed.shape[0] + full_word_embed = orig_word_embed + return full_word_embed + + state_dict = { + "model.embed_tokens.weight": pad_weight(queue_get("embeddings")["word embeddings"], md.true_vocab_size) + } + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename="pytorch_model-embedding.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + for i_layer in range(falcon_h1_config.num_hidden_layers): + state_dict = {} + + if is_hybrid_layer(i_layer): + # Process hybrid layer (Mamba + Attention) - EVEN layers + message = queue_get(f"hybrid layer {i_layer}") + + # Add Mamba + Attention components from Megatron + hybrid_weights = process_hybrid_layer_weights(message, i_layer, falcon_h1_config) + state_dict.update(hybrid_weights) + + # Add MISSING MLP components (configured to output zeros = identity for addition) + mlp_intermediate_size = falcon_h1_config.intermediate_size + state_dict.update({ + # Gate and up can be anything since down_proj will zero everything out + f"model.layers.{i_layer}.feed_forward.gate_proj.weight": torch.randn( + mlp_intermediate_size, falcon_h1_config.hidden_size, + dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.feed_forward.up_proj.weight": torch.randn( + mlp_intermediate_size, falcon_h1_config.hidden_size, + dtype=torch_dtype + ) * 0.01, + # KEY: down_proj = 0 makes entire MLP output zero + f"model.layers.{i_layer}.feed_forward.down_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, mlp_intermediate_size, + dtype=torch_dtype + ), + f"model.layers.{i_layer}.pre_ff_layernorm.weight": torch.ones( + falcon_h1_config.hidden_size, dtype=torch_dtype + ), + }) + + else: + # Process MLP-only layer - ODD layers + message = queue_get(f"mlp layer {i_layer}") + + # Add MLP components from Megatron + mlp_weights = process_mlp_layer_weights(message, i_layer, falcon_h1_config) + state_dict.update(mlp_weights) + + # Add MISSING Mamba components (configured to output zeros = identity for addition) + mamba_intermediate_size = ( + falcon_h1_config.mamba_d_ssm if falcon_h1_config.mamba_d_ssm + else int(falcon_h1_config.mamba_expand * falcon_h1_config.hidden_size) + ) + conv_dim = mamba_intermediate_size + 2 * falcon_h1_config.mamba_n_groups * falcon_h1_config.mamba_d_state + projection_size = mamba_intermediate_size + conv_dim + falcon_h1_config.mamba_n_heads + + state_dict.update({ + f"model.layers.{i_layer}.mamba.A_log": torch.log(torch.arange(1, falcon_h1_config.mamba_n_heads + 1, dtype=torch_dtype)), + f"model.layers.{i_layer}.mamba.D": torch.ones(falcon_h1_config.mamba_n_heads, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.dt_bias": torch.ones(falcon_h1_config.mamba_n_heads, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.conv1d.weight": torch.randn( + conv_dim, 1, falcon_h1_config.mamba_d_conv, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.mamba.conv1d.bias": torch.zeros(conv_dim, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.in_proj.weight": torch.randn( + projection_size, falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + # KEY: out_proj = 0 makes entire Mamba output zero + f"model.layers.{i_layer}.mamba.out_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, mamba_intermediate_size, dtype=torch_dtype + ), + f"model.layers.{i_layer}.mamba.norm.weight": torch.ones(mamba_intermediate_size, dtype=torch_dtype), + }) + + # Add MISSING Attention components (configured to output zeros = identity for addition) + head_dim = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads + state_dict.update({ + f"model.layers.{i_layer}.self_attn.q_proj.weight": torch.randn( + falcon_h1_config.num_attention_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.self_attn.k_proj.weight": torch.randn( + falcon_h1_config.num_key_value_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.self_attn.v_proj.weight": torch.randn( + falcon_h1_config.num_key_value_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + # KEY: o_proj = 0 makes entire attention output zero + f"model.layers.{i_layer}.self_attn.o_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, + falcon_h1_config.num_attention_heads * head_dim, + dtype=torch_dtype + ), + f"model.layers.{i_layer}.input_layernorm.weight": torch.ones( + falcon_h1_config.hidden_size, dtype=torch_dtype + ), + }) + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename=f"pytorch_model-{i_layer + 1}.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + + ### save final norm and output layer + state_dict = { + "model.final_layernorm.weight": queue_get("final norm")["weight"] +} + if md.checkpoint_args.untie_embeddings_and_output_weights: + state_dict["lm_head.weight"] = pad_weight(queue_get("output layer")["weight"], md.true_vocab_size) + + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename="pytorch_model-lm-head.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + # final check + if ref_state_dict: + remaining_keys = list(ref_state_dict.keys()) + print(f"Warning: reference state dict has {len(remaining_keys)} additional layers not present in converted model:") + for key in remaining_keys[:10]: # Show first 10 + print(f" - {key}") + if len(remaining_keys) > 10: + print(f" ... and {len(remaining_keys) - 10} more") + + ### save index dict + index_dict["metadata"]["total_size"] *= { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + }[torch_dtype] + print(f"saving {tmp_save_dir}/pytorch_model.bin.index.json") + with open(f"{tmp_save_dir}/pytorch_model.bin.index.json", "w") as f: + json.dump(index_dict, f) + + ### load then save model in HF format + # Make space so we can load the model properly now. + del state_dict + gc.collect() + print(f"Loading the converted pytorch checkpoint in a Falcon-H1 HF model from {tmp_save_dir}") + model = FalconH1ForCausalLM.from_pretrained( + str(tmp_save_dir), torch_dtype=torch_dtype, low_cpu_mem_usage=True, trust_remote_code=True + ) + + # Avoid saving this as part of the config. + if hasattr(model.config, '_name_or_path'): + del model.config._name_or_path + model.config.torch_dtype = torch_dtype + print(f"Saving in the Transformers safe tensors format to {args.save_dir}") + model.save_pretrained(args.save_dir, safe_serialization=True) + + ### save chat config + generation_config = ( + GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=falcon_h1_config.bos_token_id, + eos_token_id=falcon_h1_config.eos_token_id, + ) + if args.save_chat_model + else GenerationConfig( + _from_model_config=True, + bos_token_id=falcon_h1_config.bos_token_id, + eos_token_id=falcon_h1_config.eos_token_id, + ) + ) + print(f"Saving generation config to {args.save_dir}") + generation_config.save_pretrained(args.save_dir) + + ### cleanup tmp + print(f"Deleting {tmp_save_dir}") + rmtree(tmp_save_dir)