Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions csrc/kernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -2109,6 +2109,15 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
// define compile time macro for warp size with ROCm 7
// Refer: https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#warpsize
#if (HIP_VERSION_MAJOR >= 7)
#if defined(__GFX8__) || defined(__GFX9__)
#define warpSize 64
#else
#define warpSize 32
#endif
#endif
#define WARP_SIZE warpSize
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
Expand Down
12 changes: 12 additions & 0 deletions csrc/ops.hip
Original file line number Diff line number Diff line change
Expand Up @@ -692,9 +692,21 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
//warpsize - 32
int num_blocks = (m+3)/4;
//warpsize - 64

// define compile time macro for warp size with ROCm 7
// Refer: https://rocm.docs.amd.com/projects/HIP/en/latest/how-to/hip_cpp_language_extensions.html#warpsize

#if (HIP_VERSION_MAJOR >= 7)
#if defined(__GFX8__) || defined(__GFX9__)
#define warpSize 64
#else
#define warpSize 32
#endif
#else
if (warpSize == 64) {
num_blocks = (m+1)/2;
}
#endif

hipLaunchKernelGGL(( kgemm_4bit_inference_naive<T, 128, BITS>), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(hipPeekAtLastError());
Expand Down