Skip to content

Commit e207ba6

Browse files
committed
complete patch
1 parent 36230f4 commit e207ba6

File tree

9 files changed

+343
-140
lines changed

9 files changed

+343
-140
lines changed

csrc/dispatch_utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@
3333
throw std::invalid_argument(err_msg.str()); \
3434
}
3535

36+
#define DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, ...) \
37+
if (head_dim == 64) { \
38+
constexpr int HEAD_DIM = 64; \
39+
__VA_ARGS__ \
40+
} else if (head_dim == 128) { \
41+
constexpr int HEAD_DIM = 128; \
42+
__VA_ARGS__ \
43+
} else if (head_dim == 192) { \
44+
constexpr int HEAD_DIM = 192; \
45+
__VA_ARGS__ \
46+
} else { \
47+
std::ostringstream err_msg; \
48+
err_msg << "Unsupported head dim: " << int(head_dim); \
49+
throw std::invalid_argument(err_msg.str()); \
50+
}
51+
3652
#define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \
3753
if (is_causal == 1) { \
3854
constexpr bool IS_CAUSAL = true; \

csrc/fused/fused.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ void quant_per_block_int8_fuse_sub_mean_cuda(
652652

653653
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
654654
DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
655-
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
655+
DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, {
656656

657657
CHECK_SHAPE(mean, batch_size, num_heads, head_dim);
658658
CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3));
@@ -738,7 +738,7 @@ void quant_per_warp_int8_cuda(
738738
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, {
739739
DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
740740
DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, {
741-
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
741+
DISPATCH_HEAD_DIM_QK(head_dim, HEAD_DIM, {
742742

743743
CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3));
744744
CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE));

csrc/qattn/attn_cuda_sm90.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf(
2929
float sm_scale,
3030
int return_lse);
3131

32+
torch::Tensor qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90(
33+
torch::Tensor query,
34+
torch::Tensor key,
35+
torch::Tensor query_pe,
36+
torch::Tensor key_pe,
37+
torch::Tensor value,
38+
torch::Tensor output,
39+
torch::Tensor query_scale,
40+
torch::Tensor key_scale,
41+
int tensor_layout,
42+
int is_causal,
43+
int qk_quant_gran,
44+
float sm_scale,
45+
int return_lse);
46+
3247
torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
3348
torch::Tensor query,
3449
torch::Tensor key,
@@ -41,4 +56,20 @@ torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(
4156
int is_causal,
4257
int qk_quant_gran,
4358
float sm_scale,
59+
int return_lse);
60+
61+
torch::Tensor qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90(
62+
torch::Tensor query,
63+
torch::Tensor key,
64+
torch::Tensor query_pe,
65+
torch::Tensor key_pe,
66+
torch::Tensor value,
67+
torch::Tensor output,
68+
torch::Tensor query_scale,
69+
torch::Tensor key_scale,
70+
torch::Tensor value_scale,
71+
int tensor_layout,
72+
int is_causal,
73+
int qk_quant_gran,
74+
float sm_scale,
4475
int return_lse);

csrc/qattn/pybind_sm90.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2222
{
2323
m.def("qk_int8_sv_f8_accum_f32_attn_inst_buf", &qk_int8_sv_f8_accum_f32_attn_inst_buf);
2424
m.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf);
25+
m.def("qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90", &qk_int8_sv_f8_accum_f32_attn_inst_buf_dsk_sm90);
26+
m.def("qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90", &qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf_dsk_sm90);
2527
}

0 commit comments

Comments
 (0)