Skip to content

[Feature request] Support negative range as indices #997

@LyricZhao

Description

@LyricZhao
import torch
import tilelang
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },
)
def get_buggy_kernel(hidden):
    num_tokens = T.symbolic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
        with T.Kernel(num_tokens, threads=128) as pid:
            smem = T.alloc_shared((hidden, ), dtype='float')
            T.copy(x[pid, :], smem)
            T.cumsum(smem)
            if T.get_thread_binding() == 0:
                T.print(smem[-1])

    return buggy_kernel


if __name__ == '__main__':
    kernel = get_buggy_kernel(128)
    print(kernel.get_kernel_source())

    x = torch.zeros((1, 128), dtype=torch.float, device='cuda')
    kernel(x)

CUDA:

extern "C" __global__ void __launch_bounds__(128, 1) buggy_kernel_kernel(float* __restrict__ x, int num_tokens) {
  extern __shared__ __align__(1024) float smem[];
  #pragma unroll
  for (int i = 0; i < 1; ++i) {
    smem[((int)threadIdx.x)] = x[((((int64_t)((int)blockIdx.x)) * (int64_t)128) + ((int64_t)((int)threadIdx.x)))];
  }
  tl::fence_proxy_async();
  __syncthreads();
  tl::CumSum1D<128, false>::run((&(smem[0])), (&(smem[0])), 128);
  __syncthreads();
  if (((int)threadIdx.x) == 0) {
    debug_print_var("expr<smem[-1]>", smem[-1]);
  }
}

smem[-1] is wrong, so I suggest to support negative range as indices. Or at least, the compiler should print a warning. If the indices are variables, the situation may be more complex (e.g., adding more code to judge the range and lower the performance?). So a simplified solution is also OK for me, i.e., solving/warning at least the constant index cases.

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions