Commit 3dde689
Add Paged Attention to FMHA Cutlass Blackwell Forward kernel for fixed length (pytorch#4999)
Summary:
X-link: facebookresearch/FBGEMM#2013
Added paged attention support to FMHA FWD blackwell kernel.
1. Added support for fixed length case.
2. Added support for 2 cases: a) page_block_size = N tile size b) page_block_size > N
3. Added unit test, test_paged_forward.
Next steps:
1. Test the performance of fixed length case.
2. Add support for variable length case to FWD kernel.
3. Add support for small page sizes to FWD kernel.
4. Add paged attention support for decode.
Reviewed By: Aya-ZIbra, sijiac
Differential Revision: D840233961 parent b0d84b6 commit 3dde689
File tree
9 files changed
+616
-49
lines changed- fbgemm_gpu/experimental/gen_ai
- gen_ai/attention/cutlass_blackwell_fmha
- src/attention/cuda/cutlass_blackwell_fmha
- collective
- test/attention
9 files changed
+616
-49
lines changedLines changed: 12 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
| 65 | + | |
64 | 66 | | |
65 | 67 | | |
66 | 68 | | |
| |||
79 | 81 | | |
80 | 82 | | |
81 | 83 | | |
| 84 | + | |
| 85 | + | |
82 | 86 | | |
83 | 87 | | |
84 | 88 | | |
| |||
171 | 175 | | |
172 | 176 | | |
173 | 177 | | |
| 178 | + | |
| 179 | + | |
174 | 180 | | |
175 | 181 | | |
176 | 182 | | |
| |||
220 | 226 | | |
221 | 227 | | |
222 | 228 | | |
| 229 | + | |
| 230 | + | |
223 | 231 | | |
224 | 232 | | |
225 | 233 | | |
| |||
293 | 301 | | |
294 | 302 | | |
295 | 303 | | |
| 304 | + | |
| 305 | + | |
296 | 306 | | |
297 | 307 | | |
298 | 308 | | |
| |||
308 | 318 | | |
309 | 319 | | |
310 | 320 | | |
| 321 | + | |
| 322 | + | |
311 | 323 | | |
312 | 324 | | |
313 | 325 | | |
| |||
Lines changed: 16 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
4 | 4 | | |
5 | 5 | | |
6 | 6 | | |
7 | | - | |
8 | | - | |
| 7 | + | |
| 8 | + | |
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
| 16 | + | |
| 17 | + | |
16 | 18 | | |
17 | 19 | | |
18 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
19 | 27 | | |
20 | 28 | | |
21 | 29 | | |
| |||
60 | 68 | | |
61 | 69 | | |
62 | 70 | | |
| 71 | + | |
| 72 | + | |
63 | 73 | | |
64 | 74 | | |
65 | 75 | | |
| |||
94 | 104 | | |
95 | 105 | | |
96 | 106 | | |
| 107 | + | |
97 | 108 | | |
98 | 109 | | |
99 | 110 | | |
| |||
106 | 117 | | |
107 | 118 | | |
108 | 119 | | |
109 | | - | |
| 120 | + | |
110 | 121 | | |
111 | 122 | | |
112 | 123 | | |
| |||
138 | 149 | | |
139 | 150 | | |
140 | 151 | | |
| 152 | + | |
| 153 | + | |
141 | 154 | | |
142 | 155 | | |
143 | 156 | | |
| |||
Lines changed: 20 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| |||
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| 47 | + | |
| 48 | + | |
45 | 49 | | |
46 | 50 | | |
47 | 51 | | |
| |||
62 | 66 | | |
63 | 67 | | |
64 | 68 | | |
| 69 | + | |
| 70 | + | |
65 | 71 | | |
66 | 72 | | |
67 | 73 | | |
| |||
81 | 87 | | |
82 | 88 | | |
83 | 89 | | |
| 90 | + | |
| 91 | + | |
84 | 92 | | |
85 | 93 | | |
86 | 94 | | |
| |||
101 | 109 | | |
102 | 110 | | |
103 | 111 | | |
| 112 | + | |
| 113 | + | |
104 | 114 | | |
105 | 115 | | |
106 | 116 | | |
| |||
120 | 130 | | |
121 | 131 | | |
122 | 132 | | |
| 133 | + | |
| 134 | + | |
123 | 135 | | |
124 | 136 | | |
125 | 137 | | |
| |||
140 | 152 | | |
141 | 153 | | |
142 | 154 | | |
| 155 | + | |
| 156 | + | |
143 | 157 | | |
144 | 158 | | |
145 | 159 | | |
| |||
159 | 173 | | |
160 | 174 | | |
161 | 175 | | |
| 176 | + | |
| 177 | + | |
162 | 178 | | |
163 | 179 | | |
164 | 180 | | |
| |||
179 | 195 | | |
180 | 196 | | |
181 | 197 | | |
| 198 | + | |
| 199 | + | |
182 | 200 | | |
183 | 201 | | |
184 | 202 | | |
| |||
198 | 216 | | |
199 | 217 | | |
200 | 218 | | |
| 219 | + | |
| 220 | + | |
201 | 221 | | |
202 | 222 | | |
203 | 223 | | |
| |||
Lines changed: 20 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| |||
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| 47 | + | |
| 48 | + | |
45 | 49 | | |
46 | 50 | | |
47 | 51 | | |
| |||
62 | 66 | | |
63 | 67 | | |
64 | 68 | | |
| 69 | + | |
| 70 | + | |
65 | 71 | | |
66 | 72 | | |
67 | 73 | | |
| |||
81 | 87 | | |
82 | 88 | | |
83 | 89 | | |
| 90 | + | |
| 91 | + | |
84 | 92 | | |
85 | 93 | | |
86 | 94 | | |
| |||
101 | 109 | | |
102 | 110 | | |
103 | 111 | | |
| 112 | + | |
| 113 | + | |
104 | 114 | | |
105 | 115 | | |
106 | 116 | | |
| |||
120 | 130 | | |
121 | 131 | | |
122 | 132 | | |
| 133 | + | |
| 134 | + | |
123 | 135 | | |
124 | 136 | | |
125 | 137 | | |
| |||
140 | 152 | | |
141 | 153 | | |
142 | 154 | | |
| 155 | + | |
| 156 | + | |
143 | 157 | | |
144 | 158 | | |
145 | 159 | | |
| |||
159 | 173 | | |
160 | 174 | | |
161 | 175 | | |
| 176 | + | |
| 177 | + | |
162 | 178 | | |
163 | 179 | | |
164 | 180 | | |
| |||
179 | 195 | | |
180 | 196 | | |
181 | 197 | | |
| 198 | + | |
| 199 | + | |
182 | 200 | | |
183 | 201 | | |
184 | 202 | | |
| |||
198 | 216 | | |
199 | 217 | | |
200 | 218 | | |
| 219 | + | |
| 220 | + | |
201 | 221 | | |
202 | 222 | | |
203 | 223 | | |
| |||
Lines changed: 12 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
| 26 | + | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
| |||
42 | 44 | | |
43 | 45 | | |
44 | 46 | | |
| 47 | + | |
| 48 | + | |
45 | 49 | | |
46 | 50 | | |
47 | 51 | | |
| |||
62 | 66 | | |
63 | 67 | | |
64 | 68 | | |
| 69 | + | |
| 70 | + | |
65 | 71 | | |
66 | 72 | | |
67 | 73 | | |
| |||
81 | 87 | | |
82 | 88 | | |
83 | 89 | | |
| 90 | + | |
| 91 | + | |
84 | 92 | | |
85 | 93 | | |
86 | 94 | | |
| |||
101 | 109 | | |
102 | 110 | | |
103 | 111 | | |
| 112 | + | |
| 113 | + | |
104 | 114 | | |
105 | 115 | | |
106 | 116 | | |
| |||
120 | 130 | | |
121 | 131 | | |
122 | 132 | | |
| 133 | + | |
| 134 | + | |
123 | 135 | | |
124 | 136 | | |
125 | 137 | | |
| |||
0 commit comments