From fcb3c4e18fa62926e52e0bd157bdeccacbc4ba76 Mon Sep 17 00:00:00 2001 From: Gheorghe-Teodor Bercea Date: Fri, 20 Dec 2024 08:48:02 -0600 Subject: [PATCH] Improve performance of casted elementwise add operations. --- aten/src/ATen/native/cuda/CUDALoops.cuh | 32 ++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index 4dcb9a3450e8a..7120ea775044a 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -231,6 +231,36 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) { C10_CUDA_KERNEL_LAUNCH_CHECK(); } +template +C10_LAUNCH_BOUNDS_2(nt, 4) +__global__ void elementwise_kernel_strided(int N, func_t f) { + int tid = threadIdx.x; + int idx = nt * vt * blockIdx.x + tid; + int step = nt * vt * gridDim.x; + while (idx < N) { +#pragma unroll + for (int i = 0; i < vt; i++) { + if ((idx + nt * i) < N) { + f(idx + nt * i); + } + } + idx += step; + } +} + +template +static void launch_legacy_kernel_strided(int64_t N, const func_t& f) { + TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits::max()); + if (N == 0) { + return; + } + dim3 block(nt); + dim3 grid(8192); + auto stream = at::cuda::getCurrentCUDAStream(); + elementwise_kernel_strided<<>>(N, f); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + template C10_HOST_DEVICE typename traits::result_type invoke_impl( const func_t& f, @@ -348,7 +378,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { dtypes[i] = iter.dtype(i); strides[i] = inner_strides[i]; } - launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) { + launch_legacy_kernel_strided<512, 4>(numel, [=]GPU_LAMBDA(int idx) { void* out = data[0] + strides[0] * idx; arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx); c10::cast_and_store(dtypes[0], out, result);