diff --git a/gpu/bitnet_kernels/bitgemm.cu b/gpu/bitnet_kernels/bitgemm.cu new file mode 100644 index 00000000..eda5e267 --- /dev/null +++ b/gpu/bitnet_kernels/bitgemm.cu @@ -0,0 +1,224 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +template +__device__ void decode_i2s_to_i8s(T1 *_i2s, T2 *_i8s, const int N = 16) { + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2s = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2s = *_i2s; + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I4s_TO_I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(i8s[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), + "n"(I4s_TO_I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(i8s[i], 0x02020202); + } +} + +template +__global__ void int8_int2_gemm_fused_kernel( + const int8_t *__restrict__ A, + const int32_t *__restrict__ B_compressed, + __nv_bfloat16 *__restrict__ C, + int M, + const __nv_bfloat16 *__restrict__ s, // MODIFICATION: s is now bfloat16 + const __nv_bfloat16 *__restrict__ ws) // MODIFICATION: ws is now bfloat16 +{ + // --- GEMM Calculation Stage (largely unchanged) --- + constexpr int WMMA_M = 16; + constexpr int WMMA_N = 16; + constexpr int WMMA_K = 16; + constexpr int BLOCK_SIZE_K = 32; + constexpr int WARPS_M = 2; + constexpr int WARPS_N = 2; + constexpr int M_ITER = BLOCK_SIZE_M / WMMA_M / WARPS_M; + constexpr int N_ITER = BLOCK_SIZE_N / WMMA_N / WARPS_N; + + const int blockM = blockIdx.y * BLOCK_SIZE_M; + const int blockN = blockIdx.x * BLOCK_SIZE_N; + const int warpM = threadIdx.y; + const int warpN = threadIdx.z; + const int tid = threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * blockDim.x * blockDim.y; + + constexpr int PAD_A = 16; + constexpr int PAD_B = 16; + + __shared__ int8_t shared_A[BLOCK_SIZE_M][BLOCK_SIZE_K + PAD_A]; + __shared__ int8_t shared_B[BLOCK_SIZE_N][BLOCK_SIZE_K + PAD_B]; + + nvcuda::wmma::fragment c_frags[M_ITER][N_ITER]; + nvcuda::wmma::fragment a_frag; + nvcuda::wmma::fragment b_frag; + + #pragma unroll + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { + nvcuda::wmma::fill_fragment(c_frags[m_iter][n_iter], 0); + } + } + + const bool m_valid = blockM < M; + + for (int k_block = 0; k_block < K; k_block += BLOCK_SIZE_K) { + __syncthreads(); + // Load A tile + for (int load_idx = tid; load_idx < (BLOCK_SIZE_M * BLOCK_SIZE_K / 16); load_idx += blockDim.x * blockDim.y * blockDim.z) { + int local_m = (load_idx * 16) / BLOCK_SIZE_K; + int local_k = (load_idx * 16) % BLOCK_SIZE_K; + int global_m = blockM + local_m; + int global_k = k_block + local_k; + if (m_valid && global_m < M) { + *((int4*)&shared_A[local_m][local_k]) = *((int4*)&A[global_m * K + global_k]); + } else { + *((int4*)&shared_A[local_m][local_k]) = {0}; + } + } + // Load B tile + int chunk_n = (tid * 16 / BLOCK_SIZE_K); + int chunk_k = (tid * 16) % BLOCK_SIZE_K; + if (chunk_n < BLOCK_SIZE_N) { + int global_n = blockN + chunk_n; + int global_k = k_block + chunk_k; + int n_block = global_n / 16; + int k_block_32 = global_k / 32; + int k_offset_in_block = chunk_k % 32; + int in_block_n = chunk_n % 16; + int compressed_block_idx = n_block * (K / 32) + k_block_32; + int tile_idx = in_block_n / 8 * 16 + in_block_n % 8 + (k_offset_in_block / 16) * 8; + int32_t compressed = B_compressed[compressed_block_idx * 32 + tile_idx]; + int8_t decompressed[16]; + decode_i2s_to_i8s(&compressed, decompressed); + *((int4*)&shared_B[chunk_n][chunk_k]) = *((int4*)decompressed); + } + __syncthreads(); + // Perform MMA + #pragma unroll + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { + #pragma unroll + for (int wmma_k = 0; wmma_k < BLOCK_SIZE_K; wmma_k += WMMA_K) { + const int tile_m = (warpM + m_iter * WARPS_M) * WMMA_M; + const int tile_n = (warpN + n_iter * WARPS_N) * WMMA_N; + nvcuda::wmma::load_matrix_sync(a_frag, &shared_A[tile_m][wmma_k], BLOCK_SIZE_K + PAD_A); + nvcuda::wmma::load_matrix_sync(b_frag, &shared_B[tile_n][wmma_k], BLOCK_SIZE_K + PAD_B); + nvcuda::wmma::mma_sync(c_frags[m_iter][n_iter], a_frag, b_frag, c_frags[m_iter][n_iter]); + } + } + } + } + + // --- Fused Post-Processing and Store Stage --- + __shared__ int32_t shared_C[BLOCK_SIZE_M][BLOCK_SIZE_N]; + + #pragma unroll + for (int m_iter = 0; m_iter < M_ITER; m_iter++) { + #pragma unroll + for (int n_iter = 0; n_iter < N_ITER; n_iter++) { + const int tile_m = (warpM + m_iter * WARPS_M) * WMMA_M; + const int tile_n = (warpN + n_iter * WARPS_N) * WMMA_N; + nvcuda::wmma::store_matrix_sync( + &shared_C[tile_m][tile_n], + c_frags[m_iter][n_iter], + BLOCK_SIZE_N, + nvcuda::wmma::mem_row_major); + } + } + __syncthreads(); + + for (int i = tid; i < BLOCK_SIZE_M * BLOCK_SIZE_N; i += blockDim.x * blockDim.y * blockDim.z) { + const int m = i / BLOCK_SIZE_N; + const int n = i % BLOCK_SIZE_N; + + const int global_m = blockM + m; + const int global_n = blockN + n; + + if (global_m < M) { + int32_t val = shared_C[m][n]; + float float_val = static_cast(val); + + // MODIFICATION: Load bfloat16 scales and convert to float for calculation + float s_val = __bfloat162float(s[global_m]); + float_val /= s_val; + + int ws_idx = 0; + if (N == 3840) { + ws_idx = global_n / (3840 / 6); + } else if (N == 13824) { + ws_idx = global_n / (13824 / 2); + } + + float ws_val = __bfloat162float(ws[ws_idx]); + float_val *= ws_val; + + __nv_bfloat16 bf16_val = __float2bfloat16(float_val); + C[global_m * N + global_n] = bf16_val; + } + } +} + +extern "C" void bitlinear_fused_int8xint2( + int8_t *input0, int8_t *input1, + __nv_bfloat16 *output0, + int M, int N, int K, + __nv_bfloat16 *s, // MODIFICATION: s is now bfloat16 + __nv_bfloat16 *ws, // MODIFICATION: ws is now bfloat16 + cudaStream_t stream = 0) { + + constexpr int BLOCK_SIZE_M = 64; + constexpr int BLOCK_SIZE_N = 64; + + const dim3 gridDim(N / BLOCK_SIZE_N, (M + BLOCK_SIZE_M - 1) / BLOCK_SIZE_M, 1); + const dim3 blockDim(32, 2, 2); + + // Kernel launch now passes the bfloat16 pointers + if (N == 3840 && K == 2560) { + int8_int2_gemm_fused_kernel<3840, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); + } else if (N == 2560 && K == 2560) { + int8_int2_gemm_fused_kernel<2560, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); + } else if (N == 13824 && K == 2560) { + int8_int2_gemm_fused_kernel<13824, 2560, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); + } else if (N == 2560 && K == 6912) { + int8_int2_gemm_fused_kernel<2560, 6912, BLOCK_SIZE_M, BLOCK_SIZE_N> + <<>>( + input0, (int32_t *)input1, output0, M, s, ws); + } else { + std::cerr << "Error: Unsupported matrix dimensions for bitlinear_int8xint2. " + << "Required kernel: M=" << M << ", N=" << N << ", K=" << K << std::endl; + std::cerr << "Supported configurations:" << std::endl; + std::cerr << " - N=3840, K=2560" << std::endl; + std::cerr << " - N=2560, K=2560" << std::endl; + std::cerr << " - N=13824, K=2560" << std::endl; + std::cerr << " - N=2560, K=6912" << std::endl; + throw std::runtime_error("Unsupported matrix dimensions for bitlinear_int8xint2"); + } + + cudaError_t launch_error = cudaGetLastError(); + if (launch_error != cudaSuccess) { + std::cerr << "CUDA kernel launch failed: " << cudaGetErrorString(launch_error) << std::endl; + throw std::runtime_error("CUDA kernel launch failed"); + } +} diff --git a/gpu/bitnet_kernels/bitnet_kernels.cu b/gpu/bitnet_kernels/bitnet_kernels.cu index 6e615809..16650005 100644 --- a/gpu/bitnet_kernels/bitnet_kernels.cu +++ b/gpu/bitnet_kernels/bitnet_kernels.cu @@ -2,7 +2,7 @@ extern "C" void bitlinear_int8xint2(int8_t* input0, int8_t* input1, __nv_bfloat16* output0, __nv_bfloat16* s, __nv_bfloat16* ws, int M, int N, int K, cudaStream_t stream){ if (M == 1 && N == 3840 && K == 2560){ - ladder_int8xint2_kernel<1, 3840, 2560, 3, 8, 16><<>>(input0, input1, output0, s, ws); + ladder_int8xint2_kernel<1, 3840, 2560, 6, 8, 16><<>>(input0, input1, output0, s, ws); } else if (M == 1 && N == 2560 && K == 2560){ ladder_int8xint2_kernel<1, 2560, 2560, 1, 8, 16><<>>(input0, input1, output0, s, ws); diff --git a/gpu/bitnet_kernels/compile.sh b/gpu/bitnet_kernels/compile.sh index 1e22741d..6d6e40ca 100644 --- a/gpu/bitnet_kernels/compile.sh +++ b/gpu/bitnet_kernels/compile.sh @@ -1,3 +1,4 @@ nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitnet_kernels.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libbitnet.so +nvcc -std=c++17 -Xcudafe --diag_suppress=177 --compiler-options -fPIC -lineinfo --shared bitgemm.cu -lcuda -gencode=arch=compute_80,code=compute_80 -o libgemm.so diff --git a/gpu/convert_checkpoint.py b/gpu/convert_checkpoint.py index 797ad1db..0b0ba90f 100755 --- a/gpu/convert_checkpoint.py +++ b/gpu/convert_checkpoint.py @@ -47,7 +47,7 @@ def convert_int8_to_int2(weight): wk_weight, wb_scale = quant_weight_int8(wk) wv_weight, wc_scale = quant_weight_int8(wv) wqkv_weight = torch.cat([wq_weight, wk_weight, wv_weight], dim=0) - wqkv_scale = torch.cat([wa_scale, wb_scale, wc_scale, zero], dim=0) + wqkv_scale = torch.cat([wa_scale, wa_scale, wa_scale, wa_scale, wb_scale, wc_scale], dim=0) int2_result[key] = convert_int8_to_int2(wqkv_weight) int2_result[key.replace('weight', 'weight_scale')] = wqkv_scale @@ -62,7 +62,7 @@ def convert_int8_to_int2(weight): w1_weight, w1_scale = quant_weight_int8(w1) w3_weight, w3_scale = quant_weight_int8(w3) w13_weight = torch.cat([w1_weight, w3_weight], dim=0) - w13_scale = torch.cat([w1_scale, w3_scale, zero, zero], dim=0) + w13_scale = torch.cat([w1_scale, w3_scale, zero, zero, zero, zero], dim=0) int2_result[key] = convert_int8_to_int2(w13_weight) int2_result[key.replace('weight', 'weight_scale')] = w13_scale @@ -72,7 +72,7 @@ def convert_int8_to_int2(weight): fp16_result[key] = w13_weight elif 'w2' in key or 'wo' in key: weight, scale = quant_weight_int8(value) - scale = torch.cat([scale, zero, zero, zero], dim=0) + scale = torch.cat([scale, zero, zero, zero, zero, zero], dim=0) int2_result[key] = convert_int8_to_int2(weight) int2_result[key.replace('weight', 'weight_scale')] = scale diff --git a/gpu/generate.py b/gpu/generate.py index 638ed7b3..63415a07 100755 --- a/gpu/generate.py +++ b/gpu/generate.py @@ -53,7 +53,7 @@ def build( """ start_time = time.time() - model_args_prefill = fast.ModelArgs(use_kernel=False) + model_args_prefill = fast.ModelArgs(use_kernel=True) model_args_decode = fast.ModelArgs(use_kernel=True) tokenizer = Tokenizer("./tokenizer.model") @@ -63,11 +63,9 @@ def build( prefill_model = fast.Transformer(model_args_prefill) decode_model = fast.Transformer(model_args_decode) - fp16_ckpt_path = str(Path(ckpt_dir) / "model_state_fp16.pt") - fp16_checkpoint = torch.load(fp16_ckpt_path, map_location="cpu") int2_ckpt_path = str(Path(ckpt_dir) / "model_state_int2.pt") int2_checkpoint = torch.load(int2_ckpt_path, map_location="cpu") - prefill_model.load_state_dict(fp16_checkpoint, strict=True) + prefill_model.load_state_dict(int2_checkpoint, strict=True) decode_model.load_state_dict(int2_checkpoint, strict=True) torch.cuda.synchronize() diff --git a/gpu/model.py b/gpu/model.py index cd5abec0..fab4ad86 100755 --- a/gpu/model.py +++ b/gpu/model.py @@ -17,8 +17,11 @@ import ctypes bitnet_lib = ctypes.CDLL('bitnet_kernels/libbitnet.so') +gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so') -def bitnet_int8xint2_linear(input0, input1, s, ws): +import numpy as np + +def bitnet_int8xint2_linear_gemv(input0, input1, s, ws): out_shape = list(input0.shape) out_shape[-1] = input1.shape[0] @@ -36,6 +39,32 @@ def bitnet_int8xint2_linear(input0, input1, s, ws): return ret +def bitnet_int8xint2_linear_gemm(input0, input1, s, ws): + out_shape = list(input0.shape) + out_shape[-1] = input1.shape[0] + + stream = torch.cuda.current_stream() + + M = input0.shape[0] + if len(out_shape) == 3: + M *= input0.shape[1] + N = input1.shape[0] + K = input1.shape[1] * 4 + + ret = torch.zeros(*out_shape, dtype=torch.bfloat16, device=input0.device) + + gemm_lib.bitlinear_fused_int8xint2(*[ctypes.c_void_p(input0.data_ptr()), + ctypes.c_void_p(input1.data_ptr()), + ctypes.c_void_p(ret.data_ptr()), + ctypes.c_int(M), + ctypes.c_int(N), + ctypes.c_int(K), + ctypes.c_void_p(s.data_ptr()), + ctypes.c_void_p(ws.data_ptr()), + ctypes.c_void_p(stream.cuda_stream)]) + + return ret.reshape(*out_shape) + @dataclass class ModelArgs: dim: int = 2560 @@ -63,7 +92,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = False): self.out_features = out_features self.weight = torch.nn.Parameter(torch.zeros(out_features, in_features//4, dtype=torch.int8), requires_grad=False) - self.weight_scale = torch.nn.Parameter(torch.zeros(4, dtype=torch.bfloat16), requires_grad=False) + self.weight_scale = torch.nn.Parameter(torch.zeros(6, dtype=torch.bfloat16), requires_grad=False) @torch.compile def quant_input(self, input): @@ -72,7 +101,10 @@ def quant_input(self, input): def forward(self, input): input, s = self.quant_input(input) - return bitnet_int8xint2_linear(input, self.weight, s, self.weight_scale) + if input.shape[0] == 1: + return bitnet_int8xint2_linear_gemv(input, self.weight, s, self.weight_scale) + else: + return bitnet_int8xint2_linear_gemm(input, self.weight, s, self.weight_scale) class BitLinear(nn.Linear): @torch.compile diff --git a/gpu/test_gemm.py b/gpu/test_gemm.py new file mode 100644 index 00000000..a3acf3c4 --- /dev/null +++ b/gpu/test_gemm.py @@ -0,0 +1,64 @@ +import torch +from pack_weight import convert_weight_int8_to_int2 +from torch.profiler import profile, record_function, ProfilerActivity +import ctypes +import numpy as np +from torch.utils import benchmark +from model import bitnet_int8xint2_linear_gemm + +gemm_lib = ctypes.CDLL('bitnet_kernels/libgemm.so') +# set all seed +torch.manual_seed(42) +np.random.seed(42) + +def bit_linear_int8xint2(input0, weight, out, M, N, K): + stream = torch.cuda.current_stream() + gemm_lib.bitlinear_int8xint2(*[ + ctypes.c_void_p(input0.data_ptr()), + ctypes.c_void_p(weight.data_ptr()), + ctypes.c_void_p(out.data_ptr()), + ctypes.c_int(M), + ctypes.c_int(N), + ctypes.c_int(K), + ctypes.c_void_p(stream.cuda_stream),]) + +M = 512 +test_list = [ + (2560, 2560), + (3840, 2560), + (13824, 2560), + (2560, 6912), +] +for N,K in test_list: + weight = torch.randint(-1, 2, (N, K), dtype=torch.int8, device='cuda') + weight_scale = torch.ones(1, dtype=torch.bfloat16, device='cuda') + weight_compressed = convert_weight_int8_to_int2(weight).to('cuda') + weight_np = weight.cpu().to(torch.int32).T.numpy() + stream = torch.cuda.current_stream() + input0 = torch.randint(-128,127,(M, K),dtype=torch.int8, device='cuda') + input0_np = input0.cpu().to(torch.int32).numpy() + out_np = np.matmul(input0_np, weight_np) + weight_bf16 = weight.to(torch.bfloat16).T + input0_bf16 = input0.to(torch.bfloat16) + s = torch.ones(M, dtype=torch.bfloat16, device='cuda') + ws = torch.ones(6, dtype=torch.bfloat16, device='cuda') + out = bitnet_int8xint2_linear_gemm(input0, weight_compressed, s,ws) + t0 = benchmark.Timer( + stmt="out_kernel = bitnet_int8xint2_linear_gemm(input0, weight_compressed, s,ws)", + setup="from __main__ import input0, weight_compressed, s, ws, bitnet_int8xint2_linear_gemm", + num_threads=1, + ) + + t1 = benchmark.Timer( + stmt="out_bf16 = torch.matmul(input0_bf16, weight_bf16)", + setup="from __main__ import input0_bf16, weight_bf16", + num_threads=1, + ) + + time0 = t0.timeit(10) + time1 = t1.timeit(10) + + print(f'Shape{M,N,K}, W2A8: {time0.mean * 1e6:.2f}us, torch BF16: {time1.mean * 1e6:.2f}us') + out_np = torch.tensor(out_np).cuda().to(torch.bfloat16) + + print(f'custom == np {torch.all(out==out_np)}')