@@ -21,43 +21,41 @@ def flash_context_attention(
2121):
2222 num_q_heads , dim = query_states .shape [1 :3 ]
2323 num_kv_heads = value_states .shape [1 ]
24- batch = q_start_loc .shape [0 ]
2524
26- for i in range (batch ):
27- if torch .equal (q_seq_len [i ], kv_seq_len [i ]):
28- ext_ops .context_attention (
29- query_states ,
30- key_states ,
31- value_states ,
32- q_start_loc [i :i + 1 ],
33- q_seq_len [i :i + 1 ],
34- num_q_heads ,
35- num_kv_heads ,
36- attn_mask = context .attention_mask [i :i + 1 ],
37- attn_output = attn_output ,
38- )
39- else :
40- key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
41- value_cache = value_cache .reshape (1 , kv_cache_len ,
42- num_kv_heads * dim )
43- ext_ops .paged_prefill_attention (
44- query_states ,
45- key_cache ,
46- value_cache ,
47- block_offsets ,
48- block_size ,
49- q_start_loc [i :i + 1 ],
50- q_seq_len [i :i + 1 ],
51- kv_seq_len [i :i + 1 ],
52- num_q_heads ,
53- num_kv_heads ,
54- attn_mask = context .attention_mask [i :i + 1 ],
55- attn_output = attn_output ,
56- )
25+ if context .is_unpaged_prefill :
26+ ext_ops .prefill_attention (
27+ query_states ,
28+ key_states ,
29+ value_states ,
30+ q_start_loc ,
31+ q_seq_len ,
32+ context .max_q_seq_length ,
33+ num_q_heads ,
34+ num_kv_heads ,
35+ attn_mask = context .attention_mask ,
36+ attn_output = attn_output ,
37+ )
38+ else :
39+ key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
40+ value_cache = value_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
41+ ext_ops .paged_prefill_attention (
42+ query_states ,
43+ key_cache ,
44+ value_cache ,
45+ block_offsets ,
46+ block_size ,
47+ q_start_loc ,
48+ q_seq_len ,
49+ kv_seq_len ,
50+ num_q_heads ,
51+ num_kv_heads ,
52+ attn_mask = context .attention_mask ,
53+ attn_output = attn_output ,
54+ )
5755
5856
5957def paged_token_attention (q , k_cache , v_cache , attn_output , kv_seq_len ,
60- block_offsets , block_size ):
58+ max_kv_seq_len , block_offsets , block_size ):
6159 num_kv_heads , num_q_heads = k_cache .shape [1 ], q .shape [1 ]
6260 ext_ops .paged_decode_attention (
6361 q ,
@@ -66,6 +64,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6664 block_offsets ,
6765 block_size ,
6866 kv_seq_len ,
67+ max_kv_seq_len ,
6968 num_q_heads ,
7069 num_kv_heads ,
7170 attn_output = attn_output .view (q .shape ),
@@ -115,6 +114,7 @@ def paged_attention_fwd(
115114 v ,
116115 attn_output ,
117116 kv_seqlens ,
117+ context .max_kv_seq_length ,
118118 block_offsets ,
119119 block_size ,
120120 )
0 commit comments