Skip to content

Commit 016dd1c

Browse files
committed
[layout] refactor fp16/bf16 layout
1 parent 540a803 commit 016dd1c

File tree

4 files changed

+18
-23
lines changed

4 files changed

+18
-23
lines changed

tilelang/intrinsics/mma_layout.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
164164
col = 16 * (local_id % 8 // 4) + (thread_id % 4) * 4 + (local_id % 4)
165165
return row, col
166166

167-
def mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id):
167+
def mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id):
168168
"""
169169
groupID = %laneid >> 2
170170
threadID_in_group = %laneid % 4
@@ -174,14 +174,10 @@ def mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id):
174174
175175
col = groupID
176176
"""
177-
row = (thread_id % 4) * 2 + (local_id % 2) + (local_id // 2) * 8
178-
col = (thread_id // 4)
177+
col = (thread_id % 4) * 2 + (local_id % 2) + (local_id // 2) * 8
178+
row = (thread_id // 4) + 8 * (local_id // 4)
179179
return row, col
180180

181-
def mma_load_b_32x8_to_shared_16x16_layout_16bit(thread_id, local_id):
182-
row, col = mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id % 4)
183-
return row, col + 8 * (local_id // 4)
184-
185181
def shared_16x16_to_mma_32x8_smoothlayout(i, j):
186182
return (i * 2 + j // 8, j % 8)
187183

tilelang/intrinsics/mma_macro_generator.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@
1818
shared_16x32_to_mma_32x16_layout_sr_b,
1919
mma_load_a_32x4_to_shared_16x8_layout,
2020
mma_load_b_32x4_to_shared_16x8_layout,
21-
mma_load_b_32x4_to_shared_16x8_layout_16bit,
22-
mma_load_b_32x8_to_shared_16x16_layout_16bit,
21+
mma_load_b_32x8_to_shared_16x16_layout,
2322
mma_load_a_32x16_to_shared_16x32_layout,
2423
mma_load_b_32x16_to_shared_16x32_layout,
2524
mma_load_a_32x8_to_shared_16x16_layout,
@@ -290,7 +289,7 @@ def mma_load_layout(i, j):
290289
if DataType(b_dtype).bits == 8:
291290
mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
292291
elif DataType(b_dtype).bits == 16:
293-
mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout_16bit if replicate_b else mma_load_b_32x4_to_shared_16x8_layout_16bit
292+
mma_load_layout = mma_load_b_32x8_to_shared_16x16_layout
294293
elif DataType(b_dtype).bits == 32:
295294
mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
296295
else:
@@ -334,8 +333,8 @@ def _warp_ldmatrix_b(
334333
# load 16x32 data from shared buffer to local buffer
335334
# must be transposed.
336335
for j in T.serial(local_size_b):
337-
mk, mi = mma_load_layout(tx, j)
338-
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] if trans else B_shared_buf[wi + mi, wk + mk]
336+
mi, mk = mma_load_layout(tx, j)
337+
B_local_buf[i * local_size_b + j] = B_shared_buf[wi + mi, wk + mk] if trans else B_shared_buf[wk + mk, wi + mi]
339338

340339
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
341340

tilelang/intrinsics/mma_sp_layout.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@
22

33
from .mma_layout import (
44
mma_load_a_32x8_to_shared_16x16_layout,
5-
mma_load_b_32x4_to_shared_16x8_layout_16bit,
6-
5+
mma_load_b_32x4_to_shared_8x16_layout_16bit,
76
)
87

98
def mma_sp_load_a_32x8_to_shared_16x32_layout(thread_id, local_id):
109
return mma_load_a_32x8_to_shared_16x16_layout(thread_id, local_id)
1110

12-
def mma_sp_load_b_32x8_to_shared_32x8_layout(thread_id, local_id):
13-
return mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id)
11+
def mma_sp_load_b_32x8_to_shared_8x64_layout(thread_id, local_id):
12+
return mma_load_b_32x8_to_shared_8x32_layout(thread_id, local_id)
1413

15-
def mma_sp_load_b_32x16_to_shared_32x16_layout(thread_id, local_id):
16-
row, col = mma_load_b_32x4_to_shared_16x8_layout_16bit(thread_id, local_id % 8)
17-
return row, col + 8 * (local_id // 8)
14+
def mma_sp_load_b_32x16_to_shared_16x64_layout(thread_id, local_id):
15+
row, col = mma_load_b_32x16_to_shared_16x32_layout(thread_id, local_id % 8)
16+
return row, col + 8 * (local_id // 8)
1817

18+
def mma_sp_load_b_32x16_to_shared_16x32_layout(thread_id, local_id):
19+
return mma_load_b_32x8_to_shared_16x16_layout(thread_id, local_id)
1920

20-
def get_logical_id(thread_id: int) -> int:
2121
return (thread_id // 4) * 2 + (thread_id % 4) % 2
2222

2323
def metadata_load_32x4_to_shared_16x4_layout_8bit(thread_id: int, local_id: int) -> Tuple[int, int]:

tilelang/intrinsics/mma_sp_macro_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def mma_load_layout(i, j):
380380
# if DataType(b_dtype).bits == 8:
381381
# mma_load_layout = mma_load_b_32x16_to_shared_16x32_layout
382382
if DataType(b_dtype).bits == 16:
383-
mma_load_layout = mma_sp_load_b_32x16_to_shared_32x16_layout if replicate_b else mma_sp_load_b_32x8_to_shared_32x8_layout
383+
mma_load_layout = mma_sp_load_b_32x16_to_shared_16x32_layout
384384
# elif DataType(b_dtype).bits == 32:
385385
# mma_load_layout = mma_load_b_32x4_to_shared_16x8_layout
386386
else:
@@ -425,8 +425,8 @@ def _warp_ldmatrix_b(
425425
# load 16x32 data from shared buffer to local buffer
426426
# must be transposed.
427427
for j in T.serial(local_size_b):
428-
mk, mi = mma_load_layout(tx, j)
429-
B_local_buf[i * local_size_b + j] = B_shared_buf[wk + mk, wi + mi] if trans else B_shared_buf[wi + mi, wk + mk]
428+
mi, mk = mma_load_layout(tx, j)
429+
B_local_buf[i * local_size_b + j] = B_shared_buf[wi + mi, wk + mk] if trans else B_shared_buf[wk + mk, wi + mi]
430430

431431
return _warp_ldmatrix_b(B_local_buf, B_shared_buf, ki, thread_binding, rk)
432432

0 commit comments

Comments
 (0)