Skip to content

LoRA support on llama4 #19819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Conversation

frank-wei
Copy link

@frank-wei frank-wei commented Jun 18, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Current llama4 does not support LoRA. Add this feature support.

Test Plan

Verified on internal llama4 model. Need further verification on large models but I do not have the lora adapter. Will need user who has this request to help

Test Result

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @frank-wei, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for LoRA (Low-Rank Adaptation) to the llama4 model within the vllm framework. The primary goal is to enable the use of LoRA adapters with llama4 for fine-tuning and efficient deployment, specifically for the Llama4ForConditionalGeneration class, by making it compatible with LoRA-specific operations and module mapping.

Highlights

  • LoRA Integration: The Llama4ForConditionalGeneration class now implements the SupportsLoRA interface, enabling LoRA (Low-Rank Adaptation) functionality for the llama4 model.
  • Multimodal Module Mapping: A new get_mm_mapping method has been added to Llama4ForConditionalGeneration. This method provides module prefixes (language_model, connector, tower_model) essential for correctly applying LoRA adapters in a multimodal context.
  • Dependency Updates: Necessary imports for MultiModelKeys and SupportsLoRA have been added to vllm/model_executor/models/mllama4.py to support the new LoRA integration.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the llama Related to Llama models label Jun 18, 2025
@frank-wei frank-wei changed the title Feature lora LoRA support on llama4 Jun 18, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces LoRA support for the mllama4 model by having Llama4ForConditionalGeneration implement the SupportsLoRA interface and adding a helper method get_mm_mapping likely to aid in applying LoRA to multimodal components.

Key areas for the author to consider:

  1. Testing: The PR description mentions initial verification but highlights the need for more extensive testing, especially with actual LoRA adapters and on larger models. It would be helpful to detail the internal verification steps and outline a clearer plan for further testing, even if community help is needed for specific adapters.
  2. Scope of LoRA: Please clarify if LoRA is intended to be applicable to embedding layers (like token embeddings or lm_head) for mllama4, or if it's primarily for attention and MLP layers. The current changes rely on default empty configurations for LoRA on embeddings.
  3. Packed Module Mapping: The packed_modules_mapping in Llama4ForConditionalGeneration seems specific. It would be good to confirm if this mapping is intentionally different from potentially more comprehensive mappings in other Llama-style models, particularly regarding MLP layers.

Overall, the changes seem to lay the necessary groundwork for LoRA support. Addressing the points above will help ensure robustness and clarity.

Comment on lines +716 to +717
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The class Llama4ForConditionalGeneration now correctly inherits SupportsLoRA. This is a key step for enabling LoRA support. Two points to consider:

  1. Embedding LoRA: The SupportsLoRA interface provides default empty embedding_modules and embedding_padding_modules. If this model requires LoRA to be applied to specific embedding layers (e.g., word token embeddings wte, or the lm_head), these class variables should be overridden in Llama4ForConditionalGeneration. For example:

    class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
        # ... existing packed_modules_mapping ...
    
        embedding_modules = {
            "embed_tokens": "language_model.model.embed_tokens",  # Example path
        }
        # embedding_padding_modules = [...] # If needed for specific embeddings
        # ... rest of the class ...

    If LoRA is not intended for embedding layers in mllama4, then using the defaults from SupportsLoRA is acceptable. Please clarify the intended scope.

  2. Packed Modules for LoRA: The packed_modules_mapping defined on lines 718-720 is {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}. For context, the Llama4ForCausalLM class (defined in vllm/model_executor/models/llama4.py, not this mllama4.py file, but potentially related or a source of inspiration) often has a more comprehensive mapping, e.g., {"qkv_proj": [...], "gate_up_proj": ["gate_proj", "up_proj"]}.
    Is it intentional that Llama4ForConditionalGeneration has this specific mapping, potentially excluding MLP layers like gate_proj and up_proj from LoRA application? If Llama4ForConditionalGeneration shares a similar transformer block structure with other Llama models where MLP LoRA is common, consider aligning these mappings or clarifying the rationale for the difference.

Comment on lines +941 to +949
def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector.",
tower_model="vision_model.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new method get_mm_mapping provides module prefixes for different components of a multimodal model. This is likely used by the LoRA mechanism (or other systems) to correctly apply or map configurations (like LoRA weights) in a multimodal context.

  • The hardcoded prefixes ("language_model", "multi_modal_projector.", "vision_model.") are specific to the mllama4 architecture. This is generally fine for a model-specific implementation.
  • Could you confirm if these prefixes are always static for all Mllama4 variants? If there's a possibility of them changing (e.g., based on hf_config), deriving them dynamically might offer more flexibility. However, if they are fixed architectural constants for mllama4, the current approach is fine.
  • Consider adding a brief comment within the method or in the docstring explaining the origin or significance of these specific prefix strings if it's not immediately obvious from the model's standard architecture documentation. For example, clarifying that language_model refers to the text processing backbone, vision_model to the image encoder, and multi_modal_projector to the connector module.

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. We just need some adapter to verify.

Copy link
Collaborator

@22quinn 22quinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this adapter can be used for testing? https://huggingface.co/ImranzamanML/llama4-medqa/blob/main/adapter_config.json
Or we can create a dummy adapter.

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that our current MoE layer doesn't support LoRA yet, llama4 may not be able to fully support LoRA

@frank-wei
Copy link
Author

Considering that our current MoE layer doesn't support LoRA yet, llama4 may not be able to fully support LoRA

@jeejeelee , we probably need to ask people do not add the adapter for MOE modules.
Thanks @22quinn , thanks for the pointer. I will verify this adapter. Looks like the target_modules do not include the MOE modules which are what we can support now.

@frank-wei
Copy link
Author

frank-wei commented Jun 21, 2025

Tried the adapter from @22quinn , that adapter has trained with module name not supported. The shared expert module has linear operator with name gate_up_proj and down_proj but the adapter has sth called gate_proj and up_proj which are not aligned.

Log for shared expert:
�[1;36m(VllmWorker rank=7 pid=481083)�[0;0m name= language_model.model.layers.26.feed_forward.shared_expert module LlamaMLP(
�[1;36m(VllmWorker rank=7 pid=481083)�[0;0m (gate_up_proj): MergedColumnParallelLinear(in_features=5120, output_features=2048, bias=False, tp_size=8, gather_output=False)
�[1;36m(VllmWorker rank=7 pid=481083)�[0;0m (down_proj): RowParallelLinear(input_features=1024, output_features=5120, bias=False, tp_size=8, reduce_results=False)
�[1;36m(VllmWorker rank=7 pid=481083)�[0;0m (act_fn): SiluAndMul()

Error:
ValueError: While loading /home/wwei6/venv/llama4-adapter/llama4-medqa, expected target modules in ['linear', 'linear_1', 'gate_up_proj', 'v_proj', 'router', 'fc2', 'k_proj', 'o_proj', 'fc1', 'down_proj', 'q_proj'] but received ['language_model.model.layers.0.feed_forward.shared_expert.gate_proj', 'language_model.model.layers.0.feed_forward.shared_expert.up_proj'....

At this stage, I think the functionality is good but need adapter for further checking. I verified the small model internally but the big one OSS. I can follow up if any user has their adapters.

Signed-off-by: Wei Wei <[email protected]>
Signed-off-by: Wei Wei <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llama Related to Llama models
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants