Skip to content

LoRA support on llama4 #19819

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.utils import initialize_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
Expand All @@ -54,7 +55,8 @@
from vllm.multimodal.utils import run_dp_sharded_vision_model
from vllm.sequence import IntermediateTensors

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .llama4 import Llama4ForCausalLM
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
merge_multimodal_embeddings)
Expand Down Expand Up @@ -711,8 +713,8 @@ def get_dummy_mm_data(
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP):
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
Comment on lines +716 to +717
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The class Llama4ForConditionalGeneration now correctly inherits SupportsLoRA. This is a key step for enabling LoRA support. Two points to consider:

  1. Embedding LoRA: The SupportsLoRA interface provides default empty embedding_modules and embedding_padding_modules. If this model requires LoRA to be applied to specific embedding layers (e.g., word token embeddings wte, or the lm_head), these class variables should be overridden in Llama4ForConditionalGeneration. For example:

    class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
        # ... existing packed_modules_mapping ...
    
        embedding_modules = {
            "embed_tokens": "language_model.model.embed_tokens",  # Example path
        }
        # embedding_padding_modules = [...] # If needed for specific embeddings
        # ... rest of the class ...

    If LoRA is not intended for embedding layers in mllama4, then using the defaults from SupportsLoRA is acceptable. Please clarify the intended scope.

  2. Packed Modules for LoRA: The packed_modules_mapping defined on lines 718-720 is {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}. For context, the Llama4ForCausalLM class (defined in vllm/model_executor/models/llama4.py, not this mllama4.py file, but potentially related or a source of inspiration) often has a more comprehensive mapping, e.g., {"qkv_proj": [...], "gate_up_proj": ["gate_proj", "up_proj"]}.
    Is it intentional that Llama4ForConditionalGeneration has this specific mapping, potentially excluding MLP layers like gate_proj and up_proj from LoRA application? If Llama4ForConditionalGeneration shares a similar transformer block structure with other Llama models where MLP LoRA is common, consider aligning these mappings or clarifying the rationale for the difference.

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
}
Expand Down Expand Up @@ -935,3 +937,13 @@ def load_weights(self, weights: Iterable[tuple[str,
weight_loader(param, loaded_weight)
updated_params.add(name)
return updated_params

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector.",
tower_model="vision_model.",
)
Comment on lines +941 to +949
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This new method get_mm_mapping provides module prefixes for different components of a multimodal model. This is likely used by the LoRA mechanism (or other systems) to correctly apply or map configurations (like LoRA weights) in a multimodal context.

  • The hardcoded prefixes ("language_model", "multi_modal_projector.", "vision_model.") are specific to the mllama4 architecture. This is generally fine for a model-specific implementation.
  • Could you confirm if these prefixes are always static for all Mllama4 variants? If there's a possibility of them changing (e.g., based on hf_config), deriving them dynamically might offer more flexibility. However, if they are fixed architectural constants for mllama4, the current approach is fine.
  • Consider adding a brief comment within the method or in the docstring explaining the origin or significance of these specific prefix strings if it's not immediately obvious from the model's standard architecture documentation. For example, clarifying that language_model refers to the text processing backbone, vision_model to the image encoder, and multi_modal_projector to the connector module.