Skip to content

Commit f2887f6

Browse files
author
Lu Fang
committed
add assertions and merge impl
Signed-off-by: Lu Fang <[email protected]>
1 parent 8b7b409 commit f2887f6

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ def use_cascade_attention(
743743
if common_prefix_len < 256:
744744
return False
745745
# Cascade attention is currently not supported with these variants.
746-
if use_alibi or use_sliding_window:
746+
if use_alibi or use_sliding_window or use_local_attention:
747747
return False
748748
# Too few queries. Probably not worth using cascade attention.
749749
# We use an arbitrary threshold of 8 queries. TODO: Tune this threshold.

vllm/v1/kv_cache_interface.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,9 @@ def merge(cls, specs: list[Self]) -> Self:
114114
merged_spec = super().merge(specs)
115115
sliding_window = set(spec.sliding_window for spec in specs
116116
if spec.sliding_window is not None)
117+
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
118+
if spec.attention_chunk_size is not None)
119+
117120
if len(sliding_window) == 0:
118121
merged_spec.sliding_window = None
119122
elif len(sliding_window) == 1:
@@ -122,6 +125,17 @@ def merge(cls, specs: list[Self]) -> Self:
122125
raise ValueError(
123126
"All sliding window layers in the same KV cache group "
124127
"must have the same window size.")
128+
if len(attention_chunk_size) == 0:
129+
merged_spec.attention_chunk_size = None
130+
elif len(attention_chunk_size) == 1:
131+
merged_spec.attention_chunk_size = attention_chunk_size.pop()
132+
else:
133+
raise ValueError(
134+
"All chunked local attention layers in the same KV cache group "
135+
"must have the same chunk size.")
136+
assert len(sliding_window) + len(attention_chunk_size) <= 1, (
137+
"Model with both sliding window layers and chunked local attention "
138+
"layers is not supported.")
125139
return merged_spec
126140

127141

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2330,6 +2330,8 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
23302330

23312331
# TODO: Support other attention modules, e.g., cross-attention
23322332
if attn_module.attn_type == AttentionType.DECODER:
2333+
use_local_attention = (self.attention_chunk_size is not None
2334+
and attn_module.use_irope)
23332335
if attn_module.sliding_window is not None:
23342336
kv_cache_spec[layer_name] = SlidingWindowSpec(
23352337
block_size=block_size,
@@ -2338,8 +2340,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
23382340
dtype=self.kv_cache_dtype,
23392341
sliding_window=attn_module.sliding_window,
23402342
use_mla=use_mla)
2341-
elif self.attention_chunk_size is not None \
2342-
and attn_module.use_irope:
2343+
assert not use_local_attention, (
2344+
"attention module can not be with ",
2345+
"both local attention and sliding window")
2346+
elif use_local_attention:
23432347
kv_cache_spec[layer_name] = \
23442348
ChunkedLocalAttentionSpec(
23452349
block_size=block_size,

0 commit comments

Comments
 (0)