Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 61 additions & 8 deletions paddleformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,35 @@
]


def _make_sliding_window_mask(input_shape, past_key_values_length=0, window_size=5):
"""
Generate a sliding window mask that restricts each position to only attend to historical positions within the window.
Format: [bsz, 1, tgt_seq_len, src_seq_len], where True indicates allowed attention and False indicates masking.
"""
batch_size, seq_length = input_shape
# Total sequence length = historical sequence length + current sequence length (for generating complete mask)
total_length = past_key_values_length + seq_length

# Initialize mask with all False values
mask = paddle.zeros((seq_length, total_length), dtype=paddle.bool)

for i in range(seq_length):
# Absolute position of current location in the total sequence (including historical sequence)
current_pos = past_key_values_length + i
# Window start position: max(0, current position - window size + 1)
start = max(0, current_pos - window_size + 1)
# Window end position: current position (causal mask restriction, cannot exceed self)
end = current_pos + 1 # Slice is left closed and right open, so+1
# Mark window range as True (allow attention)
mask[i, start:end] = True

# Expand dimensions to [bsz, 1, tgt_seq_len, src_seq_len]
mask = mask.unsqueeze(0).unsqueeze(0)
# Copy to each sample in batch_size
mask = paddle.tile(mask, repeat_times=[batch_size, 1, 1, 1])
return mask


def get_unfinished_flag(
input_ids: Tensor, unfinished_flag: Tensor, eos_token_id: Union[int, list[int], list[list[int]]]
) -> Tensor:
Expand Down Expand Up @@ -354,29 +383,53 @@ def prepare_attention_mask_for_generation(input_ids, pad_token_id, eos_token_id)
return attention_mask

@staticmethod
def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype):
def _prepare_decoder_attention_mask(
attention_mask, input_shape, past_key_values_length, dtype, sliding_window_size=None
):
# Step 1: Process input mask to generate basic expanded mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
if len(attention_mask.shape) == 2:
expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1])
# For decoding phase in generation, seq_length = 1, we don't need to add causal mask
# When not generating in single step, need to combine causal mask and sliding window mask
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, past_key_values_length=past_key_values_length
)
# Generate basic causal mask (prevent future information leakage)
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Generate sliding window mask (limit historical attention range)
if sliding_window_size is not None and sliding_window_size > 0:
window_mask = _make_sliding_window_mask(
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
)
# Take intersection of sliding window mask and causal mask (satisfy both restrictions)
combined_attention_mask = causal_mask & window_mask
else:
combined_attention_mask = (
causal_mask # Use causal mask directly when sliding window is disabled
)

# Combine with user-provided mask (e.g., padding mask)
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool")
else:
expanded_attn_mask = expanded_attn_mask & combined_attention_mask
# [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len]
elif len(attention_mask.shape) == 3:
expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool")
# if attention_mask is already 4-D, do nothing
# 4D mask is used directly
else:
expanded_attn_mask = attention_mask
else:
expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
# Convert bool attention_mask to float attention mask, which will be added to attention_scores later
# When no input mask, generate causal mask + sliding window mask (if enabled)
causal_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length)
if sliding_window_size is not None and sliding_window_size > 0:
window_mask = _make_sliding_window_mask(
input_shape, past_key_values_length=past_key_values_length, window_size=sliding_window_size
)
expanded_attn_mask = causal_mask & window_mask
else:
expanded_attn_mask = causal_mask # Use causal mask directly when sliding window is disabled

# Step 2: Convert boolean mask to numerical mask (adapt to different devices)
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
Expand Down
17 changes: 17 additions & 0 deletions paddleformers/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,3 +1269,20 @@ def get_configuration_file(configuration_files: List[str]) -> str:
break

return configuration_file


ALLOWED_LAYER_TYPES = (
"full_attention",
"sliding_attention",
)


def layer_type_validation(layer_types: List[str], num_hidden_layers: Optional[int] = None):
"""Check that `layer_types` is correctly defined."""
if not all(layer_type in ALLOWED_LAYER_TYPES for layer_type in layer_types):
raise ValueError(f"The `layer_types` entries must be in {ALLOWED_LAYER_TYPES}")
if num_hidden_layers is not None and num_hidden_layers != len(layer_types):
raise ValueError(
f"`num_hidden_layers` ({num_hidden_layers}) must be equal to the number of layer types "
f"({len(layer_types)})"
)
12 changes: 11 additions & 1 deletion paddleformers/transformers/qwen2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""Qwen2 model configuration"""

from ..configuration_utils import PretrainedConfig
from ..configuration_utils import PretrainedConfig, layer_type_validation

__all__ = [
"Qwen2Config",
Expand Down Expand Up @@ -129,6 +129,7 @@ def __init__(
attention_dropout=0.0,
rope_scaling_factor=1.0,
rope_scaling_type=None,
layer_types=None,
pp_seg_method="layer:Qwen2DecoderLayer",
**kwargs,
):
Expand Down Expand Up @@ -167,6 +168,14 @@ def __init__(

self.pp_seg_method = pp_seg_method

self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if self.use_sliding_window and i >= self.max_window_layers else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types, self.num_hidden_layers)

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
Expand All @@ -190,5 +199,6 @@ def __init__(
"pp_seg_method",
"dpo_config",
"kto_config",
"layer_types",
]
)
Loading