Skip to content

Commit 0e10941

Browse files
kesavanramakrishnanrootvalechen
authored
Added in GQA and 64-bit indexing (#1226)
* Added GQA support for Lean Attention and 64 bit indexing * reformatted files --------- Co-authored-by: root <[email protected]> Co-authored-by: valechen <[email protected]>
1 parent c7e3d34 commit 0e10941

File tree

4 files changed

+268
-126
lines changed

4 files changed

+268
-126
lines changed

aiter/ops/triton/_triton_kernels/lean_atten.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -114,15 +114,20 @@ def _attention_inner(
114114
offs_n,
115115
BLOCK_M,
116116
BLOCK_N,
117-
HEAD_DIM,
117+
HEAD_DIM_ORIG: tl.constexpr,
118+
HEAD_DIM: tl.constexpr,
118119
local_iter,
119120
local_iter_end,
121+
use_64_indexing: tl.constexpr,
120122
):
121123
"""
122124
Performs attention calculation for an (maybe partial) output tile
123125
"""
126+
# Define head-dimension mask for padded dims
127+
offs_k_local = tl.arange(0, HEAD_DIM)
128+
mask_k_cols_local = offs_k_local < HEAD_DIM_ORIG
124129
for l_iter in range(local_iter, local_iter_end):
125-
k = tl.load(k_ptrs)
130+
k = tl.load(k_ptrs, mask=mask_k_cols_local[:, None], other=0.0)
126131
qk = tl.dot(q, k) * qk_scale
127132

128133
if causal:
@@ -152,17 +157,24 @@ def _attention_inner(
152157

153158
# Update accumulator
154159
acc = acc * alpha[:, None]
155-
v = tl.load(v_ptrs)
160+
v = tl.load(v_ptrs, mask=mask_k_cols_local[None, :], other=0.0)
156161
acc += tl.dot(p.to(v.dtype), v)
157162

158163
# Update stats
159164
l_ij = tl.sum(p, 1)
160165
l_i = l_i * alpha + l_ij
161166
m_i = m_ij.to(m_i.dtype)
162167

163-
# update k/v pointer
164-
v_ptrs += BLOCK_N * stride_vn
165-
k_ptrs += BLOCK_N * stride_kn
168+
# update k/v pointer with optional 64-bit indexing to avoid overflow
169+
if use_64_indexing:
170+
BLOCK_N64 = tl.full((), BLOCK_N, tl.int64)
171+
stride_kn64 = tl.full((), stride_kn, tl.int64)
172+
stride_vn64 = tl.full((), stride_vn, tl.int64)
173+
v_ptrs += BLOCK_N64 * stride_vn64
174+
k_ptrs += BLOCK_N64 * stride_kn64
175+
else:
176+
v_ptrs += BLOCK_N * stride_vn
177+
k_ptrs += BLOCK_N * stride_kn
166178
return m_i, l_i, acc
167179

168180

@@ -222,10 +234,12 @@ def la_persistent(
222234
stride_om, # n_ctx_q
223235
stride_oh, # Head
224236
stride_on, # head_dim
237+
n_ctx_q_rows,
225238
stride_oph, # total_programs
226239
stride_opm, # n_ctx_q
227240
stride_opn, # head_dim
228241
HEADS_PER_XCD: tl.constexpr,
242+
HEAD_DIM_ORIG: tl.constexpr,
229243
HEAD_DIM: tl.constexpr,
230244
BLOCK_M: tl.constexpr,
231245
BLOCK_N: tl.constexpr,
@@ -243,6 +257,10 @@ def la_persistent(
243257
tiles_per_head: tl.constexpr,
244258
num_splits: tl.constexpr,
245259
max_output_tile_cnt: tl.constexpr,
260+
num_heads_q: tl.constexpr,
261+
num_heads_k: tl.constexpr,
262+
gqa_group_size: tl.constexpr,
263+
use_64_indexing: tl.constexpr,
246264
):
247265
if is_pod:
248266
current_pid = pod_pid
@@ -321,6 +339,7 @@ def la_persistent(
321339
xcd_id=xcd_id,
322340
HEADS_PER_XCD=HEADS_PER_XCD,
323341
HEAD_DIM=HEAD_DIM,
342+
HEAD_DIM_ORIG=HEAD_DIM_ORIG,
324343
BLOCK_M=BLOCK_M,
325344
BLOCK_N=BLOCK_N,
326345
MASKED_BLOCKS=MASKED_BLOCKS,
@@ -335,6 +354,8 @@ def la_persistent(
335354
max_tiles_per_wg=max_tiles_per_wg,
336355
tiles_per_head=tiles_per_head,
337356
num_splits=num_splits,
357+
gqa_group_size=gqa_group_size,
358+
use_64_indexing=use_64_indexing,
338359
)
339360

340361

@@ -372,6 +393,7 @@ def la_persistent_inner(
372393
xcd_id, # The XCD the pid belongs to
373394
HEADS_PER_XCD,
374395
HEAD_DIM: tl.constexpr,
396+
HEAD_DIM_ORIG: tl.constexpr,
375397
BLOCK_M: tl.constexpr,
376398
BLOCK_N: tl.constexpr,
377399
MASKED_BLOCKS: tl.constexpr,
@@ -386,6 +408,8 @@ def la_persistent_inner(
386408
max_tiles_per_wg: tl.constexpr,
387409
tiles_per_head: tl.constexpr,
388410
num_splits: tl.constexpr,
411+
gqa_group_size: tl.constexpr,
412+
use_64_indexing: tl.constexpr,
389413
):
390414

391415
tl.assume(stride_qm > 0) # n_ctx_q
@@ -478,10 +502,13 @@ def la_persistent_inner(
478502
# Q/K/V/O offsets calculation needs global head index.
479503
# When XCD_REMAP=False, xcd_id=0
480504
tile_head_idx_global = HEADS_PER_XCD * xcd_id + tile_head_idx
505+
# Map Q head index to K/V head index via GQA grouping
506+
tile_khead_idx_global = tile_head_idx_global // gqa_group_size
481507

482508
offs_m = tl.arange(0, BLOCK_M)
483509
offs_n = tl.arange(0, BLOCK_N)
484510
offs_k = tl.arange(0, HEAD_DIM)
511+
mask_k_cols = offs_k < HEAD_DIM_ORIG
485512

486513
if causal:
487514
b_seq_size = tile_batch_idx * num_n_blocks
@@ -495,13 +522,13 @@ def la_persistent_inner(
495522

496523
k_offs = (
497524
(b_seq_size + local_iter) * BLOCK_N * stride_kn
498-
+ tile_head_idx_global * stride_kh
525+
+ tile_khead_idx_global * stride_kh
499526
+ offs_n[None, :] * stride_kn
500527
+ offs_k[:, None] * stride_kk
501528
)
502529
v_offs = (
503530
(b_seq_size + local_iter) * BLOCK_N * stride_vn
504-
+ tile_head_idx_global * stride_vh
531+
+ tile_khead_idx_global * stride_vh
505532
+ offs_n[:, None] * stride_vn
506533
+ offs_k[None, :] * stride_vk
507534
)
@@ -531,7 +558,7 @@ def la_persistent_inner(
531558
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
532559
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
533560

534-
q = tl.load(q_ptrs)
561+
q = tl.load(q_ptrs, mask=mask_k_cols[None, :], other=0.0)
535562

536563
m_i, l_i, acc = _attention_inner(
537564
q,
@@ -550,9 +577,11 @@ def la_persistent_inner(
550577
offs_n,
551578
BLOCK_M,
552579
BLOCK_N,
553-
HEAD_DIM,
554-
local_iter,
555-
local_iter_end,
580+
HEAD_DIM_ORIG=HEAD_DIM_ORIG,
581+
HEAD_DIM=HEAD_DIM,
582+
local_iter=local_iter,
583+
local_iter_end=local_iter_end,
584+
use_64_indexing=use_64_indexing,
556585
)
557586

558587
# initialize pointer to m and l
@@ -732,8 +761,13 @@ def la_persistent_inner(
732761

733762
acc0 = acc0 / l_i[:, None]
734763
acc1 = acc1 / l_i[:, None]
735-
tl.store(o_ptrs0, acc0.to(Out.type.element_ty))
736-
tl.store(o_ptrs1, acc1.to(Out.type.element_ty))
764+
COLS_HALF: tl.constexpr = HEAD_DIM // 2
765+
offs0 = tl.arange(0, COLS_HALF)
766+
offs1 = tl.arange(0, COLS_HALF) + COLS_HALF
767+
mask_cols0 = offs0 < HEAD_DIM_ORIG
768+
mask_cols1 = offs1 < HEAD_DIM_ORIG
769+
tl.store(o_ptrs0, acc0.to(Out.type.element_ty), mask=mask_cols0[None, :])
770+
tl.store(o_ptrs1, acc1.to(Out.type.element_ty), mask=mask_cols1[None, :])
737771

738772
# update iter
739773
iter = iter + (local_iter_end - local_iter)

aiter/ops/triton/lean_atten.py

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from bisect import bisect_right
2323
import triton
2424
import triton.language as tl
25-
from aiter.ops.triton._triton_kernels.lean_atten import la_persistent
25+
from aiter.ops.triton._triton_kernels.lean_atten import la_persistent, _get_config
2626
from aiter.ops.triton.utils.logger import AiterTritonLogger
2727
from aiter.ops.triton.utils.device_info import get_num_xcds
28+
from aiter.ops.triton.utils._triton import arch_info
2829

2930
_LOGGER = AiterTritonLogger()
3031

@@ -45,6 +46,7 @@ def persistent_lean_attention(
4546
sm_scale: torch.float16,
4647
causal: bool = True, # causal masking
4748
config: Optional[dict] = None,
49+
program_count: Optional[int] = None,
4850
):
4951
"""
5052
Lean Attention kernel.
@@ -55,7 +57,11 @@ def persistent_lean_attention(
5557
if config is None:
5658
config = _get_config(causal=causal, batch_size=batch_size)
5759
sm_count = arch_info.get_num_sms()
58-
total_programs = sm_count * config["SM_CNT_FACTOR"]
60+
total_programs = (
61+
program_count
62+
if program_count is not None
63+
else sm_count * config["SM_CNT_FACTOR"]
64+
)
5965

6066
return _persistent_lean_attention(
6167
q=q,
@@ -112,7 +118,10 @@ def _persistent_lean_attention(
112118
assert (
113119
HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
114120
), "Incompatible Q/K/V Hidden Dimensions"
115-
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
121+
# Allow irregular head dims by padding compute width and masking I/O
122+
HEAD_DIM_PADDED = triton.next_power_of_2(HEAD_DIM_K)
123+
if HEAD_DIM_PADDED < 16:
124+
HEAD_DIM_PADDED = 16
116125

117126
# MASKED_BLOCKS is used for prefill/causal for BLOCK_M > BLOCK_N
118127
# For MI300, BLOCK_M=128, BLOCK_N=64 is better for performance
@@ -126,6 +135,9 @@ def _persistent_lean_attention(
126135
N_CTX_Q = q.shape[0] // batch_size
127136
N_CTX_K = k.shape[0] # This is the sum of all ctx_n in a batch
128137
H = q.shape[1]
138+
H_K = k.shape[1]
139+
assert H % H_K == 0, "For GQA, the number of Q heads must be divisible by K/V heads"
140+
GQA_GROUP_SIZE = H // H_K
129141
HEADS_PER_XCD = H // NUM_XCDS
130142

131143
qk_scale = sm_scale * LOG_TWO_E
@@ -145,7 +157,7 @@ def _persistent_lean_attention(
145157
N_CTX_Q,
146158
N_CTX_K,
147159
H,
148-
H,
160+
H_K,
149161
BLOCK_M,
150162
BLOCK_N,
151163
total_programs,
@@ -178,6 +190,59 @@ def _persistent_lean_attention(
178190
if DEBUG:
179191
print(f"max_output_tile_cnt={max_output_tile_cnt}")
180192

193+
# Clamp to buffer capacity to avoid deadlocks
194+
max_supported = min(
195+
int(Mp.shape[0]), int(Lp.shape[0]), int(Op.shape[0]), int(locks.numel())
196+
)
197+
total_programs = min(total_programs, max_supported)
198+
199+
# Recompute schedule with clamped total_programs to keep splits consistent
200+
(
201+
num_m_blocks,
202+
num_n_blocks,
203+
high_load_wgs,
204+
max_tiles_per_wg,
205+
tiles_per_head,
206+
total_programs,
207+
num_splits,
208+
even_split,
209+
) = get_num_splits_and_buffer_sizes(
210+
causal,
211+
batch_size,
212+
N_CTX_Q,
213+
N_CTX_K,
214+
H,
215+
H_K,
216+
BLOCK_M,
217+
BLOCK_N,
218+
total_programs,
219+
XCD_REMAP,
220+
NUM_XCDS,
221+
)
222+
223+
# Runtime safety checks
224+
if not (Mp.dim() == 2 and Mp.shape[0] >= total_programs and Mp.shape[1] >= BLOCK_M):
225+
raise ValueError(
226+
f"Mp must have at least [total_programs, BLOCK_M] >= [{total_programs}, {BLOCK_M}], got {tuple(Mp.shape)}"
227+
)
228+
if not (Lp.dim() == 2 and Lp.shape[0] >= total_programs and Lp.shape[1] >= BLOCK_M):
229+
raise ValueError(
230+
f"Lp must have at least [total_programs, BLOCK_M] >= [{total_programs}, {BLOCK_M}], got {tuple(Lp.shape)}"
231+
)
232+
if not (
233+
Op.dim() == 3
234+
and Op.shape[0] >= total_programs
235+
and Op.shape[1] >= N_CTX_Q
236+
and Op.shape[2] >= HEAD_DIM_K
237+
):
238+
raise ValueError(
239+
f"Op must have shape[0] >= total_programs, rows >= N_CTX_Q, cols >= HEAD_DIM_K; got {tuple(Op.shape)} while required first dim >= {total_programs}, rows >= {N_CTX_Q}, cols >= {HEAD_DIM_K}"
240+
)
241+
if not (locks.numel() >= total_programs):
242+
raise ValueError(
243+
f"locks must have length >= total_programs ({total_programs}), got {locks.numel()}"
244+
)
245+
181246
max_output_tile_cnt = max_output_tile_cnt + 4
182247

183248
grid = (total_programs, 1, 1)
@@ -220,10 +285,12 @@ def _persistent_lean_attention(
220285
o.stride(0),
221286
o.stride(1),
222287
o.stride(2),
288+
N_CTX_Q,
223289
Op.stride(0), # total_programs
224290
Op.stride(1), # n_ctx_q
225291
Op.stride(2), # head_dim
226292
HEADS_PER_XCD=HEADS_PER_XCD,
293+
HEAD_DIM_ORIG=HEAD_DIM_K,
227294
HEAD_DIM=HEAD_DIM_K,
228295
BLOCK_M=BLOCK_M,
229296
BLOCK_N=BLOCK_N,
@@ -245,6 +312,16 @@ def _persistent_lean_attention(
245312
num_warps=num_warps,
246313
num_stages=1,
247314
num_ctas=1,
315+
num_heads_q=H,
316+
num_heads_k=H_K,
317+
gqa_group_size=GQA_GROUP_SIZE,
318+
use_64_indexing=(
319+
(k.stride(0) * N_CTX_K) >= (1 << 31)
320+
or (v.stride(0) * N_CTX_K) >= (1 << 31)
321+
or (Op.stride(0) * total_programs) >= (1 << 31)
322+
or (Op.stride(1) * N_CTX_Q) >= (1 << 31)
323+
or (o.stride(0) * N_CTX_Q) >= (1 << 31)
324+
),
248325
**config,
249326
)
250327
"""
@@ -257,7 +334,7 @@ def _persistent_lean_attention(
257334
"""
258335
# print(f"la kernel {la_kernel.n_regs} registers used, {la_kernel.n_spills} spills")
259336
ms = 0
260-
return o, ms
337+
return (o, ms)
261338

262339

263340
def get_num_splits_and_buffer_sizes(
@@ -281,8 +358,7 @@ def get_num_splits_and_buffer_sizes(
281358
num_m_blocks = (max_seqlen_q + BLOCK_M - 1) // BLOCK_M
282359
num_n_blocks = (max_seqlen_k + BLOCK_N - 1) // BLOCK_N
283360

284-
# TODO: Support Grouped-Query Attention
285-
max_seqlen_q = max_seqlen_q * num_heads // num_heads_k
361+
# Schedule over Q heads; K/V heads are mapped inside the kernel via gqa_group_size
286362

287363
# print(f"block_m: {BLOCK_M}, block_n: {BLOCK_N} ")
288364
# print(f"num_m_block: {num_m_blocks}, num_n_block: {num_n_blocks} ")
@@ -303,10 +379,11 @@ def get_num_splits_and_buffer_sizes(
303379
# Decode or Not Causal
304380
tiles_per_head = num_m_blocks * num_n_blocks
305381

382+
# Total tiles across all Q heads
306383
if XCD_REMAP:
307-
total_tiles = tiles_per_head * (num_heads_k // NUM_XCDS)
384+
total_tiles = tiles_per_head * (num_heads // NUM_XCDS)
308385
else:
309-
total_tiles = tiles_per_head * num_heads_k # Total tiles across all heads
386+
total_tiles = tiles_per_head * num_heads
310387

311388
# StreamK Lean has as many threadblocks as SMs
312389
# This should be a function of tile size and number of scratchpad space

0 commit comments

Comments
 (0)