Skip to content

Commit bb63338

Browse files
committed
add support for lora
Signed-off-by: Wei Wei <[email protected]>
1 parent 12575cf commit bb63338

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

vllm/model_executor/models/mllama4.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@
5454
from vllm.multimodal.utils import run_dp_sharded_vision_model
5555
from vllm.sequence import IntermediateTensors
5656

57-
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
57+
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsLoRA
5858
from .llama4 import Llama4ForCausalLM
5959
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
6060
merge_multimodal_embeddings)
61-
61+
from vllm.model_executor.models.module_mapping import MultiModelKeys
6262

6363
class Llama4ImagePatchInputs(TypedDict):
6464
type: Literal["pixel_values"]
@@ -711,8 +711,9 @@ def get_dummy_mm_data(
711711
info=Mllama4ProcessingInfo,
712712
dummy_inputs=Mllama4DummyInputsBuilder,
713713
)
714-
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
715-
SupportsPP):
714+
class Llama4ForConditionalGeneration(
715+
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
716+
):
716717
packed_modules_mapping = {
717718
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
718719
}
@@ -935,3 +936,13 @@ def load_weights(self, weights: Iterable[tuple[str,
935936
weight_loader(param, loaded_weight)
936937
updated_params.add(name)
937938
return updated_params
939+
940+
def get_mm_mapping(self) -> MultiModelKeys:
941+
"""
942+
Get the module prefix in multimodal models
943+
"""
944+
return MultiModelKeys.from_string_field(
945+
language_model="language_model",
946+
connector="multi_modal_projector.",
947+
tower_model="vision_model.",
948+
)

0 commit comments

Comments
 (0)