Skip to content

Commit 8304674

Browse files
committed
[layout] fix ldsm trans passing
1 parent 7f934c6 commit 8304674

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

tilelang/intrinsics/mma_macro_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def ldmatrix_a(self,
211211
local_size_a = self.local_size_a
212212
a_dtype = self.a_dtype
213213
a_transposed = self.a_transposed
214+
ldsm_trans = self.a_transposed
214215
# ldmatrix cannot be used for int8 + trans case.
215216
ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed)
216217

@@ -239,7 +240,6 @@ def _warp_ldmatrix_a(
239240
):
240241
stride = A_shared_buf.shape[-1]
241242
tx, _, warp_m = self.extract_thread_binding(thread_binding)
242-
trans = self.a_transposed
243243

244244
for i in T.serial(warp_rows):
245245
# Assign A_shared_buf_elem
@@ -249,18 +249,18 @@ def _warp_ldmatrix_a(
249249
if ldmatrix_available:
250250
T.ptx_ldmatrix(
251251
a_dtype,
252-
T.bool(trans),
252+
T.bool(ldsm_trans),
253253
4,
254254
".b16",
255255
A_local_buf.data,
256256
i * local_size_a,
257257
T.address_of(A_shared_buf_elem),
258-
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, a_transposed),
258+
get_ldmatrix_offset("A", tx, 0, stride, a_dtype, ldsm_trans),
259259
)
260260
else:
261261
for j in T.serial(local_size_a):
262262
mi, mk = mma_load_layout(tx, j)
263-
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if trans else A_shared_buf[wi + mi, wk + mk]
263+
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk]
264264

265265
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
266266

@@ -277,6 +277,7 @@ def ldmatrix_b(self,
277277
local_size_b = self.local_size_b
278278
b_dtype = self.b_dtype
279279
b_transposed = self.b_transposed
280+
ldsm_trans = not b_transposed
280281
thread_binding = self.get_thread_binding()
281282
replicate_b = (self.n_dim == 16)
282283
# ldmatrix cannot be used for int8 + trans case.
@@ -305,7 +306,6 @@ def _warp_ldmatrix_b(
305306
):
306307
stride = B_shared_buf.shape[-1]
307308
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
308-
trans = not b_transposed
309309

310310
for i in T.serial(warp_cols):
311311
# Assign B_shared_elem
@@ -320,13 +320,13 @@ def _warp_ldmatrix_b(
320320

321321
T.ptx_ldmatrix(
322322
b_dtype,
323-
T.bool(trans),
323+
T.bool(ldsm_trans),
324324
4 if replicate_b else 2,
325325
".b16",
326326
B_local_buf.data,
327327
i * local_size_b,
328328
T.address_of(B_shared_buf_elem),
329-
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
329+
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, ldsm_trans),
330330
)
331331

332332
else:

tilelang/intrinsics/mma_sp_macro_generator.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def ldmatrix_a(self,
277277
local_size_a = self.local_size_a
278278
a_dtype = self.a_dtype
279279
a_transposed = self.a_transposed
280+
ldsm_trans = self.a_transposed
280281
# ldmatrix cannot be used for int8 + trans case.
281282
ldmatrix_available = not (DataType(a_dtype).bits != 16 and a_transposed)
282283

@@ -305,7 +306,6 @@ def _warp_ldmatrix_a(
305306
):
306307
stride = A_shared_buf.shape[-1]
307308
tx, _, warp_m = self.extract_thread_binding(thread_binding)
308-
trans = self.a_transposed
309309

310310
for i in T.serial(warp_rows):
311311
# Assign A_shared_buf_elem
@@ -315,7 +315,7 @@ def _warp_ldmatrix_a(
315315
if ldmatrix_available:
316316
T.ptx_ldmatrix(
317317
a_dtype,
318-
T.bool(trans),
318+
T.bool(ldsm_trans),
319319
4,
320320
".b16",
321321
A_local_buf.data,
@@ -326,7 +326,7 @@ def _warp_ldmatrix_a(
326326
else:
327327
for j in T.serial(local_size_a):
328328
mi, mk = mma_load_layout(tx, j)
329-
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if trans else A_shared_buf[wi + mi, wk + mk]
329+
A_local_buf[i * local_size_a + j] = A_shared_buf[wk + mk, wi + mi] if a_transposed else A_shared_buf[wi + mi, wk + mk]
330330

331331
return _warp_ldmatrix_a(A_local_buf, A_shared_buf, ki, thread_binding, rk)
332332

@@ -411,11 +411,11 @@ def ldmatrix_b(self,
411411
local_size_b = self.local_size_b
412412
b_dtype = self.b_dtype
413413
b_transposed = self.b_transposed
414+
ldsm_trans = not b_transposed
414415
thread_binding = self.get_thread_binding()
415416
replicate_b = (self.n_dim == 16)
416417
# ldmatrix cannot be used for int8 + trans case.
417-
ldmatrix_available = False # TODO: use ldmatrix when possible
418-
418+
ldmatrix_available = not (DataType(b_dtype).bits != 16 and not b_transposed)
419419
def mma_load_layout(i, j):
420420
return i, j
421421

@@ -439,8 +439,6 @@ def _warp_ldmatrix_b(
439439
):
440440
stride = B_shared_buf.shape[-1]
441441
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
442-
trans = not b_transposed
443-
444442
for i in T.serial(warp_cols):
445443
# Assign B_shared_elem
446444
wi, wk = (
@@ -454,13 +452,24 @@ def _warp_ldmatrix_b(
454452

455453
T.ptx_ldmatrix(
456454
b_dtype,
457-
T.bool(trans),
455+
T.bool(ldsm_trans),
458456
4 if replicate_b else 2,
459457
".b16",
460458
B_local_buf.data,
461459
i * local_size_b,
462460
T.address_of(B_shared_buf_elem),
463-
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, b_transposed),
461+
get_ldmatrix_offset("B", tx, 0, stride, b_dtype, ldsm_trans),
462+
)
463+
464+
T.ptx_ldmatrix(
465+
b_dtype,
466+
T.bool(ldsm_trans),
467+
4 if replicate_b else 2,
468+
".b16",
469+
B_local_buf.data,
470+
i * local_size_b + lift(local_size_b) // 2,
471+
T.address_of(B_shared_buf_elem),
472+
get_ldmatrix_offset("B", tx, 8, stride, b_dtype, ldsm_trans),
464473
)
465474

466475
else:

0 commit comments

Comments
 (0)