Skip to content

Conversation

JohannesGaessler
Copy link
Collaborator

This PR:

  • Refactors and deduplicates the CUDA vector FlashAttention kernels. As with the mma and tile kernels, the KQ accumulation and softmax are always done with FP32, but if fast FP16 is available it is used for KQ dot products, as well as VKQ dot products and accumulation.
  • Decouples the number of threads from the head size. This enables the use of KV cache quantization for head sizes 64 and 256, previously only head size 128 was properly supported.
  • Refactors the memory layout to use larger copies, eliminate inter-warp dependencies during the main loop, and reduce intra-warp dependencies as well as shared memory I/O.
Performance changes
GPU Model KV type Microbatch size Test t/s master t/s cf0d098 Speedup
MI50 gemma 2B Q4_0 f16 1 pp16384 107.96 165.65 1.53
MI50 gemma 2B Q4_0 f16 2 pp16384 206.73 175.60 0.85
MI50 gemma 2B Q4_0 f16 4 pp16384 210.30 189.33 0.90
MI50 gemma 2B Q4_0 f16 8 pp16384 180.28 201.53 1.12
MI50 internlm2 ?B Q4_0 f16 1 pp16384 73.88 158.28 2.14
MI50 internlm2 ?B Q4_0 f16 2 pp16384 147.63 254.77 1.73
MI50 internlm2 ?B Q4_0 f16 4 pp16384 166.44 239.66 1.44
MI50 internlm2 ?B Q4_0 f16 8 pp16384 145.51 266.28 1.83
MI50 internlm2 ?B Q4_0 q4_0 1 pp16384 78.64 149.69 1.90
MI50 internlm2 ?B Q4_0 q4_0 2 pp16384 135.53 270.69 2.00
MI50 internlm2 ?B Q4_0 q4_0 4 pp16384 183.33 285.20 1.56
MI50 internlm2 ?B Q4_0 q4_0 8 pp16384 160.97 373.99 2.32
MI50 internlm2 ?B Q4_0 q4_1 1 pp16384 81.87 155.65 1.90
MI50 internlm2 ?B Q4_0 q4_1 2 pp16384 148.07 282.66 1.91
MI50 internlm2 ?B Q4_0 q4_1 4 pp16384 201.00 294.38 1.46
MI50 internlm2 ?B Q4_0 q4_1 8 pp16384 164.88 390.23 2.37
MI50 internlm2 ?B Q4_0 q5_0 1 pp16384 80.85 134.30 1.66
MI50 internlm2 ?B Q4_0 q5_0 2 pp16384 143.14 242.74 1.70
MI50 internlm2 ?B Q4_0 q5_0 4 pp16384 78.06 258.43 3.31
MI50 internlm2 ?B Q4_0 q5_0 8 pp16384 166.73 334.86 2.01
MI50 internlm2 ?B Q4_0 q5_1 1 pp16384 88.48 142.75 1.61
MI50 internlm2 ?B Q4_0 q5_1 2 pp16384 146.85 247.10 1.68
MI50 internlm2 ?B Q4_0 q5_1 4 pp16384 89.69 262.63 2.93
MI50 internlm2 ?B Q4_0 q5_1 8 pp16384 168.93 339.54 2.01
MI50 internlm2 ?B Q4_0 q8_0 1 pp16384 80.74 147.97 1.83
MI50 internlm2 ?B Q4_0 q8_0 2 pp16384 141.34 271.01 1.92
MI50 internlm2 ?B Q4_0 q8_0 4 pp16384 180.74 288.18 1.59
MI50 internlm2 ?B Q4_0 q8_0 8 pp16384 168.19 381.03 2.27
MI50 llama 1B Q4_0 f16 1 pp16384 177.31 176.48 1.00
MI50 llama 1B Q4_0 f16 2 pp16384 345.88 343.80 0.99
MI50 llama 1B Q4_0 f16 4 pp16384 468.29 469.14 1.00
MI50 llama 1B Q4_0 f16 8 pp16384 768.16 768.06 1.00
P40 gemma 2B Q4_0 f16 1 pp16384 120.01 136.15 1.13
P40 gemma 2B Q4_0 f16 2 pp16384 201.83 258.05 1.28
P40 gemma 2B Q4_0 f16 4 pp16384 238.96 321.69 1.35
P40 gemma 2B Q4_0 f16 8 pp16384 309.12 452.65 1.46
P40 internlm2 ?B Q4_0 f16 1 pp16384 114.38 116.74 1.02
P40 internlm2 ?B Q4_0 f16 2 pp16384 193.87 220.39 1.14
P40 internlm2 ?B Q4_0 f16 4 pp16384 232.12 318.61 1.37
P40 internlm2 ?B Q4_0 f16 8 pp16384 288.50 440.53 1.53
P40 internlm2 ?B Q4_0 q4_0 1 pp16384 100.44 121.99 1.21
P40 internlm2 ?B Q4_0 q4_0 2 pp16384 149.51 157.62 1.05
P40 internlm2 ?B Q4_0 q4_0 4 pp16384 171.23 187.23 1.09
P40 internlm2 ?B Q4_0 q4_0 8 pp16384 210.05 228.61 1.09
P40 internlm2 ?B Q4_0 q4_1 1 pp16384 101.70 128.41 1.26
P40 internlm2 ?B Q4_0 q4_1 2 pp16384 151.12 198.64 1.31
P40 internlm2 ?B Q4_0 q4_1 4 pp16384 172.61 240.67 1.39
P40 internlm2 ?B Q4_0 q4_1 8 pp16384 213.89 303.81 1.42
P40 internlm2 ?B Q4_0 q5_0 1 pp16384 60.55 84.31 1.39
P40 internlm2 ?B Q4_0 q5_0 2 pp16384 76.54 121.79 1.59
P40 internlm2 ?B Q4_0 q5_0 4 pp16384 81.46 138.27 1.70
P40 internlm2 ?B Q4_0 q5_0 8 pp16384 87.06 158.48 1.82
P40 internlm2 ?B Q4_0 q5_1 1 pp16384 62.99 101.47 1.61
P40 internlm2 ?B Q4_0 q5_1 2 pp16384 80.14 131.14 1.64
P40 internlm2 ?B Q4_0 q5_1 4 pp16384 85.35 148.91 1.74
P40 internlm2 ?B Q4_0 q5_1 8 pp16384 91.33 170.49 1.87
P40 internlm2 ?B Q4_0 q8_0 1 pp16384 102.09 124.77 1.22
P40 internlm2 ?B Q4_0 q8_0 2 pp16384 155.98 209.34 1.34
P40 internlm2 ?B Q4_0 q8_0 4 pp16384 179.36 261.06 1.46
P40 internlm2 ?B Q4_0 q8_0 8 pp16384 213.55 332.76 1.56
P40 llama 1B Q4_0 f16 1 pp16384 178.69 214.22 1.20
P40 llama 1B Q4_0 f16 2 pp16384 275.41 385.27 1.40
P40 llama 1B Q4_0 f16 4 pp16384 308.42 472.74 1.53
P40 llama 1B Q4_0 f16 8 pp16384 369.76 643.42 1.74
RTX 3090 gemma 2B Q4_0 f16 1 pp16384 337.91 336.99 1.00
RTX 3090 gemma 2B Q4_0 f16 2 pp16384 596.45 592.53 0.99
RTX 3090 gemma 2B Q4_0 f16 4 pp16384 1030.47 1032.56 1.00
RTX 3090 gemma 2B Q4_0 f16 8 pp16384 1352.03 1365.43 1.01
RTX 3090 internlm2 ?B Q4_0 f16 1 pp16384 297.98 298.00 1.00
RTX 3090 internlm2 ?B Q4_0 f16 2 pp16384 524.53 522.27 1.00
RTX 3090 internlm2 ?B Q4_0 f16 4 pp16384 937.79 936.12 1.00
RTX 3090 internlm2 ?B Q4_0 f16 8 pp16384 1412.27 1407.37 1.00
RTX 3090 internlm2 ?B Q4_0 q4_0 1 pp16384 320.62 320.41 1.00
RTX 3090 internlm2 ?B Q4_0 q4_0 2 pp16384 372.32 373.82 1.00
RTX 3090 internlm2 ?B Q4_0 q4_0 4 pp16384 679.37 679.34 1.00
RTX 3090 internlm2 ?B Q4_0 q4_0 8 pp16384 1091.64 1095.65 1.00
RTX 3090 internlm2 ?B Q4_0 q4_1 1 pp16384 342.22 342.24 1.00
RTX 3090 internlm2 ?B Q4_0 q4_1 2 pp16384 371.38 371.05 1.00
RTX 3090 internlm2 ?B Q4_0 q4_1 4 pp16384 674.19 676.69 1.00
RTX 3090 internlm2 ?B Q4_0 q4_1 8 pp16384 1081.41 1080.61 1.00
RTX 3090 internlm2 ?B Q4_0 q5_0 1 pp16384 284.63 285.62 1.00
RTX 3090 internlm2 ?B Q4_0 q5_0 2 pp16384 336.81 335.56 1.00
RTX 3090 internlm2 ?B Q4_0 q5_0 4 pp16384 614.13 613.81 1.00
RTX 3090 internlm2 ?B Q4_0 q5_0 8 pp16384 985.09 981.96 1.00
RTX 3090 internlm2 ?B Q4_0 q5_1 1 pp16384 320.54 319.61 1.00
RTX 3090 internlm2 ?B Q4_0 q5_1 2 pp16384 341.00 341.36 1.00
RTX 3090 internlm2 ?B Q4_0 q5_1 4 pp16384 621.39 618.65 1.00
RTX 3090 internlm2 ?B Q4_0 q5_1 8 pp16384 1005.26 1005.62 1.00
RTX 3090 internlm2 ?B Q4_0 q8_0 1 pp16384 323.34 324.65 1.00
RTX 3090 internlm2 ?B Q4_0 q8_0 2 pp16384 362.84 363.28 1.00
RTX 3090 internlm2 ?B Q4_0 q8_0 4 pp16384 660.38 659.74 1.00
RTX 3090 internlm2 ?B Q4_0 q8_0 8 pp16384 1044.19 1047.50 1.00
RTX 3090 llama 1B Q4_0 f16 1 pp16384 499.63 501.04 1.00
RTX 3090 llama 1B Q4_0 f16 2 pp16384 864.06 863.00 1.00
RTX 3090 llama 1B Q4_0 f16 4 pp16384 1502.10 1504.38 1.00
RTX 3090 llama 1B Q4_0 f16 8 pp16384 2312.14 2304.54 1.00
RTX 4090 gemma 2B Q4_0 f16 1 pp16384 442.21 443.69 1.00
RTX 4090 gemma 2B Q4_0 f16 2 pp16384 723.16 726.60 1.00
RTX 4090 gemma 2B Q4_0 f16 4 pp16384 1384.84 1382.12 1.00
RTX 4090 gemma 2B Q4_0 f16 8 pp16384 1998.59 1999.49 1.00
RTX 4090 internlm2 ?B Q4_0 f16 1 pp16384 373.72 374.73 1.00
RTX 4090 internlm2 ?B Q4_0 f16 2 pp16384 622.28 622.31 1.00
RTX 4090 internlm2 ?B Q4_0 f16 4 pp16384 1227.15 1231.67 1.00
RTX 4090 internlm2 ?B Q4_0 f16 8 pp16384 1983.17 1987.15 1.00
RTX 4090 internlm2 ?B Q4_0 q4_0 1 pp16384 429.58 449.08 1.05
RTX 4090 internlm2 ?B Q4_0 q4_0 2 pp16384 549.58 718.78 1.31
RTX 4090 internlm2 ?B Q4_0 q4_0 4 pp16384 1079.55 1081.15 1.00
RTX 4090 internlm2 ?B Q4_0 q4_0 8 pp16384 1764.52 1760.72 1.00
RTX 4090 internlm2 ?B Q4_0 q4_1 1 pp16384 429.71 454.61 1.06
RTX 4090 internlm2 ?B Q4_0 q4_1 2 pp16384 547.64 727.61 1.33
RTX 4090 internlm2 ?B Q4_0 q4_1 4 pp16384 1070.18 1071.83 1.00
RTX 4090 internlm2 ?B Q4_0 q4_1 8 pp16384 1745.26 1746.25 1.00
RTX 4090 internlm2 ?B Q4_0 q5_0 1 pp16384 384.70 416.84 1.08
RTX 4090 internlm2 ?B Q4_0 q5_0 2 pp16384 516.37 680.59 1.32
RTX 4090 internlm2 ?B Q4_0 q5_0 4 pp16384 1001.22 1007.09 1.01
RTX 4090 internlm2 ?B Q4_0 q5_0 8 pp16384 1632.30 1629.61 1.00
RTX 4090 internlm2 ?B Q4_0 q5_1 1 pp16384 404.47 436.09 1.08
RTX 4090 internlm2 ?B Q4_0 q5_1 2 pp16384 522.52 705.03 1.35
RTX 4090 internlm2 ?B Q4_0 q5_1 4 pp16384 1011.96 1014.96 1.00
RTX 4090 internlm2 ?B Q4_0 q5_1 8 pp16384 1655.43 1660.77 1.00
RTX 4090 internlm2 ?B Q4_0 q8_0 1 pp16384 412.88 423.43 1.03
RTX 4090 internlm2 ?B Q4_0 q8_0 2 pp16384 530.25 690.37 1.30
RTX 4090 internlm2 ?B Q4_0 q8_0 4 pp16384 1028.20 1029.95 1.00
RTX 4090 internlm2 ?B Q4_0 q8_0 8 pp16384 1657.93 1654.93 1.00
RTX 4090 llama 1B Q4_0 f16 1 pp16384 648.76 668.58 1.03
RTX 4090 llama 1B Q4_0 f16 2 pp16384 1036.57 1044.80 1.01
RTX 4090 llama 1B Q4_0 f16 4 pp16384 2035.72 2024.55 0.99
RTX 4090 llama 1B Q4_0 f16 8 pp16384 3238.55 3234.24 1.00
RX 6800 gemma 2B Q4_0 f16 1 pp16384 109.48 129.81 1.19
RX 6800 gemma 2B Q4_0 f16 2 pp16384 119.82 143.82 1.20
RX 6800 gemma 2B Q4_0 f16 4 pp16384 158.89 168.73 1.06
RX 6800 gemma 2B Q4_0 f16 8 pp16384 176.88 191.15 1.08
RX 6800 internlm2 ?B Q4_0 f16 1 pp16384 58.79 112.01 1.91
RX 6800 internlm2 ?B Q4_0 f16 2 pp16384 88.66 171.27 1.93
RX 6800 internlm2 ?B Q4_0 f16 4 pp16384 96.47 281.93 2.92
RX 6800 internlm2 ?B Q4_0 f16 8 pp16384 98.63 321.10 3.26
RX 6800 internlm2 ?B Q4_0 q4_0 1 pp16384 62.18 111.82 1.80
RX 6800 internlm2 ?B Q4_0 q4_0 2 pp16384 111.85 213.92 1.91
RX 6800 internlm2 ?B Q4_0 q4_0 4 pp16384 88.88 324.86 3.65
RX 6800 internlm2 ?B Q4_0 q4_0 8 pp16384 104.07 411.65 3.96
RX 6800 internlm2 ?B Q4_0 q4_1 1 pp16384 62.67 113.31 1.81
RX 6800 internlm2 ?B Q4_0 q4_1 2 pp16384 113.07 221.49 1.96
RX 6800 internlm2 ?B Q4_0 q4_1 4 pp16384 93.61 341.20 3.64
RX 6800 internlm2 ?B Q4_0 q4_1 8 pp16384 103.30 431.99 4.18
RX 6800 internlm2 ?B Q4_0 q5_0 1 pp16384 52.29 84.82 1.62
RX 6800 internlm2 ?B Q4_0 q5_0 2 pp16384 73.98 151.41 2.05
RX 6800 internlm2 ?B Q4_0 q5_0 4 pp16384 93.85 212.44 2.26
RX 6800 internlm2 ?B Q4_0 q5_0 8 pp16384 99.53 262.80 2.64
RX 6800 internlm2 ?B Q4_0 q5_1 1 pp16384 40.27 85.17 2.11
RX 6800 internlm2 ?B Q4_0 q5_1 2 pp16384 102.04 158.21 1.55
RX 6800 internlm2 ?B Q4_0 q5_1 4 pp16384 114.37 226.57 1.98
RX 6800 internlm2 ?B Q4_0 q5_1 8 pp16384 101.26 282.50 2.79
RX 6800 internlm2 ?B Q4_0 q8_0 1 pp16384 58.37 106.93 1.83
RX 6800 internlm2 ?B Q4_0 q8_0 2 pp16384 115.62 216.44 1.87
RX 6800 internlm2 ?B Q4_0 q8_0 4 pp16384 95.50 330.94 3.47
RX 6800 internlm2 ?B Q4_0 q8_0 8 pp16384 104.95 422.05 4.02
RX 6800 llama 1B Q4_0 f16 1 pp16384 107.18 181.79 1.70
RX 6800 llama 1B Q4_0 f16 2 pp16384 137.38 315.16 2.29
RX 6800 llama 1B Q4_0 f16 4 pp16384 121.12 424.00 3.50
RX 6800 llama 1B Q4_0 f16 8 pp16384 115.33 561.43 4.87

Small performance boost on modern NVIDIA for quantized KV cache and batch sizes 1-2, moderate performance boost for old NVIDIA and batch sizes 1-8, large performance boost for AMD and batch sizes 1-8.

@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs python python script changes ggml changes relating to the ggml tensor library for machine learning labels Sep 23, 2025
@JohannesGaessler
Copy link
Collaborator Author

Context for the table: LLaMA 3.2 has a head size of 64, IntenLM 2 has a head size of 128, Gemma has a head size of 256. I chose these models because I needed models that cover these head sizes and are small enough to result in benchmark runs that only take ~1 hour at most.

Also after this PR has been merged it would be fine to pad the KV cache to only multiples of 128 rather than 256.

@JohannesGaessler
Copy link
Collaborator Author

Also after this PR has been merged it would be fine to pad the KV cache to only multiples of 128 rather than 256.

Actually, there is still an issue with the WMMA kernel (used for Volta and rocWMMA) assuming a padding of 256 but that issue should be resolvable with manageable effort. Long-term I want to completely replace the WMMA kernel with the mma kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs python python script changes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant