@@ -277,6 +277,7 @@ def ldmatrix_a(self,
277277 local_size_a = self .local_size_a
278278 a_dtype = self .a_dtype
279279 a_transposed = self .a_transposed
280+ ldsm_trans = self .a_transposed
280281 # ldmatrix cannot be used for int8 + trans case.
281282 ldmatrix_available = not (DataType (a_dtype ).bits != 16 and a_transposed )
282283
@@ -305,7 +306,6 @@ def _warp_ldmatrix_a(
305306 ):
306307 stride = A_shared_buf .shape [- 1 ]
307308 tx , _ , warp_m = self .extract_thread_binding (thread_binding )
308- trans = self .a_transposed
309309
310310 for i in T .serial (warp_rows ):
311311 # Assign A_shared_buf_elem
@@ -315,7 +315,7 @@ def _warp_ldmatrix_a(
315315 if ldmatrix_available :
316316 T .ptx_ldmatrix (
317317 a_dtype ,
318- T .bool (trans ),
318+ T .bool (ldsm_trans ),
319319 4 ,
320320 ".b16" ,
321321 A_local_buf .data ,
@@ -326,7 +326,7 @@ def _warp_ldmatrix_a(
326326 else :
327327 for j in T .serial (local_size_a ):
328328 mi , mk = mma_load_layout (tx , j )
329- A_local_buf [i * local_size_a + j ] = A_shared_buf [wk + mk , wi + mi ] if trans else A_shared_buf [wi + mi , wk + mk ]
329+ A_local_buf [i * local_size_a + j ] = A_shared_buf [wk + mk , wi + mi ] if a_transposed else A_shared_buf [wi + mi , wk + mk ]
330330
331331 return _warp_ldmatrix_a (A_local_buf , A_shared_buf , ki , thread_binding , rk )
332332
@@ -411,11 +411,11 @@ def ldmatrix_b(self,
411411 local_size_b = self .local_size_b
412412 b_dtype = self .b_dtype
413413 b_transposed = self .b_transposed
414+ ldsm_trans = not b_transposed
414415 thread_binding = self .get_thread_binding ()
415416 replicate_b = (self .n_dim == 16 )
416417 # ldmatrix cannot be used for int8 + trans case.
417- ldmatrix_available = False # TODO: use ldmatrix when possible
418-
418+ ldmatrix_available = not (DataType (b_dtype ).bits != 16 and not b_transposed )
419419 def mma_load_layout (i , j ):
420420 return i , j
421421
@@ -439,8 +439,6 @@ def _warp_ldmatrix_b(
439439 ):
440440 stride = B_shared_buf .shape [- 1 ]
441441 tx , warp_n , _ = self .extract_thread_binding (thread_binding )
442- trans = not b_transposed
443-
444442 for i in T .serial (warp_cols ):
445443 # Assign B_shared_elem
446444 wi , wk = (
@@ -454,13 +452,24 @@ def _warp_ldmatrix_b(
454452
455453 T .ptx_ldmatrix (
456454 b_dtype ,
457- T .bool (trans ),
455+ T .bool (ldsm_trans ),
458456 4 if replicate_b else 2 ,
459457 ".b16" ,
460458 B_local_buf .data ,
461459 i * local_size_b ,
462460 T .address_of (B_shared_buf_elem ),
463- get_ldmatrix_offset ("B" , tx , 0 , stride , b_dtype , b_transposed ),
461+ get_ldmatrix_offset ("B" , tx , 0 , stride , b_dtype , ldsm_trans ),
462+ )
463+
464+ T .ptx_ldmatrix (
465+ b_dtype ,
466+ T .bool (ldsm_trans ),
467+ 4 if replicate_b else 2 ,
468+ ".b16" ,
469+ B_local_buf .data ,
470+ i * local_size_b + lift (local_size_b ) // 2 ,
471+ T .address_of (B_shared_buf_elem ),
472+ get_ldmatrix_offset ("B" , tx , 8 , stride , b_dtype , ldsm_trans ),
464473 )
465474
466475 else :
0 commit comments