-
Notifications
You must be signed in to change notification settings - Fork 321
Closed
Labels
enhancementNew feature or requestNew feature or request
Description
import torch
import tilelang
from tilelang import language as T
@tilelang.jit(
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
},
)
def get_buggy_kernel(hidden):
num_tokens = T.symbolic('num_tokens')
@T.prim_func
def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
with T.Kernel(num_tokens, threads=128) as pid:
smem = T.alloc_shared((hidden, ), dtype='float')
T.copy(x[pid, :], smem)
T.cumsum(smem)
if T.get_thread_binding() == 0:
T.print(smem[-1])
return buggy_kernel
if __name__ == '__main__':
kernel = get_buggy_kernel(128)
print(kernel.get_kernel_source())
x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
kernel(x)CUDA:
extern "C" __global__ void __launch_bounds__(128, 1) buggy_kernel_kernel(float* __restrict__ x, int num_tokens) {
extern __shared__ __align__(1024) float smem[];
#pragma unroll
for (int i = 0; i < 1; ++i) {
smem[((int)threadIdx.x)] = x[((((int64_t)((int)blockIdx.x)) * (int64_t)128) + ((int64_t)((int)threadIdx.x)))];
}
tl::fence_proxy_async();
__syncthreads();
tl::CumSum1D<128, false>::run((&(smem[0])), (&(smem[0])), 128);
__syncthreads();
if (((int)threadIdx.x) == 0) {
debug_print_var("expr<smem[-1]>", smem[-1]);
}
}smem[-1] is wrong, so I suggest to support negative range as indices. Or at least, the compiler should print a warning. If the indices are variables, the situation may be more complex (e.g., adding more code to judge the range and lower the performance?). So a simplified solution is also OK for me, i.e., solving/warning at least the constant index cases.
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request