Skip to content

Conversation

johnnynunez
Copy link

@johnnynunez johnnynunez commented Apr 22, 2025

Support:

Jetson Orin: 8.7
Jetson Thor: 10.1
Blackwell B100/B200/GB200: 10.0
Spark: 11.0

@woct0rdho
Copy link

woct0rdho commented Apr 22, 2025

Did you test that each sm should be routed to which implementation in sageattn in core.py?

Maybe we should eventually implement something like autotune to do this

Update: Currently the sm89 implementation is SageAttention2++, so it's faster than the old sm89 implementation. Let's see if it's faster than the sm90 implementation on newer GPUs.

SageAttention3 is in early access (and still has some precision issues). It uses fp4 quantization, so it's faster than SageAttention2++ on sm120, but I don't know if it can be directly compiled for other Blackwell GPUs.

@johnnynunez
Copy link
Author

johnnynunez commented Apr 22, 2025

Did you test that each arch/device should be routed to which implementation in sageattn in core.py?

Maybe we should eventually implement something like autotune to do this

I'm testing with Ada, Hopper (gh200) and jetson orin and rtx5090/gb200

@johnnynunez
Copy link
Author

For blackwell is the same as rtx50 with triton 3.3.x

@pftq
Copy link

pftq commented May 17, 2025

Isn't there more needed to handle the B200? The commit seems to only get past the setup process. For example, the sm100 for B200 is not a case handled in the core.py (it skips to sm120).

Line 135 in core.py

    elif arch == "sm90":
        return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
    elif arch == "sm120":
        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.

Otherwise seems to throw this error:

  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/sageattention/core.py", line 138, in sageattn
    raise ValueError(f"Unsupported CUDA architecture: {arch}")
ValueError: Unsupported CUDA architecture: sm100

Copying one of the other cases doesn't seem to be enough:

 elif arch == "sm100":
        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32")

Still results in:

  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 857, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/workspace/ComfyUI/venv/lib/python3.11/site-packages/sageattention/core.py", line 722, in sageattn_qk_int8_pv_fp8_cuda
    o = torch.empty(q.size(), dtype=dtype, device=q.device)
RuntimeError: CUDA error: no kernel image is available for execution on the device

@woct0rdho
Copy link

woct0rdho commented May 17, 2025

@pftq Ok let's do it. First you need to build the binary like _qattn_sm89 but with the compiler flag for sm100. You need to add HAS_SM100 to the line if HAS_SM89 or HAS_SM120: and build it, otherwise you'll see no kernel image is available for execution on the device.

Then modify the function sageattn in core.py to use the implementation from that binary.

