Skip to content

Commit e9ee6b3

Browse files
authored
[model] fix: qwen3vl models shape mismatch error with SP (#3735)
1 parent 9d4554b commit e9ee6b3

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

verl/models/transformers/monkey_patch.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def _create_ulysses_wrapped_decoder_forward(original_forward):
127127
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
128128
inputs_embeds = kwargs.get("inputs_embeds")
129129
position_ids = kwargs.get("position_ids")
130+
visual_pos_masks = kwargs.get("visual_pos_masks")
131+
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
130132
call_kwargs = kwargs.copy()
131133

132134
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@@ -139,6 +141,43 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
139141
if slice_now:
140142
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
141143
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
144+
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
145+
if visual_pos_masks is not None:
146+
original_visual_mask = visual_pos_masks
147+
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
148+
call_kwargs["visual_pos_masks"] = sliced_visual_mask
149+
150+
if deepstack_visual_embeds is not None:
151+
sliced_embeds = []
152+
153+
num_visual_before = original_visual_mask.sum().item()
154+
num_visual_in_shard = sliced_visual_mask.sum().item()
155+
156+
if num_visual_in_shard > 0 and num_visual_before > 0:
157+
# Calculate which visual embeddings belong to this shard
158+
# We need to find the offset of visual tokens in this shard
159+
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
160+
161+
rank = get_ulysses_sequence_parallel_rank()
162+
seq_len = original_visual_mask.shape[1]
163+
local_seq_len = seq_len // current_ulysses_sp_size
164+
start_idx = rank * local_seq_len
165+
end_idx = start_idx + local_seq_len
166+
167+
# Get total visual tokens before and up to the end of the shard's sequence slice
168+
# This correctly handles batches by summing across all samples
169+
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
170+
visual_end = original_visual_mask[:, :end_idx].sum().item()
171+
172+
# Slice each tensor in deepstack_visual_embeds
173+
for embed in deepstack_visual_embeds:
174+
sliced_embeds.append(embed[visual_start:visual_end])
175+
else:
176+
# No visual tokens in this shard, create empty tensors to maintain gradient flow
177+
for embed in deepstack_visual_embeds:
178+
sliced_embeds.append(embed[:0])
179+
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
180+
142181
self._needs_initial_slice = False
143182
try:
144183
return original_forward(self, *args, **call_kwargs)
@@ -290,9 +329,7 @@ def state_dict(self, *args, **kwargs):
290329
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
291330
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
292331
)
293-
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
294-
Qwen2VLFlashAttention2 as Qwen2VLAttention,
295-
)
332+
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
296333

297334
if use_remove_padding or ulysses_sp_size > 1:
298335
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward

0 commit comments

Comments
 (0)