Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions examples/gemm/example_gemm_intrinsics_dcu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,)
from tilelang.transform import simplify_prim_func
from tilelang import disable_cache

disable_cache()

def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

if out_dtype == "int32":
micro_size_k = 32

Comment on lines +50 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

micro_size_k should depend on in_dtype, not out_dtype

Int8 paths require K-fragment=32 regardless of the final out dtype. Gate on in_dtype to avoid incorrect tiling when out_dtype is int32 but inputs aren’t int8.

Apply this diff:

-    if out_dtype == "int32":
-        micro_size_k = 32
+    if in_dtype == "int8":
+        micro_size_k = 32
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if out_dtype == "int32":
micro_size_k = 32
if in_dtype == "int8":
micro_size_k = 32
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 50 to 52, the code
sets micro_size_k based on out_dtype but it should be gated on in_dtype: change
the condition to check if in_dtype == "int8" (or equivalent int8 type sentinel
used elsewhere) and set micro_size_k = 32 for int8 input paths; leave other
dtype branches unchanged so int8 inputs always use K-fragment=32 even if
out_dtype is int32.

# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"

# Pipeline Stage
stage = 2

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

Comment on lines +65 to +68
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Tile edge handling missing; kernel reads/writes out of bounds for non-multiple sizes

Grid uses ceildiv, but loads/stores assume perfect multiples of block_M/N/K. Either add guards in load/store loops or assert divisibility.

Minimal safety precondition (quick fix):

     block_K = chunk
@@
-    A_shape = (M, K)
+    # Precondition for this example kernel (simplifies control flow)
+    assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+        "Example kernel currently requires M,N,K to be multiples of block sizes"
+    A_shape = (M, K)

Recommended follow-up (optional): add bounds checks in the A/B loads and C stores so arbitrary M,N,K work. I can draft that if you want.

Also applies to: 108-159


I need to examine the actual code to verify the concerns raised in the review comment.


🏁 Script executed:

#!/bin/bash
# Retrieve the file and examine the kernel implementation
if [ -f "examples/gemm/example_gemm_intrinsics_dcu.py" ]; then
    echo "=== File Content ===" 
    cat examples/gemm/example_gemm_intrinsics_dcu.py
else
    echo "File not found at examples/gemm/example_gemm_intrinsics_dcu.py"
    find . -name "example_gemm_intrinsics_dcu.py" -type f 2>/dev/null
fi

Length of output: 5708


Add bounds-checking guards or assertion for non-multiple tensor dimensions

The kernel grid uses ceildiv (line 99) but load/store loops assume perfect multiples. With M, N, or K not divisible by block_M, block_N, or block_K, the loops read/write out of bounds. The current test (M=N=K=16384) masks this because all are divisible by their block sizes.

Either add bounds checks in the A/B loads (lines 117–121) and C store (lines 138–145), or add an early assertion:

     block_K = chunk
@@
+    assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+        "Example kernel requires M, N, K to be multiples of block sizes (M:%d, N:%d, K:%d; block_M:%d, block_N:%d, block_K:%d)" % (M, N, K, block_M, block_N, block_K)
     A_shape = (M, K)
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 65–68, the computed
block sizes (block_M, block_N, block_K) are used later by loads/stores that
assume tensor dimensions are exact multiples; add either runtime assertions
early (assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0) or,
preferably, add bounds-check guards around A/B loads (lines ~117–121) and the C
store (lines ~138–145): before reading A or B elements check the computed global
row/col indices against M/N/K and substitute zero (or a safe value) for
out-of-bounds loads; before writing C check indices and skip stores outside M/N,
ensuring no out-of-bounds memory access.

A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMAC Wrapper to Auto Generate Code for MMAC
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=stage):

# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mmac_emitter.ldmatrix_a(A_local, A_shared, ki)

# Load B into fragment
mmac_emitter.ldmatrix_b(B_local, B_shared, ki)

# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_local)

# Perform STMatrix
mmac_emitter.stmatrix(C_local, C_shared)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y,
i // micro_size_x,
i % micro_size_x,
j % micro_size_y,
]
Comment on lines +153 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

C_shared indexing appears transposed vs declared shape

C_shared is declared as (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y), but store uses [j-group, i-group, i-rem, j-rem]. Swap i/j groups to match the shape.

Apply this diff:

-                C[by * block_M + i, bx * block_N + j] = C_shared[
-                    j // micro_size_y,
-                    i // micro_size_x,  
-                    i % micro_size_x,
-                    j % micro_size_y,
-                ]
+                C[by * block_M + i, bx * block_N + j] = C_shared[
+                    i // micro_size_x,
+                    j // micro_size_y,
+                    i % micro_size_x,
+                    j % micro_size_y,
+                ]
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 153 to 158, the
indexing used to read from C_shared is transposed relative to its declared shape
(C_shared declared as (block_M // micro_size_x, block_N // micro_size_y,
micro_size_x, micro_size_y)); fix the read to match that shape by swapping the
group indices so the access becomes C_shared[i // micro_size_x, j //
micro_size_y, i % micro_size_x, j % micro_size_y] instead of the current [j //
micro_size_y, i // micro_size_x, i % micro_size_x, j % micro_size_y].


return gemm_intrinsics


def ref_program(A, B):
return A @ B.T


def main():
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None

profiler = kernel.get_profiler()

latency = profiler.do_bench(profiler.func, warmup=25)

print(latency)
print(kernel.get_kernel_source())
# Ensure that the latency is not None
assert latency is not None

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
main()


161 changes: 161 additions & 0 deletions examples/minference/ops/vertical_slash_index.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <assert.h>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>

#include <hip/hip_runtime.h>

__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}

__global__ void convert_vertical_slash_indexes_kernel(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int N_HEADS,
int N_ROWS,
int BLOCK_SIZE_M,
int BLOCK_SIZE_N,
int NNZ_V,
int NNZ_S
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;

int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;

Comment on lines +40 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against row index overflow.

Threads with block_idx_m ≥ N_ROWS can compute a valid start_m < seqlen (if seqlen > context), causing OOB on row_offset. Add an explicit guard.

Apply:

   int seqlen = seqlens[batch_idx];
   int block_idx_m = group_idx * blockDim.x + threadIdx.x;
+  if (block_idx_m >= N_ROWS) {
+      return;
+  }
   int start_m = block_idx_m * BLOCK_SIZE_M;
   if (start_m >= seqlen) {
       return;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
if (block_idx_m >= N_ROWS) {
return;
}
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 40 to 54,
threads where block_idx_m >= N_ROWS can still have start_m < seqlen and will
compute row_offset and use it causing out-of-bounds accesses; add an explicit
guard right after computing block_idx_m (before computing row_offset and any
row-dependent offsets) that returns when block_idx_m >= N_ROWS so subsequent
uses of row_offset, block_count/offset, column_count/index are safe.

int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
Comment on lines +55 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix OOB reads when NNZ_S/NNZ_V are zero and bound the pre-scan.

Accessing vertical_indexes[v++] and slash_indexes[s++] without checking NNZ_* risks OOB. The pre-loop while also lacks a bound on s.

Apply:

-    int tmp_col_cnt = 0, tmp_blk_cnt = 0;
-    int s = 0, v = 0;
-    int v_idx = vertical_indexes[v++];
-    int s_idx = slash_indexes[s++];
-    while (s_idx >= end_m) {
-        s_idx = slash_indexes[s++];
-    }
-    s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
+    int tmp_col_cnt = 0, tmp_blk_cnt = 0;
+    int s = 0, v = 0;
+    // Safe init of v_idx
+    int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
+    // Handle NNZ_S == 0 early
+    if (NNZ_S == 0) {
+        block_count[0] = 0;
+        column_count[0] = 0;
+        return;
+    }
+    int s_idx = slash_indexes[s++];
+    while (s < NNZ_S && s_idx >= end_m) {
+        s_idx = slash_indexes[s++];
+    }
+    if (s_idx >= end_m) {
+        // No slash indices relevant for this row
+        block_count[0] = 0;
+        column_count[0] = 0;
+        return;
+    }
+    s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
// Safe init of v_idx
int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
// Handle NNZ_S == 0 early
if (NNZ_S == 0) {
block_count[0] = 0;
column_count[0] = 0;
return;
}
int s_idx = slash_indexes[s++];
while (s < NNZ_S && s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
if (s_idx >= end_m) {
// No slash indices relevant for this row
block_count[0] = 0;
column_count[0] = 0;
return;
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);

int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
Comment on lines +67 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Cap per-row writes to avoid overflow of column_index and block_offset.

tmp_col_cnt and tmp_blk_cnt can exceed NNZ_V/NNZ_S; cap writes and bound save_blocks.

Apply:

-                column_index[tmp_col_cnt++] = v_idx;
+                if (tmp_col_cnt < NNZ_V) {
+                    column_index[tmp_col_cnt++] = v_idx;
+                }
@@
-                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);
@@
-                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);

And update save_blocks to accept a max:

-__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
-    for (int idx = range_start; idx < range_end; idx += block_size) {
-        block_offset[block_count++] = idx;
-    }
-}
+__device__ __forceinline__ void save_blocks(int* block_offset,
+                                            int range_start,
+                                            int range_end,
+                                            int block_size,
+                                            int& blk_cnt,
+                                            int max_blocks) {
+    for (int idx = range_start; idx < range_end && blk_cnt < max_blocks; idx += block_size) {
+        block_offset[blk_cnt++] = idx;
+    }
+}

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 67 to 78,
tmp_col_cnt and tmp_blk_cnt can grow past their backing limits causing
out-of-bounds writes; cap increments and writes so they never exceed NNZ_V and
NNZ_S respectively and pass a max limit into save_blocks. Change the code paths
that write into column_index and block_offset to check (tmp_col_cnt < NNZ_V) and
(tmp_blk_cnt < NNZ_S) before assigning/incrementing and, when calling
save_blocks, pass a new max parameter (e.g., max_tmp_blk_cnt) instead of the raw
tmp_blk_cnt; then update save_blocks signature and implementation to accept that
max and ensure it only processes up to that bounded count and validates indices
before accessing arrays.

break;
}
if (s_idx > range_end + BLOCK_SIZE_M) {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}

block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}

void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int BATCH_SIZE,
int N_HEADS,
int N_ROWS,
int NNZ_V,
int NNZ_S
) {
const int BLOCK_SIZE_M = 64;
const int BLOCK_SIZE_N = 64;
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0,
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
Comment on lines +114 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use PyTorch’s current stream, not stream 0; add launch error check.

Launching on stream 0 breaks PyTorch stream semantics and can race other ops. Also add a kernel launch check.

Apply:

+#include <ATen/cuda/CUDAContext.h>  // works for CUDA and ROCm builds
@@
-   hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0, 
+   auto stream = at::cuda::getCurrentCUDAStream();
+   hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, stream.stream(),
         seqlens, vertical_indexes, slash_indexes,
         block_count, block_offset, column_count, column_index,
         N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
     );
+   AT_CUDA_CHECK(hipGetLastError());

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 114-118, the
kernel is launched on stream 0 and lacks a launch error check; change the launch
to use PyTorch's current HIP stream (retrieve the current stream from the
ATen/C10 API rather than hardcoding 0) and after hipLaunchKernelGGL add a kernel
launch error check (call hipGetLastError() and handle/report the error or
throw/log if non-zero) so the launch respects PyTorch stream semantics and
failures are detected.

}

std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);

hipSetDevice(seqlens.get_device());

int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;

torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());

convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
Comment on lines +121 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate inputs, use DeviceGuard, enforce dtypes/contiguity; replace assert with TORCH_CHECK.

Ensure tensors live on the same device/stream, are int32, and contiguous. Avoid hipSetDevice; use DeviceGuard. Assert() is compiled out in Release.

Apply:

+#include <c10/core/DeviceGuard.h>
@@
-    assert(block_size_M == 64);
-    assert(block_size_N == 64);
+    TORCH_CHECK(block_size_M == 64, "block_size_M must be 64");
+    TORCH_CHECK(block_size_N == 64, "block_size_N must be 64");
@@
-    hipSetDevice(seqlens.get_device());
+    c10::DeviceGuard guard(seqlens.device());
+    TORCH_CHECK(seqlens.is_cuda(), "seqlens must be on CUDA/HIP device");
+    TORCH_CHECK(vertical_indexes.is_cuda() && slash_indexes.is_cuda(), "Inputs must be CUDA/HIP tensors");
+    TORCH_CHECK(vertical_indexes.device() == seqlens.device() && slash_indexes.device() == seqlens.device(),
+                "All inputs must be on the same device");
+    TORCH_CHECK(seqlens.scalar_type() == at::kInt, "seqlens must be int32");
+    TORCH_CHECK(vertical_indexes.scalar_type() == at::kInt && slash_indexes.scalar_type() == at::kInt,
+                "vertical_indexes/slash_indexes must be int32");
+
+    seqlens = seqlens.contiguous();
+    vertical_indexes = vertical_indexes.contiguous();
+    slash_indexes = slash_indexes.contiguous();
@@
-    int batch_size = slash_indexes.size(0);
-    int num_heads = slash_indexes.size(1);
-    int nnz_slash = slash_indexes.size(2);
-    int nnz_vertical = vertical_indexes.size(2);
+    int batch_size = static_cast<int>(slash_indexes.size(0));
+    int num_heads = static_cast<int>(slash_indexes.size(1));
+    int nnz_slash = static_cast<int>(slash_indexes.size(2));
+    int nnz_vertical = static_cast<int>(vertical_indexes.size(2));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);
hipSetDevice(seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
TORCH_CHECK(block_size_M == 64, "block_size_M must be 64");
TORCH_CHECK(block_size_N == 64, "block_size_N must be 64");
c10::DeviceGuard guard(seqlens.device());
TORCH_CHECK(seqlens.is_cuda(), "seqlens must be on CUDA/HIP device");
TORCH_CHECK(vertical_indexes.is_cuda() && slash_indexes.is_cuda(), "Inputs must be CUDA/HIP tensors");
TORCH_CHECK(vertical_indexes.device() == seqlens.device() && slash_indexes.device() == seqlens.device(),
"All inputs must be on the same device");
TORCH_CHECK(seqlens.scalar_type() == at::kInt, "seqlens must be int32");
TORCH_CHECK(vertical_indexes.scalar_type() == at::kInt && slash_indexes.scalar_type() == at::kInt,
"vertical_indexes/slash_indexes must be int32");
seqlens = seqlens.contiguous();
vertical_indexes = vertical_indexes.contiguous();
slash_indexes = slash_indexes.contiguous();
int batch_size = static_cast<int>(slash_indexes.size(0));
int num_heads = static_cast<int>(slash_indexes.size(1));
int nnz_slash = static_cast<int>(slash_indexes.size(2));
int nnz_vertical = static_cast<int>(vertical_indexes.size(2));
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);


return { block_count, block_offset, column_count, column_index };
}
18 changes: 18 additions & 0 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
return block_layout;
}

Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64)
LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, true);
return block_layout;
}

Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
Expand Down Expand Up @@ -730,6 +747,7 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
Expand Down
3 changes: 3 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Expand Down
Loading