Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _cutlass_blackwell_fmha_forward(
softmax_scale: float | None = None,
causal: bool = False,
seqlen_kv: torch.Tensor | None = None,
page_table: torch.Tensor | None = None,
seqlen_k: int | None = None,
window_left: int = -1,
window_right: int = -1,
bottom_right: bool = True,
Expand All @@ -79,6 +81,8 @@ def _cutlass_blackwell_fmha_forward(
softmax_scale=softmax_scale,
causal=causal,
seqlen_kv=seqlen_kv,
page_table=page_table,
seqlen_k=seqlen_k,
window_size_left=window_left,
window_size_right=window_right,
bottom_right=bottom_right,
Expand Down Expand Up @@ -171,6 +175,8 @@ def forward( # type: ignore
max_seq_len_q: Optional[int] = None,
max_seq_len_k: Optional[int] = None,
seqlen_kv: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
seqlen_k: Optional[int] = None,
window_size: tuple[int, int] = (-1, -1),
bottom_right: bool = True,
deterministic: bool = False,
Expand Down Expand Up @@ -220,6 +226,8 @@ def forward( # type: ignore
softmax_scale,
causal,
seqlen_kv,
page_table,
seqlen_k,
window_left,
window_right,
bottom_right,
Expand Down Expand Up @@ -293,6 +301,8 @@ def cutlass_blackwell_fmha_func(
max_seq_len_q: int | None = None,
max_seq_len_k: int | None = None,
seqlen_kv: torch.Tensor | None = None,
page_table: torch.Tensor | None = None,
seqlen_k: int | None = None,
window_size: tuple[int, int] | None = (-1, -1),
bottom_right: bool = True,
deterministic: bool = False,
Expand All @@ -308,6 +318,8 @@ def cutlass_blackwell_fmha_func(
max_seq_len_q,
max_seq_len_k,
seqlen_kv,
page_table,
seqlen_k,
window_size,
bottom_right,
deterministic,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@

std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& k, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
const at::Tensor& v, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged
const std::optional<at::Tensor>& cu_seqlens_q,
const std::optional<at::Tensor>& cu_seqlens_k,
std::optional<int64_t> max_seq_len_q,
std::optional<int64_t> max_seq_len_k,
std::optional<double> softmax_scale,
bool causal,
const std::optional<at::Tensor>& seqlen_kv,
const std::optional<at::Tensor>& page_table, // dim: (batch_size, max_num_pages_per_seq) , null if non-paged
std::optional<int64_t> seqlen_k,
int64_t window_size_left,
int64_t window_size_right,
bool bottom_right) {

bool kIsPaged = false;
if (page_table && page_table->defined()) {
kIsPaged = true;
}

// Handle local attention parameters
bool local = (window_size_left >= 0 || window_size_right >= 0);
if (local) {
Expand Down Expand Up @@ -60,6 +68,8 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
max_seq_len_k,
softmax_scale,
seqlen_kv,
page_table,
seqlen_k,
window_size_left,
window_size_right);
};
Expand Down Expand Up @@ -94,6 +104,7 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
};

auto dispatch_mask = [&](auto varlen) {
int seq_k = kIsPaged ? static_cast<int>(*seqlen_k) : varlen ? k.size(0) : k.size(1);
if (causal) {
if (bottom_right) {
return dispatch_head_dim(varlen, CausalMask</*kIsQBegin=*/false>{});
Expand All @@ -106,7 +117,7 @@ std::tuple<at::Tensor, at::Tensor> dispatch_fmha_fwd(
} else {
return dispatch_head_dim(varlen, LocalMask</*kIsQBegin=*/true>{});
}
} else if (varlen || k.size(1) % 128 != 0) {
} else if (varlen || seq_k % 128 != 0) {
// Use the residual mask for varlen or when K seqlen is not multiple of
// blockN
return dispatch_head_dim(varlen, ResidualMask{});
Expand Down Expand Up @@ -138,6 +149,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
" float? softmax_scale=None, "
" bool causal=False, "
" Tensor? seqlen_kv=None, "
" Tensor? page_table=None, "
" int? seqlen_k=None, "
" int window_size_left=-1, "
" int window_size_right=-1, "
" bool bottom_right=True"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -140,6 +152,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -159,6 +173,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -179,6 +195,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -198,6 +216,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -140,6 +152,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -159,6 +173,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -179,6 +195,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -198,6 +216,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -42,6 +44,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -62,6 +66,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -81,6 +87,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -101,6 +109,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand All @@ -120,6 +130,8 @@ template std::tuple<at::Tensor, at::Tensor> fmha_fwd<
std::optional<int64_t> max_seq_len_k,
const std::optional<double> softmax_scale,
const std::optional<const at::Tensor>& seqlen_kv,
const std::optional<const at::Tensor>& page_table,
std::optional<int64_t> seqlen_k,
const int window_size_left,
const int window_size_right);

Expand Down
Loading
Loading