@@ -114,15 +114,20 @@ def _attention_inner(
114114 offs_n ,
115115 BLOCK_M ,
116116 BLOCK_N ,
117- HEAD_DIM ,
117+ HEAD_DIM_ORIG : tl .constexpr ,
118+ HEAD_DIM : tl .constexpr ,
118119 local_iter ,
119120 local_iter_end ,
121+ use_64_indexing : tl .constexpr ,
120122):
121123 """
122124 Performs attention calculation for an (maybe partial) output tile
123125 """
126+ # Define head-dimension mask for padded dims
127+ offs_k_local = tl .arange (0 , HEAD_DIM )
128+ mask_k_cols_local = offs_k_local < HEAD_DIM_ORIG
124129 for l_iter in range (local_iter , local_iter_end ):
125- k = tl .load (k_ptrs )
130+ k = tl .load (k_ptrs , mask = mask_k_cols_local [:, None ], other = 0.0 )
126131 qk = tl .dot (q , k ) * qk_scale
127132
128133 if causal :
@@ -152,17 +157,24 @@ def _attention_inner(
152157
153158 # Update accumulator
154159 acc = acc * alpha [:, None ]
155- v = tl .load (v_ptrs )
160+ v = tl .load (v_ptrs , mask = mask_k_cols_local [ None , :], other = 0.0 )
156161 acc += tl .dot (p .to (v .dtype ), v )
157162
158163 # Update stats
159164 l_ij = tl .sum (p , 1 )
160165 l_i = l_i * alpha + l_ij
161166 m_i = m_ij .to (m_i .dtype )
162167
163- # update k/v pointer
164- v_ptrs += BLOCK_N * stride_vn
165- k_ptrs += BLOCK_N * stride_kn
168+ # update k/v pointer with optional 64-bit indexing to avoid overflow
169+ if use_64_indexing :
170+ BLOCK_N64 = tl .full ((), BLOCK_N , tl .int64 )
171+ stride_kn64 = tl .full ((), stride_kn , tl .int64 )
172+ stride_vn64 = tl .full ((), stride_vn , tl .int64 )
173+ v_ptrs += BLOCK_N64 * stride_vn64
174+ k_ptrs += BLOCK_N64 * stride_kn64
175+ else :
176+ v_ptrs += BLOCK_N * stride_vn
177+ k_ptrs += BLOCK_N * stride_kn
166178 return m_i , l_i , acc
167179
168180
@@ -222,10 +234,12 @@ def la_persistent(
222234 stride_om , # n_ctx_q
223235 stride_oh , # Head
224236 stride_on , # head_dim
237+ n_ctx_q_rows ,
225238 stride_oph , # total_programs
226239 stride_opm , # n_ctx_q
227240 stride_opn , # head_dim
228241 HEADS_PER_XCD : tl .constexpr ,
242+ HEAD_DIM_ORIG : tl .constexpr ,
229243 HEAD_DIM : tl .constexpr ,
230244 BLOCK_M : tl .constexpr ,
231245 BLOCK_N : tl .constexpr ,
@@ -243,6 +257,10 @@ def la_persistent(
243257 tiles_per_head : tl .constexpr ,
244258 num_splits : tl .constexpr ,
245259 max_output_tile_cnt : tl .constexpr ,
260+ num_heads_q : tl .constexpr ,
261+ num_heads_k : tl .constexpr ,
262+ gqa_group_size : tl .constexpr ,
263+ use_64_indexing : tl .constexpr ,
246264):
247265 if is_pod :
248266 current_pid = pod_pid
@@ -321,6 +339,7 @@ def la_persistent(
321339 xcd_id = xcd_id ,
322340 HEADS_PER_XCD = HEADS_PER_XCD ,
323341 HEAD_DIM = HEAD_DIM ,
342+ HEAD_DIM_ORIG = HEAD_DIM_ORIG ,
324343 BLOCK_M = BLOCK_M ,
325344 BLOCK_N = BLOCK_N ,
326345 MASKED_BLOCKS = MASKED_BLOCKS ,
@@ -335,6 +354,8 @@ def la_persistent(
335354 max_tiles_per_wg = max_tiles_per_wg ,
336355 tiles_per_head = tiles_per_head ,
337356 num_splits = num_splits ,
357+ gqa_group_size = gqa_group_size ,
358+ use_64_indexing = use_64_indexing ,
338359 )
339360
340361
@@ -372,6 +393,7 @@ def la_persistent_inner(
372393 xcd_id , # The XCD the pid belongs to
373394 HEADS_PER_XCD ,
374395 HEAD_DIM : tl .constexpr ,
396+ HEAD_DIM_ORIG : tl .constexpr ,
375397 BLOCK_M : tl .constexpr ,
376398 BLOCK_N : tl .constexpr ,
377399 MASKED_BLOCKS : tl .constexpr ,
@@ -386,6 +408,8 @@ def la_persistent_inner(
386408 max_tiles_per_wg : tl .constexpr ,
387409 tiles_per_head : tl .constexpr ,
388410 num_splits : tl .constexpr ,
411+ gqa_group_size : tl .constexpr ,
412+ use_64_indexing : tl .constexpr ,
389413):
390414
391415 tl .assume (stride_qm > 0 ) # n_ctx_q
@@ -478,10 +502,13 @@ def la_persistent_inner(
478502 # Q/K/V/O offsets calculation needs global head index.
479503 # When XCD_REMAP=False, xcd_id=0
480504 tile_head_idx_global = HEADS_PER_XCD * xcd_id + tile_head_idx
505+ # Map Q head index to K/V head index via GQA grouping
506+ tile_khead_idx_global = tile_head_idx_global // gqa_group_size
481507
482508 offs_m = tl .arange (0 , BLOCK_M )
483509 offs_n = tl .arange (0 , BLOCK_N )
484510 offs_k = tl .arange (0 , HEAD_DIM )
511+ mask_k_cols = offs_k < HEAD_DIM_ORIG
485512
486513 if causal :
487514 b_seq_size = tile_batch_idx * num_n_blocks
@@ -495,13 +522,13 @@ def la_persistent_inner(
495522
496523 k_offs = (
497524 (b_seq_size + local_iter ) * BLOCK_N * stride_kn
498- + tile_head_idx_global * stride_kh
525+ + tile_khead_idx_global * stride_kh
499526 + offs_n [None , :] * stride_kn
500527 + offs_k [:, None ] * stride_kk
501528 )
502529 v_offs = (
503530 (b_seq_size + local_iter ) * BLOCK_N * stride_vn
504- + tile_head_idx_global * stride_vh
531+ + tile_khead_idx_global * stride_vh
505532 + offs_n [:, None ] * stride_vn
506533 + offs_k [None , :] * stride_vk
507534 )
@@ -531,7 +558,7 @@ def la_persistent_inner(
531558 l_i = tl .zeros ([BLOCK_M ], dtype = tl .float32 ) + 1.0
532559 acc = tl .zeros ([BLOCK_M , HEAD_DIM ], dtype = tl .float32 )
533560
534- q = tl .load (q_ptrs )
561+ q = tl .load (q_ptrs , mask = mask_k_cols [ None , :], other = 0.0 )
535562
536563 m_i , l_i , acc = _attention_inner (
537564 q ,
@@ -550,9 +577,11 @@ def la_persistent_inner(
550577 offs_n ,
551578 BLOCK_M ,
552579 BLOCK_N ,
553- HEAD_DIM ,
554- local_iter ,
555- local_iter_end ,
580+ HEAD_DIM_ORIG = HEAD_DIM_ORIG ,
581+ HEAD_DIM = HEAD_DIM ,
582+ local_iter = local_iter ,
583+ local_iter_end = local_iter_end ,
584+ use_64_indexing = use_64_indexing ,
556585 )
557586
558587 # initialize pointer to m and l
@@ -732,8 +761,13 @@ def la_persistent_inner(
732761
733762 acc0 = acc0 / l_i [:, None ]
734763 acc1 = acc1 / l_i [:, None ]
735- tl .store (o_ptrs0 , acc0 .to (Out .type .element_ty ))
736- tl .store (o_ptrs1 , acc1 .to (Out .type .element_ty ))
764+ COLS_HALF : tl .constexpr = HEAD_DIM // 2
765+ offs0 = tl .arange (0 , COLS_HALF )
766+ offs1 = tl .arange (0 , COLS_HALF ) + COLS_HALF
767+ mask_cols0 = offs0 < HEAD_DIM_ORIG
768+ mask_cols1 = offs1 < HEAD_DIM_ORIG
769+ tl .store (o_ptrs0 , acc0 .to (Out .type .element_ty ), mask = mask_cols0 [None , :])
770+ tl .store (o_ptrs1 , acc1 .to (Out .type .element_ty ), mask = mask_cols1 [None , :])
737771
738772 # update iter
739773 iter = iter + (local_iter_end - local_iter )
0 commit comments