diff --git a/vllm/model_executor/models/mllama4.py b/vllm/model_executor/models/mllama4.py index bf4bd309eea..843382392f1 100644 --- a/vllm/model_executor/models/mllama4.py +++ b/vllm/model_executor/models/mllama4.py @@ -41,6 +41,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.model_loader.utils import initialize_model from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, @@ -54,7 +55,8 @@ from vllm.multimodal.utils import run_dp_sharded_vision_model from vllm.sequence import IntermediateTensors -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsPP) from .llama4 import Llama4ForCausalLM from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix, merge_multimodal_embeddings) @@ -711,8 +713,8 @@ def get_dummy_mm_data( info=Mllama4ProcessingInfo, dummy_inputs=Mllama4DummyInputsBuilder, ) -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsPP): +class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP, + SupportsLoRA): packed_modules_mapping = { "qkv_proj": ["q_proj", "k_proj", "v_proj"], } @@ -935,3 +937,13 @@ def load_weights(self, weights: Iterable[tuple[str, weight_loader(param, loaded_weight) updated_params.add(name) return updated_params + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="language_model", + connector="multi_modal_projector.", + tower_model="vision_model.", + )