|
54 | 54 | from vllm.multimodal.utils import run_dp_sharded_vision_model
|
55 | 55 | from vllm.sequence import IntermediateTensors
|
56 | 56 |
|
57 |
| -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP |
| 57 | +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP, SupportsLoRA |
58 | 58 | from .llama4 import Llama4ForCausalLM
|
59 | 59 | from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
60 | 60 | merge_multimodal_embeddings)
|
61 |
| - |
| 61 | +from vllm.model_executor.models.module_mapping import MultiModelKeys |
62 | 62 |
|
63 | 63 | class Llama4ImagePatchInputs(TypedDict):
|
64 | 64 | type: Literal["pixel_values"]
|
@@ -711,8 +711,9 @@ def get_dummy_mm_data(
|
711 | 711 | info=Mllama4ProcessingInfo,
|
712 | 712 | dummy_inputs=Mllama4DummyInputsBuilder,
|
713 | 713 | )
|
714 |
| -class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, |
715 |
| - SupportsPP): |
| 714 | +class Llama4ForConditionalGeneration( |
| 715 | + nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA |
| 716 | +): |
716 | 717 | packed_modules_mapping = {
|
717 | 718 | "qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
718 | 719 | }
|
@@ -935,3 +936,13 @@ def load_weights(self, weights: Iterable[tuple[str,
|
935 | 936 | weight_loader(param, loaded_weight)
|
936 | 937 | updated_params.add(name)
|
937 | 938 | return updated_params
|
| 939 | + |
| 940 | + def get_mm_mapping(self) -> MultiModelKeys: |
| 941 | + """ |
| 942 | + Get the module prefix in multimodal models |
| 943 | + """ |
| 944 | + return MultiModelKeys.from_string_field( |
| 945 | + language_model="language_model", |
| 946 | + connector="multi_modal_projector.", |
| 947 | + tower_model="vision_model.", |
| 948 | + ) |
0 commit comments