|
63 | 63 | ]
|
64 | 64 |
|
65 | 65 |
|
| 66 | +def _make_sliding_window_mask(input_shape, past_key_values_length=0, window_size=5): |
| 67 | + """ |
| 68 | + Generate a sliding window mask that restricts each position to only attend to historical positions within the window. |
| 69 | + Format: [bsz, 1, tgt_seq_len, src_seq_len], where True indicates allowed attention and False indicates masking. |
| 70 | + """ |
| 71 | + batch_size, seq_length = input_shape |
| 72 | + # Total sequence length = historical sequence length + current sequence length (for generating complete mask) |
| 73 | + total_length = past_key_values_length + seq_length |
| 74 | + |
| 75 | + # Initialize mask with all False values |
| 76 | + mask = paddle.zeros((seq_length, total_length), dtype=paddle.bool) |
| 77 | + |
| 78 | + for i in range(seq_length): |
| 79 | + # Absolute position of current location in the total sequence (including historical sequence) |
| 80 | + current_pos = past_key_values_length + i |
| 81 | + # Window start position: max(0, current position - window size + 1) |
| 82 | + start = max(0, current_pos - window_size + 1) |
| 83 | + # Window end position: current position (causal mask restriction, cannot exceed self) |
| 84 | + end = current_pos + 1 # Slice is left closed and right open, so+1 |
| 85 | + # Mark window range as True (allow attention) |
| 86 | + mask[i, start:end] = True |
| 87 | + |
| 88 | + # Expand dimensions to [bsz, 1, tgt_seq_len, src_seq_len] |
| 89 | + mask = mask.unsqueeze(0).unsqueeze(0) |
| 90 | + # Copy to each sample in batch_size |
| 91 | + mask = paddle.tile(mask, repeat_times=[batch_size, 1, 1, 1]) |
| 92 | + return mask |
| 93 | + |
| 94 | + |
66 | 95 | def get_unfinished_flag(
|
67 | 96 | input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]]
|
68 | 97 | ) -> Tensor:
|
@@ -354,29 +383,53 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
|
354 | 383 | return attention_mask
|
355 | 384 |
|
356 | 385 | @staticmethod
|
357 |
| - def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): |
| 386 | + def _prepare_decoder_attention_mask( |
| 387 | + attention_mask, input_shape, past_key_values_length, dtype, sliding_window_size=None |
| 388 | + ): |
| 389 | + # Step 1: Process input mask to generate basic expanded mask |
358 | 390 | if attention_mask is not None:
|
359 | 391 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
360 | 392 | if len(attention_mask.shape) == 2:
|
361 | 393 | expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
|
362 |
| - # For decoding phase in generation, seq_length = 1, we don't need to add causal mask |
| 394 | + # When not generating in single step, need to combine causal mask and sliding window mask |
363 | 395 | if input_shape[-1] > 1:
|
364 |
| - combined_attention_mask = _make_causal_mask( |
365 |
| - input_shape, past_key_values_length=past_key_values_length |
366 |
| - ) |
| 396 | + # Generate basic causal mask (prevent future information leakage) |
| 397 | + causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) |
| 398 | + # Generate sliding window mask (limit historical attention range) |
| 399 | + if sliding_window_size is not None and sliding_window_size > 0: |
| 400 | + window_mask = _make_sliding_window_mask( |
| 401 | + input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size |
| 402 | + ) |
| 403 | + # Take intersection of sliding window mask and causal mask (satisfy both restrictions) |
| 404 | + combined_attention_mask = causal_mask & window_mask |
| 405 | + else: |
| 406 | + combined_attention_mask = ( |
| 407 | + causal_mask # Use causal mask directly when sliding window is disabled |
| 408 | + ) |
| 409 | + |
| 410 | + # Combine with user-provided mask (e.g., padding mask) |
367 | 411 | if get_env_device() in ["npu", "mlu", "intel_hpu"]:
|
368 | 412 | expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
|
369 | 413 | else:
|
370 | 414 | expanded_attn_mask = expanded_attn_mask & combined_attention_mask
|
371 | 415 | # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
|
372 | 416 | elif len(attention_mask.shape) == 3:
|
373 | 417 | expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
|
374 |
| - # if attention_mask is already 4-D, do nothing |
| 418 | + # 4D mask is used directly |
375 | 419 | else:
|
376 | 420 | expanded_attn_mask = attention_mask
|
377 | 421 | else:
|
378 |
| - expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) |
379 |
| - # Convert bool attention_mask to float attention mask, which will be added to attention_scores later |
| 422 | + # When no input mask, generate causal mask + sliding window mask (if enabled) |
| 423 | + causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) |
| 424 | + if sliding_window_size is not None and sliding_window_size > 0: |
| 425 | + window_mask = _make_sliding_window_mask( |
| 426 | + input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size |
| 427 | + ) |
| 428 | + expanded_attn_mask = causal_mask & window_mask |
| 429 | + else: |
| 430 | + expanded_attn_mask = causal_mask # Use causal mask directly when sliding window is disabled |
| 431 | + |
| 432 | + # Step 2: Convert boolean mask to numerical mask (adapt to different devices) |
380 | 433 | if get_env_device() in ["npu", "mlu", "intel_hpu"]:
|
381 | 434 | x = paddle.to_tensor(0.0, dtype="float32")
|
382 | 435 | y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
|
|
0 commit comments