diff --git a/csrc/kernels.hip b/csrc/kernels.hip index ec3f7f025..17abc0f31 100644 --- a/csrc/kernels.hip +++ b/csrc/kernels.hip @@ -2109,6 +2109,15 @@ __global__ void kdequant_mm_int32_fp16( #define DENORM 1.0f/127.0f #define MAX_SPARSE_COUNT 32 #define SMEM_SIZE 8*256 +// define compile time macro for warp size with ROCm 7 +// Refer: https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#warpsize +#if (HIP_VERSION_MAJOR >= 7) +#if defined(__GFX8__) || defined(__GFX9__) + #define warpSize 64 +#else + #define warpSize 32 +#endif +#endif #define WARP_SIZE warpSize template __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) diff --git a/csrc/ops.hip b/csrc/ops.hip index 260b74b30..a0ad3da77 100644 --- a/csrc/ops.hip +++ b/csrc/ops.hip @@ -692,9 +692,21 @@ template void gemm_4bit_inference_naive(int m, int n, int //warpsize - 32 int num_blocks = (m+3)/4; //warpsize - 64 + +// define compile time macro for warp size with ROCm 7 +// Refer: https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#warpsize + +#if (HIP_VERSION_MAJOR >= 7) +#if defined(__GFX8__) || defined(__GFX9__) + #define warpSize 64 +#else + #define warpSize 32 +#endif +#else if (warpSize == 64) { num_blocks = (m+1)/2; } +#endif hipLaunchKernelGGL(( kgemm_4bit_inference_naive), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(hipPeekAtLastError());