diff --git a/sageattention/core.py b/sageattention/core.py index 1121f922..4399801e 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -27,19 +27,19 @@ from .triton.quant_per_thread import per_thread_int8 as per_thread_int8_triton try: - from . import _qattn_sm80 + from . import sm80_compile SM80_ENABLED = True except: SM80_ENABLED = False try: - from . import _qattn_sm89 + from . import sm89_compile SM89_ENABLED = True except: SM89_ENABLED = False try: - from . import _qattn_sm90 + from . import sm90_compile SM90_ENABLED = True except: SM90_ENABLED = False @@ -52,9 +52,10 @@ from typing import Any, List, Literal, Optional, Tuple, Union import warnings - import subprocess import re + + def get_cuda_version(): try: output = subprocess.check_output(['nvcc', '--version']).decode() @@ -66,6 +67,7 @@ def get_cuda_version(): print("Failed to get CUDA version:", e) return None, None + def get_cuda_arch_versions(): cuda_archs = [] for i in range(torch.cuda.device_count()): @@ -73,6 +75,7 @@ def get_cuda_arch_versions(): cuda_archs.append(f"sm{major}{minor}") return cuda_archs + def sageattn( q: torch.Tensor, k: torch.Tensor, @@ -151,7 +154,7 @@ def sageattn( else: raise ValueError(f"Unsupported CUDA architecture: {arch}") -@torch.compiler.disable + def sageattn_qk_int8_pv_fp16_triton( q: torch.Tensor, k: torch.Tensor, @@ -294,7 +297,7 @@ def sageattn_qk_int8_pv_fp16_triton( else: return o -@torch.compiler.disable + def sageattn_varlen( q: torch.Tensor, k: torch.Tensor, @@ -411,7 +414,7 @@ def sageattn_varlen( return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp16_cuda( q: torch.Tensor, k: torch.Tensor, @@ -566,17 +569,17 @@ def sageattn_qk_int8_pv_fp16_cuda( if pv_accum_dtype == 'fp32': v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16+fp32": v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") @@ -587,7 +590,7 @@ def sageattn_qk_int8_pv_fp16_cuda( else: return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp8_cuda( q: torch.Tensor, k: torch.Tensor, @@ -756,13 +759,13 @@ def sageattn_qk_int8_pv_fp8_cuda( if pv_accum_dtype == "fp32": if smooth_v: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp16": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] @@ -771,7 +774,7 @@ def sageattn_qk_int8_pv_fp8_cuda( else: return o -@torch.compiler.disable + def sageattn_qk_int8_pv_fp8_cuda_sm90( q: torch.Tensor, k: torch.Tensor, @@ -921,13 +924,13 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( if pv_accum_dtype == "fp32": raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] if return_lse: return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504 else: - return o \ No newline at end of file + return o diff --git a/sageattention/sm80_compile.py b/sageattention/sm80_compile.py new file mode 100644 index 00000000..ac5db6e4 --- /dev/null +++ b/sageattention/sm80_compile.py @@ -0,0 +1,149 @@ +from . import _qattn_sm80 +import torch + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f32_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f32_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP32 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f32_attn( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + """ + Custom CUDA kernel for SageAttention with INT8 quantization for Q and K, FP16 PV with FP16 accumulation. + """ + return _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_mean, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +def sm80_qk_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) + + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn")(sm80_qk_fake_impl) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f32_attn")(sm80_qk_fake_impl) +torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_attn_inst_buf")(sm80_qk_fake_impl) + + +@torch.library.register_fake("sageattention::qk_int8_sv_f16_accum_f16_fuse_v_mean_attn") +def qk_int8_sv_f16_accum_f16_fuse_v_mean_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return sm80_qk_fake_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) \ No newline at end of file diff --git a/sageattention/sm89_compile.py b/sageattention/sm89_compile.py new file mode 100644 index 00000000..42e56e2a --- /dev/null +++ b/sageattention/sm89_compile.py @@ -0,0 +1,146 @@ +from . import _qattn_sm89 +import torch + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +def sm89_qk_with_key_value( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) + + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf")(sm89_qk_with_key_value) +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf")(sm89_qk_with_key_value) +torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn")(sm89_qk_with_key_value) + + +@torch.library.custom_op("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn( + query, key, value, output, query_scale, key_scale, value_scale, + value_mean, tensor_layout, is_causal, qk_quant_gran, sm_scale, + return_lse + ) + + +@torch.library.register_fake("sageattention_sm89::qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn") +def sm89_qk_with_key_value_mean( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + value_mean: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return sm89_qk_with_key_value( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) diff --git a/sageattention/sm90_compile.py b/sageattention/sm90_compile.py new file mode 100644 index 00000000..60847c01 --- /dev/null +++ b/sageattention/sm90_compile.py @@ -0,0 +1,94 @@ +from . import _qattn_sm90 +import torch + + +@torch.library.custom_op("sageattention_sm90::qk_int8_sv_f8_accum_f32_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm90.qk_int8_sv_f8_accum_f32_attn_inst_buf( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.register_fake("sageattention_sm90::qk_int8_sv_f8_accum_f32_attn_inst_buf") +def qk_int8_sv_f8_accum_f32_attn_inst_buf_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + batch_size = query.size(0) + + if tensor_layout == 0: + num_qo_heads = query.size(2) + qo_len = query.size(1) + else: + num_qo_heads = query.size(1) + qo_len = query.size(2) + + if return_lse: + lse = torch.empty((batch_size, num_qo_heads, qo_len), dtype=torch.float32, device="cuda") + else: + lse = torch.empty((0)) + return lse + + +@torch.library.custom_op("sageattention_sm90::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf", mutates_args=(), device_types="cuda") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query, key, value, output, query_scale, key_scale, value_scale, + tensor_layout, is_causal, qk_quant_gran, sm_scale, return_lse + ) + + +@torch.library.register_fake("sageattention_sm90::qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf") +def qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + qk_quant_gran: int, + sm_scale: float, + return_lse: int, +) -> torch.Tensor: + return qk_int8_sv_f8_accum_f32_attn_inst_buf_fake_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, qk_quant_gran, sm_scale, return_lse + ) diff --git a/setup.py b/setup.py index 5e4779dd..81084ed6 100644 --- a/setup.py +++ b/setup.py @@ -132,7 +132,7 @@ def get_nvcc_cuda_version(cuda_dir: str) -> Version: ) ext_modules.append(qattn_extension) -if HAS_SM89 or HAS_SM120: +if HAS_SM89 or HAS_SM90 or HAS_SM120: qattn_extension = CUDAExtension( name="sageattention._qattn_sm89", sources=[