Then you can test whether the implementation is correct ( such as my test script https://github.com/woct0rdho/SageAttention/blob/main/tests/test_sageattn.py ), and which of the implementations (triton/sm89/sm90) is the fastest (maybe using the scripts at https://github.com/thu-ml/SageAttention/tree/main/bench ).

@pftq
Copy link

pftq commented May 17, 2025

See the bottom of my earlier reply - right now just getting a CUDA error so I'd need to get past that first.

@shounak-ray
Copy link

Were you able to get this SageAttention running on B200s?

@Aya-ZIbra
Copy link

+1
We would like to evaluate this on B200 / GB200. Any branches that support these, please?

@Pangm
Copy link

Pangm commented Sep 1, 2025

@pftq Ok let's do it. First you need to build the binary like _qattn_sm89 but with the compiler flag for sm100. You need to add HAS_SM100 to the line if HAS_SM89 or HAS_SM120: and build it, otherwise you'll see no kernel image is available for execution on the device.

Then modify the function sageattn in core.py to use the implementation from that binary.

Then you can test whether the implementation is correct ( such as my test script https://github.com/woct0rdho/SageAttention/blob/main/tests/test_sageattn.py ), and which of the implementations (triton/sm89/sm90) is the fastest (maybe using the scripts at https://github.com/thu-ml/SageAttention/tree/main/bench ).

BZZ2: the sm89 implementation is SageAttention2++
HZZ2: the sm90 implementation

SageAttention2++ on SM100(BZZ2) performance has decreased than the SageAttention on SM90(HZZ2). Is there any code optimized for SM100?
@woct0rdho @johnnynunez


一、Benchmark, https://github.com/thu-ml/SageAttention/tree/main/bench

  1. BZZ2
$ python bench_qk_int8_pv_fp8_cuda.py --pv_accum_dtype fp32+fp16 --quant_gran per_warp
CUDA QK Int8 PV FP8 Benchmark
batch: 4, head: 32, headdim: 128, pv_accum_dtype: fp32+fp16, fused_v: False
is_causal: False
1024 flops:369.53967935191037
2048 flops:404.9683418973956
4096 flops:419.0692280433787
8192 flops:425.7005528441506
16384 flops:432.06802961968805
32768 flops:433.55603269730267
is_causal: True
1024 flops:250.66197667633227
2048 flops:324.3266958350166
4096 flops:369.54497019445324
8192 flops:395.62348971339713
16384 flops:409.47553355350226
32768 flops:416.7046366646864
  1. HZZ2:
$ bench_qk_int8_pv_fp8_cuda_sm90.py --pv_accum_dtype fp32+fp32 --quant_gran per_thread
CUDA QK Int8 PV FP8 SM90 Benchmark
batch: 4, head: 32, headdim: 128
is_causal: False
1024 flops:678.563889696416
2048 flops:802.0022759455021
4096 flops:874.6756686620378
8192 flops:903.8851925032727
16384 flops:896.3888132200065
32768 flops:903.401759607428
is_causal: True
1024 flops:456.897014541662
2048 flops:636.5962171472885
4096 flops:769.9143136658597
8192 flops:822.0096982392442
16384 flops:863.3824950649109
32768 flops:885.0549123325416

二、code diff

$ git diff
diff --git a/sageattention/core.py b/sageattention/core.py
index c8829e4..3ca9c10 100644
--- a/sageattention/core.py
+++ b/sageattention/core.py
@@ -146,6 +146,8 @@ def sageattn(
         return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
     elif arch == "sm90":
         return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32")
+    elif arch == "sm100":
+        return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16")
     elif arch == "sm120":
         return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120.
     else:
@@ -952,4 +954,4 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90(
     if return_lse:
         return o, lse / 1.44269504 + lse_correction * sm_scale if smooth_k else lse / 1.44269504
     else:
-        return o
\ No newline at end of file
+        return o
diff --git a/setup.py b/setup.py
index 5e4779d..05d6f0d 100644
--- a/setup.py
+++ b/setup.py
@@ -28,10 +28,11 @@ HAS_SM80 = False
 HAS_SM86 = False
 HAS_SM89 = False
 HAS_SM90 = False
+HAS_SM100 = False
 HAS_SM120 = False

 # Supported NVIDIA GPU architectures.
-SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "12.0"}
+SUPPORTED_ARCHS = {"8.0", "8.6", "8.9", "9.0", "10.0", "12.0"}

 # Compiler flags.
 CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"]
@@ -109,6 +110,9 @@ for capability in compute_capabilities:
     elif capability.startswith("9.0"):
         HAS_SM90 = True
         num = "90a" # need to use sm90a instead of sm90 to use wgmma ptx instruction.
+    elif capability.startswith("10.0"):
+        HAS_SM100 = True
+        num = "100"
     elif capability.startswith("12.0"):
         HAS_SM120 = True
         num = "120" # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions.
@@ -132,7 +136,7 @@ if HAS_SM80 or HAS_SM86 or HAS_SM89 or HAS_SM90 or HAS_SM120:
     )
     ext_modules.append(qattn_extension)

-if HAS_SM89 or HAS_SM120:
+if HAS_SM89 or HAS_SM120 or HAS_SM100:
     qattn_extension = CUDAExtension(
         name="sageattention._qattn_sm89",
         sources=[
@@ -179,6 +183,7 @@ fused_extension = CUDAExtension(
 )
 ext_modules.append(fused_extension)

+print(f"nvcc_flags: {NVCC_FLAGS}")

 parallel = None
 if 'EXT_PARALLEL' in os.environ:

@woct0rdho
Copy link

@Pangm I'm not familiar with sm90 and sm100 but AFAIK there are already some special optimizations for sm90. For example, sm90 uses CTA_Q = 64, CTA_K = 128 by default, while sm89 uses CTA_Q = 128, CTA_K = 64. So I guess there should also be some special optimizations for sm100.

@atagunov
Copy link

atagunov commented Sep 7, 2025

Hi I have idependently done almost exactly same as above treating B200 same as RTX5090 and found that although I was successful in enabling sage attention in ComfyUI my WAN 2.2 generation times became significantly slower maybe around 2.5x times slower with it enabled.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants