Skip to content
This repository was archived by the owner on Oct 14, 2025. It is now read-only.
This repository was archived by the owner on Oct 14, 2025. It is now read-only.

Issue with Llama conversion for new release #24

@evellasques

Description

@evellasques

I noticed that in the latest release, llama_module.py was replaced with falcon_module.py. And then, in test_llama.sh, you rely on megatron_gpt_pretraining.py (which relies on MegatronGPTModel instead of llama_module.py).

The problem is, MegatronGPTModel at some point relies on transformer.py (instead of llama_module.py) and there, for Swiglu, you've replaced the two separate MLP layers (dense_h_to_4h and dense_h_to_4h_2) with a single one, twice as large:

        self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
            hidden_size,
            2*ffn_hidden_size if self.glu_activation_family else ffn_hidden_size,
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
            resume_from_checkpoint=resume_from_checkpoint,
            use_cpu_initialization=use_cpu_initialization,
            bias=bias,
            sequence_parallel_enabled=sequence_parallel,
            no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
            gradient_accumulation_fusion=gradient_accumulation_fusion,
            transfer_with_static_ring=transfer_with_static_ring,
        )

While in llama_module.pyyou had:

 self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
            hidden_size,
            ffn_hidden_size,  # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
            gather_output=False,
            init_method=init_method,
            skip_bias_add=True,
            use_cpu_initialization=use_cpu_initialization,
            bias=bias,
            sequence_parallel_enabled=sequence_parallel,
            no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce,
            gradient_accumulation_fusion=gradient_accumulation_fusion,
            transfer_with_static_ring=transfer_with_static_ring,
        )

        if activation in ['geglu', 'reglu', 'swiglu']:
            # Separate linear layer for *GLU activations.
            # Source: https://github.com/huggingface/transformers/blob/bee361c6f1f7704f8c688895f2f86f6e5ff84727/src/transformers/models/t5/modeling_t5.py#L292
            self.dense_h_to_4h_2 = tensor_parallel.ColumnParallelLinear(

But then you would have to change the checkpoint conversion script for llama as well, it's currently:

translation = {
        "model.language_model.embedding.word_embeddings.weight": (1, "model.embed_tokens.weight", 0, 0),
        # a['model']['language_model']['word_embeddings']['weight']
        "input_layernorm.weight": (0, "input_layernorm.weight", None, 0),
        "self_attention.query_key_value.weight": (1, "self_attn.query_key_value.weight", 0, 0),
        "self_attention.dense.weight": (1, "self_attn.o_proj.weight", 1, 0),
        "post_attention_layernorm.weight": (0, "post_attention_layernorm.weight", None, 0),
        "self_attention.core_attention.rotary_emb.inv_freq": (0, "self_attn.rotary_emb.inv_freq", None, 0),
        "mlp.dense_h_to_4h.weight": (1, "mlp.gate_proj.weight", 0, 0),
        "mlp.dense_h_to_4h_2.weight": (1, "mlp.up_proj.weight", 0, 0),
        "mlp.dense_4h_to_h.weight": (1, "mlp.down_proj.weight", 1, 0),
        "model.language_model.encoder.final_layernorm.weight": (0, "model.norm.weight", None, 0),
        "model.language_model.output_layer.weight": (1, "lm_head.weight", 0, 0),
    }

This is currently causing a crash when I try to load a checkpoint converted from HF Llama since it expects dense_h_to_4h to be a concatenation of gate_proj and up_proj (from the HF checkpoint):

RuntimeError: Error(s) in loading state_dict for MegatronGPTModel:
        size mismatch for model.language_model.encoder.layers.0.mlp.dense_h_to_4h.weight: copying a param with 
shape torch.Size([1376, 4096]) from checkpoint, the shape in current model is torch.Size([2752, 4096]).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions