Skip to content

Commit 3dde689

Browse files
sarithad-metafacebook-github-bot
authored andcommitted
Add Paged Attention to FMHA Cutlass Blackwell Forward kernel for fixed length (pytorch#4999)
Summary: X-link: facebookresearch/FBGEMM#2013 Added paged attention support to FMHA FWD blackwell kernel. 1. Added support for fixed length case. 2. Added support for 2 cases: a) page_block_size = N tile size b) page_block_size > N 3. Added unit test, test_paged_forward. Next steps: 1. Test the performance of fixed length case. 2. Add support for variable length case to FWD kernel. 3. Add support for small page sizes to FWD kernel. 4. Add paged attention support for decode. Reviewed By: Aya-ZIbra, sijiac Differential Revision: D84023396
1 parent b0d84b6 commit 3dde689

File tree

9 files changed

+616
-49
lines changed

9 files changed

+616
-49
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ def _cutlass_blackwell_fmha_forward(
6161
softmax_scale: float | None = None,
6262
causal: bool = False,
6363
seqlen_kv: torch.Tensor | None = None,
64+
page_table: torch.Tensor | None = None,
65+
seqlen_k: int | None = None,
6466
window_left: int = -1,
6567
window_right: int = -1,
6668
bottom_right: bool = True,
@@ -79,6 +81,8 @@ def _cutlass_blackwell_fmha_forward(
7981
softmax_scale=softmax_scale,
8082
causal=causal,
8183
seqlen_kv=seqlen_kv,
84+
page_table=page_table,
85+
seqlen_k=seqlen_k,
8286
window_size_left=window_left,
8387
window_size_right=window_right,
8488
bottom_right=bottom_right,
@@ -171,6 +175,8 @@ def forward( # type: ignore
171175
max_seq_len_q: Optional[int] = None,
172176
max_seq_len_k: Optional[int] = None,
173177
seqlen_kv: Optional[torch.Tensor] = None,
178+
page_table: Optional[torch.Tensor] = None,
179+
seqlen_k: Optional[int] = None,
174180
window_size: tuple[int, int] = (-1, -1),
175181
bottom_right: bool = True,
176182
deterministic: bool = False,
@@ -220,6 +226,8 @@ def forward( # type: ignore
220226
softmax_scale,
221227
causal,
222228
seqlen_kv,
229+
page_table,
230+
seqlen_k,
223231
window_left,
224232
window_right,
225233
bottom_right,
@@ -293,6 +301,8 @@ def cutlass_blackwell_fmha_func(
293301
max_seq_len_q: int | None = None,
294302
max_seq_len_k: int | None = None,
295303
seqlen_kv: torch.Tensor | None = None,
304+
page_table: torch.Tensor | None = None,
305+
seqlen_k: int | None = None,
296306
window_size: tuple[int, int] | None = (-1, -1),
297307
bottom_right: bool = True,
298308
deterministic: bool = False,
@@ -308,6 +318,8 @@ def cutlass_blackwell_fmha_func(
308318
max_seq_len_q,
309319
max_seq_len_k,
310320
seqlen_kv,
321+
page_table,
322+
seqlen_k,
311323
window_size,
312324
bottom_right,
313325
deterministic,

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,26 @@
44

55
std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
66
const at::Tensor& q,
7-
const at::Tensor& k,
8-
const at::Tensor& v,
7+
const at::Tensor& k, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
8+
const at::Tensor& v, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
99
const std::optional<at::Tensor>& cu_seqlens_q,
1010
const std::optional<at::Tensor>& cu_seqlens_k,
1111
std::optional<int64_t> max_seq_len_q,
1212
std::optional<int64_t> max_seq_len_k,
1313
std::optional<double> softmax_scale,
1414
bool causal,
1515
const std::optional<at::Tensor>& seqlen_kv,
16+
const std::optional<at::Tensor>& page_table, // dim: (batch_size, max_num_pages_per_seq) , null if non-paged
17+
std::optional<int64_t> seqlen_k,
1618
int64_t window_size_left,
1719
int64_t window_size_right,
1820
bool bottom_right) {
21+
22+
bool kIsPaged = false;
23+
if (page_table && page_table->defined()) {
24+
kIsPaged = true;
25+
}
26+
1927
// Handle local attention parameters
2028
bool local = (window_size_left >= 0 || window_size_right >= 0);
2129
if (local) {
@@ -60,6 +68,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
6068
max_seq_len_k,
6169
softmax_scale,
6270
seqlen_kv,
71+
page_table,
72+
seqlen_k,
6373
window_size_left,
6474
window_size_right);
6575
};
@@ -94,6 +104,7 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
94104
};
95105

96106
auto dispatch_mask = [&](auto varlen) {
107+
int seq_k = kIsPaged ? static_cast<int>(*seqlen_k) : varlen ? k.size(0) : k.size(1);
97108
if (causal) {
98109
if (bottom_right) {
99110
return dispatch_head_dim(varlen, CausalMask</*kIsQBegin=*/false>{});
@@ -106,7 +117,7 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
106117
} else {
107118
return dispatch_head_dim(varlen, LocalMask</*kIsQBegin=*/true>{});
108119
}
109-
} else if (varlen || k.size(1) % 128 != 0) {
120+
} else if (varlen || seq_k % 128 != 0) {
110121
// Use the residual mask for varlen or when K seqlen is not multiple of
111122
// blockN
112123
return dispatch_head_dim(varlen, ResidualMask{});
@@ -138,6 +149,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
138149
" float? softmax_scale=None, "
139150
" bool causal=False, "
140151
" Tensor? seqlen_kv=None, "
152+
" Tensor? page_table=None, "
153+
" int? seqlen_k=None, "
141154
" int window_size_left=-1, "
142155
" int window_size_right=-1, "
143156
" bool bottom_right=True"

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_bf16_inst.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
2323
std::optional<int64_t> max_seq_len_k,
2424
const std::optional<double> softmax_scale,
2525
const std::optional<const at::Tensor>& seqlen_kv,
26+
const std::optional<const at::Tensor>& page_table,
27+
std::optional<int64_t> seqlen_k,
2628
const int window_size_left,
2729
const int window_size_right);
2830

@@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
4244
std::optional<int64_t> max_seq_len_k,
4345
const std::optional<double> softmax_scale,
4446
const std::optional<const at::Tensor>& seqlen_kv,
47+
const std::optional<const at::Tensor>& page_table,
48+
std::optional<int64_t> seqlen_k,
4549
const int window_size_left,
4650
const int window_size_right);
4751

@@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
6266
std::optional<int64_t> max_seq_len_k,
6367
const std::optional<double> softmax_scale,
6468
const std::optional<const at::Tensor>& seqlen_kv,
69+
const std::optional<const at::Tensor>& page_table,
70+
std::optional<int64_t> seqlen_k,
6571
const int window_size_left,
6672
const int window_size_right);
6773

@@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
8187
std::optional<int64_t> max_seq_len_k,
8288
const std::optional<double> softmax_scale,
8389
const std::optional<const at::Tensor>& seqlen_kv,
90+
const std::optional<const at::Tensor>& page_table,
91+
std::optional<int64_t> seqlen_k,
8492
const int window_size_left,
8593
const int window_size_right);
8694

@@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
101109
std::optional<int64_t> max_seq_len_k,
102110
const std::optional<double> softmax_scale,
103111
const std::optional<const at::Tensor>& seqlen_kv,
112+
const std::optional<const at::Tensor>& page_table,
113+
std::optional<int64_t> seqlen_k,
104114
const int window_size_left,
105115
const int window_size_right);
106116

@@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
120130
std::optional<int64_t> max_seq_len_k,
121131
const std::optional<double> softmax_scale,
122132
const std::optional<const at::Tensor>& seqlen_kv,
133+
const std::optional<const at::Tensor>& page_table,
134+
std::optional<int64_t> seqlen_k,
123135
const int window_size_left,
124136
const int window_size_right);
125137

@@ -140,6 +152,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
140152
std::optional<int64_t> max_seq_len_k,
141153
const std::optional<double> softmax_scale,
142154
const std::optional<const at::Tensor>& seqlen_kv,
155+
const std::optional<const at::Tensor>& page_table,
156+
std::optional<int64_t> seqlen_k,
143157
const int window_size_left,
144158
const int window_size_right);
145159

@@ -159,6 +173,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
159173
std::optional<int64_t> max_seq_len_k,
160174
const std::optional<double> softmax_scale,
161175
const std::optional<const at::Tensor>& seqlen_kv,
176+
const std::optional<const at::Tensor>& page_table,
177+
std::optional<int64_t> seqlen_k,
162178
const int window_size_left,
163179
const int window_size_right);
164180

@@ -179,6 +195,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
179195
std::optional<int64_t> max_seq_len_k,
180196
const std::optional<double> softmax_scale,
181197
const std::optional<const at::Tensor>& seqlen_kv,
198+
const std::optional<const at::Tensor>& page_table,
199+
std::optional<int64_t> seqlen_k,
182200
const int window_size_left,
183201
const int window_size_right);
184202

@@ -198,6 +216,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
198216
std::optional<int64_t> max_seq_len_k,
199217
const std::optional<double> softmax_scale,
200218
const std::optional<const at::Tensor>& seqlen_kv,
219+
const std::optional<const at::Tensor>& page_table,
220+
std::optional<int64_t> seqlen_k,
201221
const int window_size_left,
202222
const int window_size_right);
203223

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp16_inst.cu

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
2323
std::optional<int64_t> max_seq_len_k,
2424
const std::optional<double> softmax_scale,
2525
const std::optional<const at::Tensor>& seqlen_kv,
26+
const std::optional<const at::Tensor>& page_table,
27+
std::optional<int64_t> seqlen_k,
2628
const int window_size_left,
2729
const int window_size_right);
2830

@@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
4244
std::optional<int64_t> max_seq_len_k,
4345
const std::optional<double> softmax_scale,
4446
const std::optional<const at::Tensor>& seqlen_kv,
47+
const std::optional<const at::Tensor>& page_table,
48+
std::optional<int64_t> seqlen_k,
4549
const int window_size_left,
4650
const int window_size_right);
4751

@@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
6266
std::optional<int64_t> max_seq_len_k,
6367
const std::optional<double> softmax_scale,
6468
const std::optional<const at::Tensor>& seqlen_kv,
69+
const std::optional<const at::Tensor>& page_table,
70+
std::optional<int64_t> seqlen_k,
6571
const int window_size_left,
6672
const int window_size_right);
6773

@@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
8187
std::optional<int64_t> max_seq_len_k,
8288
const std::optional<double> softmax_scale,
8389
const std::optional<const at::Tensor>& seqlen_kv,
90+
const std::optional<const at::Tensor>& page_table,
91+
std::optional<int64_t> seqlen_k,
8492
const int window_size_left,
8593
const int window_size_right);
8694

@@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
101109
std::optional<int64_t> max_seq_len_k,
102110
const std::optional<double> softmax_scale,
103111
const std::optional<const at::Tensor>& seqlen_kv,
112+
const std::optional<const at::Tensor>& page_table,
113+
std::optional<int64_t> seqlen_k,
104114
const int window_size_left,
105115
const int window_size_right);
106116

@@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
120130
std::optional<int64_t> max_seq_len_k,
121131
const std::optional<double> softmax_scale,
122132
const std::optional<const at::Tensor>& seqlen_kv,
133+
const std::optional<const at::Tensor>& page_table,
134+
std::optional<int64_t> seqlen_k,
123135
const int window_size_left,
124136
const int window_size_right);
125137

@@ -140,6 +152,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
140152
std::optional<int64_t> max_seq_len_k,
141153
const std::optional<double> softmax_scale,
142154
const std::optional<const at::Tensor>& seqlen_kv,
155+
const std::optional<const at::Tensor>& page_table,
156+
std::optional<int64_t> seqlen_k,
143157
const int window_size_left,
144158
const int window_size_right);
145159

@@ -159,6 +173,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
159173
std::optional<int64_t> max_seq_len_k,
160174
const std::optional<double> softmax_scale,
161175
const std::optional<const at::Tensor>& seqlen_kv,
176+
const std::optional<const at::Tensor>& page_table,
177+
std::optional<int64_t> seqlen_k,
162178
const int window_size_left,
163179
const int window_size_right);
164180

@@ -179,6 +195,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
179195
std::optional<int64_t> max_seq_len_k,
180196
const std::optional<double> softmax_scale,
181197
const std::optional<const at::Tensor>& seqlen_kv,
198+
const std::optional<const at::Tensor>& page_table,
199+
std::optional<int64_t> seqlen_k,
182200
const int window_size_left,
183201
const int window_size_right);
184202

@@ -198,6 +216,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
198216
std::optional<int64_t> max_seq_len_k,
199217
const std::optional<double> softmax_scale,
200218
const std::optional<const at::Tensor>& seqlen_kv,
219+
const std::optional<const at::Tensor>& page_table,
220+
std::optional<int64_t> seqlen_k,
201221
const int window_size_left,
202222
const int window_size_right);
203223

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp8_inst.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
2323
std::optional<int64_t> max_seq_len_k,
2424
const std::optional<double> softmax_scale,
2525
const std::optional<const at::Tensor>& seqlen_kv,
26+
const std::optional<const at::Tensor>& page_table,
27+
std::optional<int64_t> seqlen_k,
2628
const int window_size_left,
2729
const int window_size_right);
2830

@@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
4244
std::optional<int64_t> max_seq_len_k,
4345
const std::optional<double> softmax_scale,
4446
const std::optional<const at::Tensor>& seqlen_kv,
47+
const std::optional<const at::Tensor>& page_table,
48+
std::optional<int64_t> seqlen_k,
4549
const int window_size_left,
4650
const int window_size_right);
4751

@@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
6266
std::optional<int64_t> max_seq_len_k,
6367
const std::optional<double> softmax_scale,
6468
const std::optional<const at::Tensor>& seqlen_kv,
69+
const std::optional<const at::Tensor>& page_table,
70+
std::optional<int64_t> seqlen_k,
6571
const int window_size_left,
6672
const int window_size_right);
6773

@@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
8187
std::optional<int64_t> max_seq_len_k,
8288
const std::optional<double> softmax_scale,
8389
const std::optional<const at::Tensor>& seqlen_kv,
90+
const std::optional<const at::Tensor>& page_table,
91+
std::optional<int64_t> seqlen_k,
8492
const int window_size_left,
8593
const int window_size_right);
8694

@@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
101109
std::optional<int64_t> max_seq_len_k,
102110
const std::optional<double> softmax_scale,
103111
const std::optional<const at::Tensor>& seqlen_kv,
112+
const std::optional<const at::Tensor>& page_table,
113+
std::optional<int64_t> seqlen_k,
104114
const int window_size_left,
105115
const int window_size_right);
106116

@@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
120130
std::optional<int64_t> max_seq_len_k,
121131
const std::optional<double> softmax_scale,
122132
const std::optional<const at::Tensor>& seqlen_kv,
133+
const std::optional<const at::Tensor>& page_table,
134+
std::optional<int64_t> seqlen_k,
123135
const int window_size_left,
124136
const int window_size_right);
125137

0 commit comments

Comments
 (0)