-
Notifications
You must be signed in to change notification settings - Fork 301
[DCU] Support the deployment and operation of tilelang on the Hygon DCU backend #1145
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| from tilelang import tvm as tvm | ||
| from tvm import DataType | ||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang.intrinsics import get_swizzle_layout | ||
| from tilelang.intrinsics.mmac_macro_generator import ( | ||
| MatrixCoreIntrinEmitter,) | ||
| from tilelang.transform import simplify_prim_func | ||
| from tilelang import disable_cache | ||
|
|
||
| disable_cache() | ||
|
|
||
| def make_swizzle_layout(shared_buf): | ||
| dtype = shared_buf.dtype | ||
| shape = shared_buf.shape | ||
|
|
||
| can_swizzle = shape[-1] * DataType(dtype).bits == 512 | ||
| if not can_swizzle: | ||
| return T.Layout(shape, lambda *args: args) | ||
|
|
||
| def transform_func(i, j): | ||
| new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype) | ||
| return [new_warp_i, new_warp_j] | ||
|
|
||
| return T.Layout(shape, transform_func) | ||
|
|
||
|
|
||
| @tilelang.jit(out_idx=[2]) | ||
| @simplify_prim_func | ||
| def tl_matmul( | ||
| M, | ||
| N, | ||
| K, | ||
| in_dtype, | ||
| out_dtype, | ||
| accum_dtype, | ||
| ): | ||
| assert in_dtype in [ | ||
| "float16", | ||
| "int8", | ||
| ], "Currently only float16 and int8 are supported" | ||
| assert out_dtype in [ | ||
| "float16", | ||
| "float32", | ||
| "int32", | ||
| ], "Currently only float16, float32 and int32 are supported" | ||
|
|
||
| micro_size_x = micro_size_y = micro_size_k = 16 | ||
|
|
||
| if out_dtype == "int32": | ||
| micro_size_k = 32 | ||
|
|
||
| # This is a debug config | ||
| block_row_warps = 2 | ||
| block_col_warps = 2 | ||
| warp_row_tiles = 64 | ||
| warp_col_tiles = 64 | ||
| # chunk = 32 if in_dtype == "float16" else 64 | ||
| chunk = 32 | ||
| shared_scope = "shared.dyn" | ||
|
|
||
| # Pipeline Stage | ||
| stage = 2 | ||
|
|
||
| block_M = block_row_warps * warp_row_tiles | ||
| block_N = block_col_warps * warp_col_tiles | ||
| block_K = chunk | ||
|
|
||
|
Comment on lines
+65
to
+68
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainTile edge handling missing; kernel reads/writes out of bounds for non-multiple sizes Grid uses ceildiv, but loads/stores assume perfect multiples of block_M/N/K. Either add guards in load/store loops or assert divisibility. Minimal safety precondition (quick fix): block_K = chunk
@@
- A_shape = (M, K)
+ # Precondition for this example kernel (simplifies control flow)
+ assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+ "Example kernel currently requires M,N,K to be multiples of block sizes"
+ A_shape = (M, K)Recommended follow-up (optional): add bounds checks in the A/B loads and C stores so arbitrary M,N,K work. I can draft that if you want. Also applies to: 108-159 I need to examine the actual code to verify the concerns raised in the review comment. 🏁 Script executed: #!/bin/bash
# Retrieve the file and examine the kernel implementation
if [ -f "examples/gemm/example_gemm_intrinsics_dcu.py" ]; then
echo "=== File Content ==="
cat examples/gemm/example_gemm_intrinsics_dcu.py
else
echo "File not found at examples/gemm/example_gemm_intrinsics_dcu.py"
find . -name "example_gemm_intrinsics_dcu.py" -type f 2>/dev/null
fiLength of output: 5708 Add bounds-checking guards or assertion for non-multiple tensor dimensions The kernel grid uses Either add bounds checks in the A/B loads (lines 117–121) and C store (lines 138–145), or add an early assertion: block_K = chunk
@@
+ assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+ "Example kernel requires M, N, K to be multiples of block sizes (M:%d, N:%d, K:%d; block_M:%d, block_N:%d, block_K:%d)" % (M, N, K, block_M, block_N, block_K)
A_shape = (M, K)🤖 Prompt for AI Agents |
||
| A_shape = (M, K) | ||
| B_shape = (N, K) | ||
| A_shared_shape = (block_M, block_K) | ||
| B_shared_shape = (block_N, block_K) | ||
| C_shared_shape = ( | ||
| block_M // micro_size_x, | ||
| block_N // micro_size_y, | ||
| micro_size_x, | ||
| micro_size_y, | ||
| ) | ||
|
|
||
| warp_size = 64 | ||
| threads = warp_size * (block_row_warps * block_col_warps) | ||
| local_size_a = (micro_size_x * micro_size_k) // warp_size | ||
| local_size_b = (micro_size_y * micro_size_k) // warp_size | ||
| local_size_c = (micro_size_x * micro_size_y) // warp_size | ||
| warp_rows = warp_row_tiles // micro_size_x | ||
| warp_cols = warp_col_tiles // micro_size_y | ||
|
|
||
| # MMAC Wrapper to Auto Generate Code for MMAC | ||
| mmac_emitter = MatrixCoreIntrinEmitter( | ||
| a_dtype=in_dtype, | ||
| b_dtype=in_dtype, | ||
| accum_dtype=accum_dtype, | ||
| a_transposed=False, | ||
| b_transposed=True, | ||
| block_row_warps=block_row_warps, | ||
| block_col_warps=block_col_warps, | ||
| warp_row_tiles=warp_row_tiles, | ||
| warp_col_tiles=warp_col_tiles, | ||
| chunk=chunk, | ||
| ) | ||
|
|
||
| @T.prim_func | ||
| def gemm_intrinsics( | ||
| A: T.Tensor(A_shape, in_dtype), | ||
| B: T.Tensor(B_shape, in_dtype), | ||
| C: T.Tensor((M, N), out_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
|
|
||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) | ||
| B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) | ||
| C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) | ||
| A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) | ||
| B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) | ||
| C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) | ||
|
|
||
| T.annotate_layout({ | ||
| A_shared: make_swizzle_layout(A_shared), | ||
| B_shared: make_swizzle_layout(B_shared), | ||
| }) | ||
|
|
||
| # Improve L2 Cache | ||
| T.use_swizzle(panel_size=10) | ||
|
|
||
| T.clear(C_local) | ||
|
|
||
| for ko in T.Pipelined((K // block_K), num_stages=stage): | ||
|
|
||
| # Load A into shared memory | ||
| for i, k in T.Parallel(block_M, block_K): | ||
| A_shared[i, k] = A[by * block_M + i, ko * block_K + k] | ||
|
|
||
| # Load B into shared memory | ||
| for j, k in T.Parallel(block_N, block_K): | ||
| B_shared[j, k] = B[bx * block_N + j, ko * block_K + k] | ||
|
|
||
| for ki in T.serial(0, (block_K // micro_size_k)): | ||
|
|
||
| # Load A into fragment | ||
| mmac_emitter.ldmatrix_a(A_local, A_shared, ki) | ||
|
|
||
| # Load B into fragment | ||
| mmac_emitter.ldmatrix_b(B_local, B_shared, ki) | ||
|
|
||
| # Perform Matrix Multiplication | ||
| mmac_emitter.mmac(A_local, B_local, C_local) | ||
|
|
||
| # Perform STMatrix | ||
| mmac_emitter.stmatrix(C_local, C_shared) | ||
|
|
||
| # Store shared into global | ||
| for i, j in T.Parallel(block_M, block_N): | ||
| C[by * block_M + i, bx * block_N + j] = C_shared[ | ||
| j // micro_size_y, | ||
| i // micro_size_x, | ||
| i % micro_size_x, | ||
| j % micro_size_y, | ||
| ] | ||
|
Comment on lines
+153
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. C_shared indexing appears transposed vs declared shape C_shared is declared as (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y), but store uses [j-group, i-group, i-rem, j-rem]. Swap i/j groups to match the shape. Apply this diff: - C[by * block_M + i, bx * block_N + j] = C_shared[
- j // micro_size_y,
- i // micro_size_x,
- i % micro_size_x,
- j % micro_size_y,
- ]
+ C[by * block_M + i, bx * block_N + j] = C_shared[
+ i // micro_size_x,
+ j // micro_size_y,
+ i % micro_size_x,
+ j % micro_size_y,
+ ]🤖 Prompt for AI Agents |
||
|
|
||
| return gemm_intrinsics | ||
|
|
||
|
|
||
| def ref_program(A, B): | ||
| return A @ B.T | ||
|
|
||
|
|
||
| def main(): | ||
| M, N, K = 16384, 16384, 16384 | ||
| in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" | ||
| kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) | ||
| src_code = kernel.get_kernel_source() | ||
| # src_code is the generated cuda source | ||
| assert src_code is not None | ||
|
|
||
| profiler = kernel.get_profiler() | ||
|
|
||
| latency = profiler.do_bench(profiler.func, warmup=25) | ||
|
|
||
| print(latency) | ||
| print(kernel.get_kernel_source()) | ||
| # Ensure that the latency is not None | ||
| assert latency is not None | ||
|
|
||
| profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,161 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // !!! This is a file automatically generated by hipify!!! | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <ATen/dtk_macros.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Copyright (c) Microsoft Corporation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| // Licensed under the MIT license. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <assert.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <pybind11/pybind11.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <pybind11/stl.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <pybind11/numpy.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <torch/extension.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| #include <hip/hip_runtime.h> | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for (int idx = range_start; idx < range_end; idx += block_size) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_offset[block_count++] = idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| __global__ void convert_vertical_slash_indexes_kernel( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* seqlens, // [BATCH, ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int N_HEADS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int N_ROWS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int BLOCK_SIZE_M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int BLOCK_SIZE_N, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int NNZ_V, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int NNZ_S | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int batch_idx = blockIdx.y; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int head_idx = blockIdx.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int group_idx = blockIdx.z; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int seqlen = seqlens[batch_idx]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int block_idx_m = group_idx * blockDim.x + threadIdx.x; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int start_m = block_idx_m * BLOCK_SIZE_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (start_m >= seqlen) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int end_m = start_m + BLOCK_SIZE_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_count += row_offset; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_offset += row_offset * NNZ_S; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_count += row_offset; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_index += row_offset * NNZ_V; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+40
to
+54
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Guard against row index overflow. Threads with block_idx_m ≥ N_ROWS can compute a valid start_m < seqlen (if seqlen > context), causing OOB on row_offset. Add an explicit guard. Apply: int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
+ if (block_idx_m >= N_ROWS) {
+ return;
+ }
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int tmp_col_cnt = 0, tmp_blk_cnt = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int s = 0, v = 0; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int v_idx = vertical_indexes[v++]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int s_idx = slash_indexes[s++]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (s_idx >= end_m) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_idx = slash_indexes[s++]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_idx = max(end_m - s_idx, BLOCK_SIZE_M); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+55
to
+62
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix OOB reads when NNZ_S/NNZ_V are zero and bound the pre-scan. Accessing vertical_indexes[v++] and slash_indexes[s++] without checking NNZ_* risks OOB. The pre-loop while also lacks a bound on s. Apply: - int tmp_col_cnt = 0, tmp_blk_cnt = 0;
- int s = 0, v = 0;
- int v_idx = vertical_indexes[v++];
- int s_idx = slash_indexes[s++];
- while (s_idx >= end_m) {
- s_idx = slash_indexes[s++];
- }
- s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
+ int tmp_col_cnt = 0, tmp_blk_cnt = 0;
+ int s = 0, v = 0;
+ // Safe init of v_idx
+ int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
+ // Handle NNZ_S == 0 early
+ if (NNZ_S == 0) {
+ block_count[0] = 0;
+ column_count[0] = 0;
+ return;
+ }
+ int s_idx = slash_indexes[s++];
+ while (s < NNZ_S && s_idx >= end_m) {
+ s_idx = slash_indexes[s++];
+ }
+ if (s_idx >= end_m) {
+ // No slash indices relevant for this row
+ block_count[0] = 0;
+ column_count[0] = 0;
+ return;
+ }
+ s_idx = max(end_m - s_idx, BLOCK_SIZE_M);📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| while (1) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (v_idx < range_end) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (v_idx < range_start) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_index[tmp_col_cnt++] = v_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (v < NNZ_V) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_idx = vertical_indexes[v++]; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| v_idx = end_m + BLOCK_SIZE_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (s < NNZ_S) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+67
to
+78
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cap per-row writes to avoid overflow of column_index and block_offset. tmp_col_cnt and tmp_blk_cnt can exceed NNZ_V/NNZ_S; cap writes and bound save_blocks. Apply: - column_index[tmp_col_cnt++] = v_idx;
+ if (tmp_col_cnt < NNZ_V) {
+ column_index[tmp_col_cnt++] = v_idx;
+ }
@@
- save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+ save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);
@@
- save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+ save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);And update save_blocks to accept a max: -__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
- for (int idx = range_start; idx < range_end; idx += block_size) {
- block_offset[block_count++] = idx;
- }
-}
+__device__ __forceinline__ void save_blocks(int* block_offset,
+ int range_start,
+ int range_end,
+ int block_size,
+ int& blk_cnt,
+ int max_blocks) {
+ for (int idx = range_start; idx < range_end && blk_cnt < max_blocks; idx += block_size) {
+ block_offset[blk_cnt++] = idx;
+ }
+}
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| break; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if (s_idx > range_end + BLOCK_SIZE_M) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| range_start = s_idx - BLOCK_SIZE_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| range_end = s_idx; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } else if (s_idx > range_end) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| range_end += BLOCK_SIZE_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_count[0] = tmp_blk_cnt; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_count[0] = tmp_col_cnt; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| void convert_vertical_slash_indexes_64x64( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* seqlens, // [BATCH, ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int BATCH_SIZE, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int N_HEADS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int N_ROWS, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int NNZ_V, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int NNZ_S | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int BLOCK_SIZE_M = 64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int BLOCK_SIZE_N = 64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const int N_THREADS = 64; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const dim3 dimBlock(N_THREADS); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seqlens, vertical_indexes, slash_indexes, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_count, block_offset, column_count, column_index, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+114
to
+118
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use PyTorch’s current stream, not stream 0; add launch error check. Launching on stream 0 breaks PyTorch stream semantics and can race other ops. Also add a kernel launch check. Apply: +#include <ATen/cuda/CUDAContext.h> // works for CUDA and ROCm builds
@@
- hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0,
+ auto stream = at::cuda::getCurrentCUDAStream();
+ hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, stream.stream(),
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
+ AT_CUDA_CHECK(hipGetLastError());
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<at::Tensor> convert_vertical_slash_indexes( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor seqlens, // [BATCH, ] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int context_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int block_size_M, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int block_size_N | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(block_size_M == 64); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert(block_size_N == 64); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hipSetDevice(seqlens.get_device()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int batch_size = slash_indexes.size(0); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int num_heads = slash_indexes.size(1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int nnz_slash = slash_indexes.size(2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int nnz_vertical = vertical_indexes.size(2); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| int num_rows = (context_size + block_size_M - 1) / block_size_M; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| convert_vertical_slash_indexes_64x64( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| seqlens.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vertical_indexes.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| slash_indexes.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_count.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| block_offset.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_count.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| column_index.data_ptr<int>(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_heads, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| num_rows, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nnz_vertical, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| nnz_slash | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+121
to
+158
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate inputs, use DeviceGuard, enforce dtypes/contiguity; replace assert with TORCH_CHECK. Ensure tensors live on the same device/stream, are int32, and contiguous. Avoid hipSetDevice; use DeviceGuard. Assert() is compiled out in Release. Apply: +#include <c10/core/DeviceGuard.h>
@@
- assert(block_size_M == 64);
- assert(block_size_N == 64);
+ TORCH_CHECK(block_size_M == 64, "block_size_M must be 64");
+ TORCH_CHECK(block_size_N == 64, "block_size_N must be 64");
@@
- hipSetDevice(seqlens.get_device());
+ c10::DeviceGuard guard(seqlens.device());
+ TORCH_CHECK(seqlens.is_cuda(), "seqlens must be on CUDA/HIP device");
+ TORCH_CHECK(vertical_indexes.is_cuda() && slash_indexes.is_cuda(), "Inputs must be CUDA/HIP tensors");
+ TORCH_CHECK(vertical_indexes.device() == seqlens.device() && slash_indexes.device() == seqlens.device(),
+ "All inputs must be on the same device");
+ TORCH_CHECK(seqlens.scalar_type() == at::kInt, "seqlens must be int32");
+ TORCH_CHECK(vertical_indexes.scalar_type() == at::kInt && slash_indexes.scalar_type() == at::kInt,
+ "vertical_indexes/slash_indexes must be int32");
+
+ seqlens = seqlens.contiguous();
+ vertical_indexes = vertical_indexes.contiguous();
+ slash_indexes = slash_indexes.contiguous();
@@
- int batch_size = slash_indexes.size(0);
- int num_heads = slash_indexes.size(1);
- int nnz_slash = slash_indexes.size(2);
- int nnz_vertical = vertical_indexes.size(2);
+ int batch_size = static_cast<int>(slash_indexes.size(0));
+ int num_heads = static_cast<int>(slash_indexes.size(1));
+ int nnz_slash = static_cast<int>(slash_indexes.size(2));
+ int nnz_vertical = static_cast<int>(vertical_indexes.size(2));📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return { block_count, block_offset, column_count, column_index }; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
micro_size_k should depend on in_dtype, not out_dtype
Int8 paths require K-fragment=32 regardless of the final out dtype. Gate on in_dtype to avoid incorrect tiling when out_dtype is int32 but inputs aren’t int8.
Apply this diff:
📝 Committable suggestion
🤖 Prompt for AI Agents