Skip to content

Commit cf77edd

Browse files
authored
cherry pick from pr #2655 (#2668)
1 parent 39e4293 commit cf77edd

File tree

10 files changed

+1193
-455
lines changed

10 files changed

+1193
-455
lines changed

paddleformers/generation/utils.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@
6363
]
6464

6565

66+
def _make_sliding_window_mask(input_shape, past_key_values_length=0, window_size=5):
67+
"""
68+
Generate a sliding window mask that restricts each position to only attend to historical positions within the window.
69+
Format: [bsz, 1, tgt_seq_len, src_seq_len], where True indicates allowed attention and False indicates masking.
70+
"""
71+
batch_size, seq_length = input_shape
72+
# Total sequence length = historical sequence length + current sequence length (for generating complete mask)
73+
total_length = past_key_values_length + seq_length
74+
75+
# Initialize mask with all False values
76+
mask = paddle.zeros((seq_length, total_length), dtype=paddle.bool)
77+
78+
for i in range(seq_length):
79+
# Absolute position of current location in the total sequence (including historical sequence)
80+
current_pos = past_key_values_length + i
81+
# Window start position: max(0, current position - window size + 1)
82+
start = max(0, current_pos - window_size + 1)
83+
# Window end position: current position (causal mask restriction, cannot exceed self)
84+
end = current_pos + 1 # Slice is left closed and right open, so+1
85+
# Mark window range as True (allow attention)
86+
mask[i, start:end] = True
87+
88+
# Expand dimensions to [bsz, 1, tgt_seq_len, src_seq_len]
89+
mask = mask.unsqueeze(0).unsqueeze(0)
90+
# Copy to each sample in batch_size
91+
mask = paddle.tile(mask, repeat_times=[batch_size, 1, 1, 1])
92+
return mask
93+
94+
6695
def get_unfinished_flag(
6796
input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]]
6897
) -> Tensor:
@@ -354,29 +383,53 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
354383
return attention_mask
355384

356385
@staticmethod
357-
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
386+
def _prepare_decoder_attention_mask(
387+
attention_mask, input_shape, past_key_values_length, dtype, sliding_window_size=None
388+
):
389+
# Step 1: Process input mask to generate basic expanded mask
358390
if attention_mask is not None:
359391
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
360392
if len(attention_mask.shape) == 2:
361393
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
362-
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
394+
# When not generating in single step, need to combine causal mask and sliding window mask
363395
if input_shape[-1] > 1:
364-
combined_attention_mask = _make_causal_mask(
365-
input_shape, past_key_values_length=past_key_values_length
366-
)
396+
# Generate basic causal mask (prevent future information leakage)
397+
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
398+
# Generate sliding window mask (limit historical attention range)
399+
if sliding_window_size is not None and sliding_window_size > 0:
400+
window_mask = _make_sliding_window_mask(
401+
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
402+
)
403+
# Take intersection of sliding window mask and causal mask (satisfy both restrictions)
404+
combined_attention_mask = causal_mask & window_mask
405+
else:
406+
combined_attention_mask = (
407+
causal_mask # Use causal mask directly when sliding window is disabled
408+
)
409+
410+
# Combine with user-provided mask (e.g., padding mask)
367411
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
368412
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
369413
else:
370414
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
371415
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
372416
elif len(attention_mask.shape) == 3:
373417
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
374-
# if attention_mask is already 4-D, do nothing
418+
# 4D mask is used directly
375419
else:
376420
expanded_attn_mask = attention_mask
377421
else:
378-
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
379-
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
422+
# When no input mask, generate causal mask + sliding window mask (if enabled)
423+
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
424+
if sliding_window_size is not None and sliding_window_size > 0:
425+
window_mask = _make_sliding_window_mask(
426+
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
427+
)
428+
expanded_attn_mask = causal_mask & window_mask
429+
else:
430+
expanded_attn_mask = causal_mask # Use causal mask directly when sliding window is disabled
431+
432+
# Step 2: Convert boolean mask to numerical mask (adapt to different devices)
380433
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
381434
x = paddle.to_tensor(0.0, dtype="float32")
382435
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")

paddleformers/transformers/configuration_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,3 +1269,20 @@ def get_configuration_file(configuration_files: List[str]) -> str:
12691269
break
12701270

12711271
return configuration_file
1272+
1273+
1274+
ALLOWED_LAYER_TYPES = (
1275+
"full_attention",
1276+
"sliding_attention",
1277+
)
1278+
1279+
1280+
def layer_type_validation(layer_types: List[str], num_hidden_layers: Optional[int] = None):
1281+
"""Check that `layer_types` is correctly defined."""
1282+
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
1283+
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
1284+
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
1285+
raise ValueError(
1286+
f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
1287+
f"({len(layer_types)})"
1288+
)

paddleformers/transformers/qwen2/configuration.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
"""Qwen2 model configuration"""
1616

17-
from ..configuration_utils import PretrainedConfig
17+
from ..configuration_utils import PretrainedConfig, layer_type_validation
1818

1919
__all__ = [
2020
"Qwen2Config",
@@ -129,6 +129,7 @@ def __init__(
129129
attention_dropout=0.0,
130130
rope_scaling_factor=1.0,
131131
rope_scaling_type=None,
132+
layer_types=None,
132133
pp_seg_method="layer:Qwen2DecoderLayer",
133134
**kwargs,
134135
):
@@ -167,6 +168,14 @@ def __init__(
167168

168169
self.pp_seg_method = pp_seg_method
169170

171+
self.layer_types = layer_types
172+
if self.layer_types is None:
173+
self.layer_types = [
174+
"sliding_attention" if self.use_sliding_window and i >= self.max_window_layers else "full_attention"
175+
for i in range(self.num_hidden_layers)
176+
]
177+
layer_type_validation(self.layer_types, self.num_hidden_layers)
178+
170179
super().__init__(
171180
pad_token_id=pad_token_id,
172181
bos_token_id=bos_token_id,
@@ -190,5 +199,6 @@ def __init__(
190199
"pp_seg_method",
191200
"dpo_config",
192201
"kto_config",
202+
"layer_types",
193203
]
194204
)

0 commit comments

Comments
 (0)