-
Notifications
You must be signed in to change notification settings - Fork 233
Add support more NVIDIA devices #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Did you test that each sm should be routed to which implementation in
Update: Currently the sm89 implementation is SageAttention2++, so it's faster than the old sm89 implementation. Let's see if it's faster than the sm90 implementation on newer GPUs. SageAttention3 is in early access (and still has some precision issues). It uses fp4 quantization, so it's faster than SageAttention2++ on sm120, but I don't know if it can be directly compiled for other Blackwell GPUs. |
I'm testing with Ada, Hopper (gh200) and jetson orin and rtx5090/gb200 |
For blackwell is the same as rtx50 with triton 3.3.x |
Isn't there more needed to handle the B200? The commit seems to only get past the setup process. For example, the sm100 for B200 is not a case handled in the core.py (it skips to sm120). Line 135 in core.py
Otherwise seems to throw this error:
Copying one of the other cases doesn't seem to be enough:
Still results in:
|
@pftq Ok let's do it. First you need to build the binary like Then modify the function Then you can test whether the implementation is correct ( such as my test script https://github.com/woct0rdho/SageAttention/blob/main/tests/test_sageattn.py ), and which of the implementations (triton/sm89/sm90) is the fastest (maybe using the scripts at https://github.com/thu-ml/SageAttention/tree/main/bench ). |
See the bottom of my earlier reply - right now just getting a CUDA error so I'd need to get past that first. |
Were you able to get this SageAttention running on B200s? |
+1 |
BZZ2: the sm89 implementation is SageAttention2++ SageAttention2++ on SM100(BZZ2) performance has decreased than the SageAttention on SM90(HZZ2). Is there any code optimized for SM100? 一、Benchmark, https://github.com/thu-ml/SageAttention/tree/main/bench
$ python bench_qk_int8_pv_fp8_cuda.py --pv_accum_dtype fp32+fp16 --quant_gran per_warp
CUDA QK Int8 PV FP8 Benchmark
batch: 4, head: 32, headdim: 128, pv_accum_dtype: fp32+fp16, fused_v: False
is_causal: False
1024 flops:369.53967935191037
2048 flops:404.9683418973956
4096 flops:419.0692280433787
8192 flops:425.7005528441506
16384 flops:432.06802961968805
32768 flops:433.55603269730267
is_causal: True
1024 flops:250.66197667633227
2048 flops:324.3266958350166
4096 flops:369.54497019445324
8192 flops:395.62348971339713
16384 flops:409.47553355350226
32768 flops:416.7046366646864
$ bench_qk_int8_pv_fp8_cuda_sm90.py --pv_accum_dtype fp32+fp32 --quant_gran per_thread
CUDA QK Int8 PV FP8 SM90 Benchmark
batch: 4, head: 32, headdim: 128
is_causal: False
1024 flops:678.563889696416
2048 flops:802.0022759455021
4096 flops:874.6756686620378
8192 flops:903.8851925032727
16384 flops:896.3888132200065
32768 flops:903.401759607428
is_causal: True
1024 flops:456.897014541662
2048 flops:636.5962171472885
4096 flops:769.9143136658597
8192 flops:822.0096982392442
16384 flops:863.3824950649109
32768 flops:885.0549123325416 二、code diff $ git diff
diff --git a/sageattention/core.py b/sageattention/core.py
index c8829e4..3ca9c10 100644
--- a/sageattention/core.py
+++ b/sageattention/core.py
@@ -146,6 +146,8 @@ def sageattn(
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
elif arch == "sm90":
return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
+ elif arch == "sm100":
+ return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
elif arch == "sm120":
return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
else:
@@ -952,4 +954,4 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
if return_lse:
return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
else:
- return o
\ No newline at end of file
+ return o
diff --git a/setup.py b/setup.py
index 5e4779d..05d6f0d 100644
--- a/setup.py
+++ b/setup.py
@@ -28,10 +28,11 @@ HAS_SM80 = False
HAS_SM86 = False
HAS_SM89 = False
HAS_SM90 = False
+HAS_SM100 = False
HAS_SM120 = False
# Supported NVIDIA GPU architectures.
-SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"}
+SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "10.0", "12.0"}
# Compiler flags.
CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
@@ -109,6 +110,9 @@ for capability in compute_capabilities:
elif capability.startswith("9.0"):
HAS_SM90 = True
num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction.
+ elif capability.startswith("10.0"):
+ HAS_SM100 = True
+ num = "100"
elif capability.startswith("12.0"):
HAS_SM120 = True
num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions.
@@ -132,7 +136,7 @@ if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
)
ext_modules.append(qattn_extension)
-if HAS_SM89 or HAS_SM120:
+if HAS_SM89 or HAS_SM120 or HAS_SM100:
qattn_extension = CUDAExtension(
name="sageattention._qattn_sm89",
sources=[
@@ -179,6 +183,7 @@ fused_extension = CUDAExtension(
)
ext_modules.append(fused_extension)
+print(f"nvcc_flags: {NVCC_FLAGS}")
parallel = None
if 'EXT_PARALLEL' in os.environ: |
@Pangm I'm not familiar with sm90 and sm100 but AFAIK there are already some special optimizations for sm90. For example, sm90 uses |
Hi I have idependently done almost exactly same as above treating B200 same as RTX5090 and found that although I was successful in enabling sage attention in ComfyUI my WAN 2.2 generation times became significantly slower maybe around 2.5x times slower with it enabled. |
Support:
Jetson Orin: 8.7
Jetson Thor: 10.1
Blackwell B100/B200/GB200: 10.0
Spark: 11.0