@@ -127,6 +127,8 @@ def _create_ulysses_wrapped_decoder_forward(original_forward):
127127 def ulysses_wrapped_decoder_forward (self , * args , ** kwargs ):
128128 inputs_embeds = kwargs .get ("inputs_embeds" )
129129 position_ids = kwargs .get ("position_ids" )
130+ visual_pos_masks = kwargs .get ("visual_pos_masks" )
131+ deepstack_visual_embeds = kwargs .get ("deepstack_visual_embeds" )
130132 call_kwargs = kwargs .copy ()
131133
132134 current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size ()
@@ -139,6 +141,43 @@ def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
139141 if slice_now :
140142 call_kwargs ["inputs_embeds" ] = slice_input_tensor (inputs_embeds , dim = 1 , padding = False )
141143 call_kwargs ["position_ids" ] = slice_input_tensor (position_ids , dim = - 1 , padding = False )
144+ # Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
145+ if visual_pos_masks is not None :
146+ original_visual_mask = visual_pos_masks
147+ sliced_visual_mask = slice_input_tensor (visual_pos_masks , dim = 1 , padding = False )
148+ call_kwargs ["visual_pos_masks" ] = sliced_visual_mask
149+
150+ if deepstack_visual_embeds is not None :
151+ sliced_embeds = []
152+
153+ num_visual_before = original_visual_mask .sum ().item ()
154+ num_visual_in_shard = sliced_visual_mask .sum ().item ()
155+
156+ if num_visual_in_shard > 0 and num_visual_before > 0 :
157+ # Calculate which visual embeddings belong to this shard
158+ # We need to find the offset of visual tokens in this shard
159+ from verl .utils .ulysses import get_ulysses_sequence_parallel_rank
160+
161+ rank = get_ulysses_sequence_parallel_rank ()
162+ seq_len = original_visual_mask .shape [1 ]
163+ local_seq_len = seq_len // current_ulysses_sp_size
164+ start_idx = rank * local_seq_len
165+ end_idx = start_idx + local_seq_len
166+
167+ # Get total visual tokens before and up to the end of the shard's sequence slice
168+ # This correctly handles batches by summing across all samples
169+ visual_start = original_visual_mask [:, :start_idx ].sum ().item () if start_idx > 0 else 0
170+ visual_end = original_visual_mask [:, :end_idx ].sum ().item ()
171+
172+ # Slice each tensor in deepstack_visual_embeds
173+ for embed in deepstack_visual_embeds :
174+ sliced_embeds .append (embed [visual_start :visual_end ])
175+ else :
176+ # No visual tokens in this shard, create empty tensors to maintain gradient flow
177+ for embed in deepstack_visual_embeds :
178+ sliced_embeds .append (embed [:0 ])
179+ call_kwargs ["deepstack_visual_embeds" ] = sliced_embeds
180+
142181 self ._needs_initial_slice = False
143182 try :
144183 return original_forward (self , * args , ** call_kwargs )
@@ -290,9 +329,7 @@ def state_dict(self, *args, **kwargs):
290329 from transformers .models .qwen2_5_vl .modeling_qwen2_5_vl import (
291330 Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention ,
292331 )
293- from transformers .models .qwen2_vl .modeling_qwen2_vl import (
294- Qwen2VLFlashAttention2 as Qwen2VLAttention ,
295- )
332+ from transformers .models .qwen2_vl .modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
296333
297334 if use_remove_padding or ulysses_sp_size > 1 :
298335 from verl .models .transformers .qwen2_vl import qwen2_vl_attn_forward
0 commit comments