diff --git a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py index 6f7c684cb5..a5fe08bb02 100644 --- a/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py +++ b/fbgemm_gpu/experimental/gen_ai/gen_ai/attention/cutlass_blackwell_fmha/cutlass_blackwell_fmha_interface.py @@ -61,6 +61,8 @@ def _cutlass_blackwell_fmha_forward( softmax_scale: float | None = None, causal: bool = False, seqlen_kv: torch.Tensor | None = None, + page_table: torch.Tensor | None = None, + seqlen_k: int | None = None, window_left: int = -1, window_right: int = -1, bottom_right: bool = True, @@ -79,6 +81,8 @@ def _cutlass_blackwell_fmha_forward( softmax_scale=softmax_scale, causal=causal, seqlen_kv=seqlen_kv, + page_table=page_table, + seqlen_k=seqlen_k, window_size_left=window_left, window_size_right=window_right, bottom_right=bottom_right, @@ -171,6 +175,8 @@ def forward( # type: ignore max_seq_len_q: Optional[int] = None, max_seq_len_k: Optional[int] = None, seqlen_kv: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + seqlen_k: Optional[int] = None, window_size: tuple[int, int] = (-1, -1), bottom_right: bool = True, deterministic: bool = False, @@ -220,6 +226,8 @@ def forward( # type: ignore softmax_scale, causal, seqlen_kv, + page_table, + seqlen_k, window_left, window_right, bottom_right, @@ -293,6 +301,8 @@ def cutlass_blackwell_fmha_func( max_seq_len_q: int | None = None, max_seq_len_k: int | None = None, seqlen_kv: torch.Tensor | None = None, + page_table: torch.Tensor | None = None, + seqlen_k: int | None = None, window_size: tuple[int, int] | None = (-1, -1), bottom_right: bool = True, deterministic: bool = False, @@ -308,6 +318,8 @@ def cutlass_blackwell_fmha_func( max_seq_len_q, max_seq_len_k, seqlen_kv, + page_table, + seqlen_k, window_size, bottom_right, deterministic, diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu index 5b63ae44e9..f0478fd00f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd.cu @@ -4,8 +4,8 @@ std::tuple dispatch_fmha_fwd( const at::Tensor& q, - const at::Tensor& k, - const at::Tensor& v, + const at::Tensor& k, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged + const at::Tensor& v, // (batch_size, KV_seqlen, num_KV_heads, head_dim) if non-paged or (num_blocks, page_block_size, num_KV_heads, head_dim) if paged const std::optional& cu_seqlens_q, const std::optional& cu_seqlens_k, std::optional max_seq_len_q, @@ -13,9 +13,17 @@ std::tuple dispatch_fmha_fwd( std::optional softmax_scale, bool causal, const std::optional& seqlen_kv, + const std::optional& page_table, // dim: (batch_size, max_num_pages_per_seq) , null if non-paged + std::optional seqlen_k, int64_t window_size_left, int64_t window_size_right, bool bottom_right) { + + bool kIsPaged = false; + if (page_table && page_table->defined()) { + kIsPaged = true; + } + // Handle local attention parameters bool local = (window_size_left >= 0 || window_size_right >= 0); if (local) { @@ -60,6 +68,8 @@ std::tuple dispatch_fmha_fwd( max_seq_len_k, softmax_scale, seqlen_kv, + page_table, + seqlen_k, window_size_left, window_size_right); }; @@ -94,6 +104,7 @@ std::tuple dispatch_fmha_fwd( }; auto dispatch_mask = [&](auto varlen) { + int seq_k = kIsPaged ? static_cast(*seqlen_k) : varlen ? k.size(0) : k.size(1); if (causal) { if (bottom_right) { return dispatch_head_dim(varlen, CausalMask{}); @@ -106,7 +117,7 @@ std::tuple dispatch_fmha_fwd( } else { return dispatch_head_dim(varlen, LocalMask{}); } - } else if (varlen || k.size(1) % 128 != 0) { + } else if (varlen || seq_k % 128 != 0) { // Use the residual mask for varlen or when K seqlen is not multiple of // blockN return dispatch_head_dim(varlen, ResidualMask{}); @@ -138,6 +149,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { " float? softmax_scale=None, " " bool causal=False, " " Tensor? seqlen_kv=None, " + " Tensor? page_table=None, " + " int? seqlen_k=None, " " int window_size_left=-1, " " int window_size_right=-1, " " bool bottom_right=True" diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_bf16_inst.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_bf16_inst.cu index e87cc211ff..15db6543c4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_bf16_inst.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_bf16_inst.cu @@ -23,6 +23,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -42,6 +44,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -62,6 +66,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -81,6 +87,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -101,6 +109,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -120,6 +130,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -140,6 +152,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -159,6 +173,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -179,6 +195,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -198,6 +216,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp16_inst.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp16_inst.cu index b16cf23ee1..64100afea2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp16_inst.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp16_inst.cu @@ -23,6 +23,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -42,6 +44,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -62,6 +66,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -81,6 +87,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -101,6 +109,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -120,6 +130,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -140,6 +152,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -159,6 +173,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -179,6 +195,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -198,6 +216,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp8_inst.cu b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp8_inst.cu index a6ec5cf9de..908314044e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp8_inst.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_fp8_inst.cu @@ -23,6 +23,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -42,6 +44,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -62,6 +66,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -81,6 +87,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -101,6 +109,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); @@ -120,6 +130,8 @@ template std::tuple fmha_fwd< std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right); diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_template.cuh b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_template.cuh index 88f4e0c90c..475cf0e33c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_template.cuh +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/blackwell_fmha_fwd_template.cuh @@ -20,6 +20,8 @@ std::tuple fmha_fwd( std::optional max_seq_len_k, const std::optional softmax_scale, const std::optional& seqlen_kv, + const std::optional& page_table, + std::optional seqlen_k, const int window_size_left, const int window_size_right ) { @@ -29,6 +31,11 @@ std::tuple fmha_fwd( using ElementAccumulatorQK = float; using ElementAccumulatorPV = float; + bool kIsPaged = false; + if (page_table && page_table->defined()) { + kIsPaged = true; + } + // Q K D (H_r H_k) B using ProblemShapeRegular = cute::tuple, int>>; @@ -84,7 +91,31 @@ std::tuple fmha_fwd( StrideLSE>, TileScheduler>>; - if (kIsVarlen) { + if (kIsPaged && !kIsVarlen) { + TORCH_CHECK( + q.dim() == 4, + "Expect Q shape to be (batch_size, Q_seqlen, num_Q_heads, head_dim). ", + "Found shape ", q.sizes()); + TORCH_CHECK( + k.dim() == 4, + "Expect K shape to be (num_blocks, page_block_size, num_KV_heads, head_dim) ", + "Found shape ", k.sizes()); + TORCH_CHECK( + v.dim() == 4, + "Expect V shape to be (num_blocks, page_block_size, num_KV_heads, head_dim) ", + "Found shape ", v.sizes()); + TORCH_CHECK( + page_table.value().dim() == 2, + "Expect page table shape to be (batch_size, max_num_blocks_per_batch)", + "Found shape ", page_table.value().sizes()); + + int tile_N = static_cast(get<1>(TileShape{}).value); + TORCH_CHECK((k.size(1) % tile_N) == 0, "Page Block Size should be divisible by N tile size"); + TORCH_CHECK((v.size(1) % tile_N) == 0, "Page Block Size should be divisible by N tile size"); + + TORCH_CHECK(seqlen_k.has_value(), "seqlen_k should be set"); + } + else if (kIsVarlen) { TORCH_CHECK( q.dim() == 3, "Expect Q shape to be (total_Q_seqlen, num_Q_heads, head_dim) ", @@ -131,9 +162,16 @@ std::tuple fmha_fwd( // SQ represents SumB(Q) for varlen (jagged len) int SQ = kIsVarlen ? q.size(0) : q.size(1); - int SK = kIsVarlen ? k.size(0) : k.size(1); + int SK = kIsPaged ? static_cast(*seqlen_k) : kIsVarlen ? k.size(0) : k.size(1); int B = kIsVarlen ? cu_seqlens_q->size(0) - 1 : q.size(0); + // Parameters for paged attention. + int page_table_stride = kIsPaged ? page_table.value().size(1) : 0; + int num_blocks = kIsPaged ? k.size(0) : 1; // num_blocks + int page_block_size = kIsPaged ? k.size(1) : 1; // page_block_size + // num KV tiles > 1 within a page in the case of page_block_size > TileShapeN. + int num_KV_tiles_per_page = kIsPaged ? k.size(1) / (get<1>(TileShape{}).value) : 1; + ProblemShapeType problem_shape; if constexpr (kIsVarlen) { problem_shape = cute::make_tuple( @@ -153,7 +191,8 @@ std::tuple fmha_fwd( // Reshape to get strides auto B_ = kIsVarlen ? 1 : B; auto q_ = q.reshape({B_, SQ, H_K, H_R, D}); - auto k_ = k.reshape({B_, SK, H_K, 1, D}).expand({B_, SK, H_K, H_R, D}); + auto k_ = (kIsPaged) ? k.reshape({num_blocks, page_block_size, H_K, 1, D}).expand({num_blocks, page_block_size, H_K, H_R, D}) + : k.reshape({B_, SK, H_K, 1, D}).expand({B_, SK, H_K, H_R, D}); auto ndim = q_.dim(); TORCH_CHECK(q_.stride(ndim - 1) == 1, "The head dim in Q must be contiguous"); @@ -174,6 +213,8 @@ std::tuple fmha_fwd( static_cast(q_.stride(0)))); // K shape = (B, K, H_K, 1, D) + // Strides expressed in logical layout, (K, D, ((H_R, H_K), B)) if non-paged + // or (page_block_size, D, (H_R, H_K), num_blocks) if paged. StrideK stride_K = make_stride( static_cast(k_.stride(1)), _1{}, @@ -224,6 +265,11 @@ std::tuple fmha_fwd( static_cast(q.data_ptr()), stride_Q, static_cast(k.data_ptr()), stride_K, static_cast(v.data_ptr()), stride_V, + kIsPaged + ? static_cast(page_table.value().data_ptr()) + : nullptr, + page_table_stride, num_blocks, + page_block_size, num_KV_tiles_per_page, window_size_left, window_size_right }, static_cast(softmax_scale.value_or(0.0f)) /* softmax_scale */, diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index ba50c3bccf..2c6a2c482e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -257,10 +257,17 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { Load load; - load.load(blk_coord, problem_shape, params.load, params_problem_shape, - storage, - pipeline_q, pipeline_q_producer_state, - pipeline_kv, pipeline_kv_producer_state); + if (params.load.page_table) { + load.load_paged(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } else { + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } } template diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp index 44dea47160..742e92a2e9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -65,6 +65,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { using TileShapeQK = typename CollectiveMmaQK::TileShape; using TileShapePV = typename CollectiveMmaPV::TileShape; + using ProblemShapeK = + cute::tuple, int>>; + struct Arguments { const Element* ptr_Q; StrideQ dQ; @@ -72,6 +75,11 @@ struct Sm100FmhaLoadTmaWarpspecialized { StrideK dK; const Element* ptr_V; StrideV dV; + const int* ptr_page_table; + int page_table_stride; + int num_blocks; + int page_block_size; + int num_KV_tiles_per_page; int window_size_left = -1; int window_size_right = -1; @@ -85,6 +93,13 @@ struct Sm100FmhaLoadTmaWarpspecialized { TMA_Q tma_load_q; TMA_K tma_load_k; TMA_V tma_load_v; + + const int* page_table; + int page_table_stride; + int num_blocks; + int page_block_size; + int num_KV_tiles_per_page; + int window_size_left; int window_size_right; }; @@ -101,6 +116,8 @@ struct Sm100FmhaLoadTmaWarpspecialized { auto dQ = args.dQ; auto dK = args.dK; auto dV = args.dV; + bool kIsPaged = args.ptr_page_table ? true : false; + // Local changes (D79534034) int get_0 = int(get<0>(problem_shape)); @@ -117,27 +134,76 @@ struct Sm100FmhaLoadTmaWarpspecialized { get_1 = get<1>(problem_shape).total_length; } - auto problem_shape_qk = make_tuple(get_0, get_1, get<2>(problem_shape), get<3>(problem_shape)); - - auto params_qk = CollectiveMmaQK::to_underlying_arguments( - problem_shape_qk, - typename CollectiveMmaQK::Arguments { - ptr_Q, dQ, - ptr_K, dK, - }, /*workspace=*/ nullptr); + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; - auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); - auto params_pv = CollectiveMmaPV::to_underlying_arguments( - problem_shape_pv, - typename CollectiveMmaPV::Arguments { - ptr_K, dK, // never used, dummy - ptr_V, select<1,0,2>(dV), - }, /*workspace=*/ nullptr); + if (kIsPaged) { // Paged Case + //Create TMA Atom/Descriptor for Q, K, V + //Q + Layout layout_Q = make_layout(select<0,2,3>(problem_shape), dQ); + Tensor mQ = make_tensor(make_gmem_ptr(ptr_Q), layout_Q); + + auto cluster_layout_vmnk = + tiled_divide(make_layout(Shape<_1, _1, _1>{}), + make_tile(typename CollectiveMmaQK::TiledMma::AtomThrID{})); + tma_load_q = make_tma_atom_A_sm100( + cute::SM90_TMA_LOAD{}, mQ, SmemLayoutQ{}(_, _, _, _0{}), TileShapeQK{}, + typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk); + + // K + auto problem_shape_paged_k = make_tuple(get_0, get_1, get<2>(problem_shape), get<3>(problem_shape)); + get<1> (problem_shape_paged_k) = args.page_block_size; + get<3, 1>(problem_shape_paged_k) = args.num_blocks; + Layout layout_k = make_layout(select<1,2,3>(problem_shape_paged_k), dK); + Tensor mK = make_tensor(make_gmem_ptr(ptr_K), layout_k); + + tma_load_k = make_tma_atom_B_sm100( + cute::SM90_TMA_LOAD{}, mK, SmemLayoutK{}(_, _, _, _0{}), TileShapeQK{}, + typename CollectiveMmaQK::TiledMma{}, cluster_layout_vmnk); + + // V + auto problem_shape_paged_v = make_tuple(get_0, get<2>(problem_shape), get_1, get<3>(problem_shape)); + get<2> (problem_shape_paged_v) = args.page_block_size; + get<3, 1>(problem_shape_paged_v) = args.num_blocks; + Layout layout_v = make_layout(select<1,2,3>(problem_shape_paged_v), select<1,0,2>(dV)); + Tensor mV = make_tensor(make_gmem_ptr(ptr_V), layout_v); + + tma_load_v = make_tma_atom_B_sm100( + cute::SM90_TMA_LOAD{}, mV, SmemLayoutV{}(_, _, _, _0{}), TileShapePV{}, + typename CollectiveMmaPV::TiledMma{}, cluster_layout_vmnk); + } else { + auto problem_shape_qk = make_tuple(get_0, get_1, get<2>(problem_shape), get<3>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + tma_load_q = params_qk.tma_load_a; + tma_load_k = params_qk.tma_load_b; + tma_load_v = params_pv.tma_load_b; + } return Params{ - params_qk.tma_load_a, - params_qk.tma_load_b, - params_pv.tma_load_b, + tma_load_q, + tma_load_k, + tma_load_v, + args.ptr_page_table, + args.page_table_stride, + args.num_blocks, + args.page_block_size, + args.num_KV_tiles_per_page, args.window_size_left, args.window_size_right }; @@ -301,6 +367,161 @@ struct Sm100FmhaLoadTmaWarpspecialized { ++pipeline_kv_producer_state; } } + +template + CUTLASS_DEVICE void + load_paged( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto min_max = Mask(params.window_size_left, params.window_size_right).get_n_block_min_max(blk_coord_in, TileShape{}, problem_shape); + int n_block_min = get<0>(min_max); + int n_block_max = get<1>(min_max); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + int q_offs_2_1 = 0; + int batch_idx = get<2, 1>(blk_coord_q); + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + q_offs_0 = cumulative_length_q[batch_idx]; + get<2, 1>(blk_coord_q) = 0; + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, q_offs_2_1)), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + ProblemShapeK problem_shape_k = problem_shape; + get<1> (problem_shape_k) = params.page_block_size; + get<3, 1>(problem_shape_k) = params.num_blocks; + + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_k)); + + Tensor gK_kdl = local_tile(mK_kdl_p, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + + auto tKgK = tKgK_kdl(_, _, _0{}, make_coord(get<0>(get<2>(blk_coord_kv)), _)); + + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + ProblemShapeK problem_shape_v = problem_shape; + get<1> (problem_shape_v) = params.page_block_size; + get<3, 1>(problem_shape_v) = params.num_blocks; + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor gV_dkl = local_tile(mV_dkl_p, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + + auto tVgV = tVgV_dkl(_, _0{}, _, make_coord(get<0>(get<2>(blk_coord_kv)), _)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = n_block_min; + int logical_block_id = k_index / params.num_KV_tiles_per_page; + int page_id = params.page_table[batch_idx * params.page_table_stride + logical_block_id]; + int n_block = k_index % params.num_KV_tiles_per_page; // TODO: Replace % + + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, n_block, page_id), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, n_block, page_id), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + for (; k_index < n_block_max; k_index += 1) { + logical_block_id = k_index / params.num_KV_tiles_per_page; + page_id = params.page_table[batch_idx * params.page_table_stride + logical_block_id]; + n_block = k_index % params.num_KV_tiles_per_page; // TODO: Replace % + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, n_block, page_id), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, n_block, page_id), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + } + } }; } // namespace cutlass::fmha::collective diff --git a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py index d2f03748e1..a9896b60e9 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/attention/blackwell_fmha_test.py @@ -4,11 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import random import unittest -from typing import Optional +from typing import cast, Optional import torch +from einops import rearrange from fbgemm_gpu.experimental.gen_ai.attention.cutlass_blackwell_fmha import ( cutlass_blackwell_fmha_func, @@ -122,6 +124,77 @@ def _generate_qkv( v = v.to(torch.float8_e4m3fn) return q, k, v + # Generates K and V for paged attention. + def _generate_qkv_paged( + self, + batch_size: int, + seqlen_q: int, + seqlen_k: int, + q_heads: int, + kv_heads: int, + head_dim: int, + page_block_size: int, + device: torch.device, + dtype: torch.dtype, + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ]: + num_blocks = math.ceil(seqlen_k / page_block_size) * batch_size + q = torch.randn( + batch_size, + seqlen_q, + q_heads, + head_dim, + dtype=dtype if dtype != torch.float8_e4m3fn else torch.float, + device=device, + requires_grad=True, + ) + k_paged, v_paged = ( + torch.randn( + num_blocks, + page_block_size, + kv_heads, + head_dim, + dtype=dtype if dtype != torch.float8_e4m3fn else torch.float, + device=device, + requires_grad=True, + ) + for _ in range(2) + ) + page_table = rearrange( + torch.randperm(num_blocks, dtype=torch.int32, device=device), + "(b nblocks) -> b nblocks", + b=batch_size, + ) + if DEBUG: + print(f"page_table: {page_table.size()}") + + k = rearrange( + # pytorch 1.12 doesn't have indexing with int32 + k_paged[page_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + + v = rearrange( + v_paged[page_table.to(dtype=torch.long).flatten()], + "(b nblocks) block_size ... -> b (nblocks block_size) ...", + b=batch_size, + )[:, :seqlen_k] + + if dtype == torch.float8_e4m3fn: + q = q.to(torch.float8_e4m3fn) + k = k.to(torch.float8_e4m3fn) + v = v.to(torch.float8_e4m3fn) + k_paged = k_paged.to(torch.float8_e4m3fn) + v_paged = v_paged.to(torch.float8_e4m3fn) + return q, k, v, k_paged, v_paged, page_table + def _execute_cutlass_blackwell_attn_dense( self, batch_size: int, @@ -130,12 +203,14 @@ def _execute_cutlass_blackwell_attn_dense( q_heads: int, kv_heads: int, head_dim: int, + page_block_size: int, dtype: torch.dtype, causal: bool, window_size: tuple[int, int], fwd_only: bool, deterministic: bool, sm_scale: Optional[float], + is_paged: Optional[bool], ) -> None: device = torch.accelerator.current_accelerator() assert device is not None @@ -144,17 +219,34 @@ def _execute_cutlass_blackwell_attn_dense( # Initialize deterministic variables out_d = None + out_paged: torch.Tensor | None = None + k_paged: torch.Tensor | None = None + v_paged: torch.Tensor | None = None + page_table: torch.Tensor | None = None - q, k, v = self._generate_qkv( - batch_size, - seqlen_q, - seqlen_k, - q_heads, - kv_heads, - head_dim, - device, - dtype, - ) + if is_paged: + q, k, v, k_paged, v_paged, page_table = self._generate_qkv_paged( + batch_size, + seqlen_q, + seqlen_k, + q_heads, + kv_heads, + head_dim, + page_block_size, + device, + dtype, + ) + else: + q, k, v = self._generate_qkv( + batch_size, + seqlen_q, + seqlen_k, + q_heads, + kv_heads, + head_dim, + device, + dtype, + ) # Initialize seqlen_kv for generation phase (seqlen_q == 1) seqlen_kv = None @@ -195,6 +287,21 @@ def _execute_cutlass_blackwell_attn_dense( ) # Run tested kernel + if is_paged: + assert k_paged is not None and v_paged is not None + out_paged = cutlass_blackwell_fmha_func( + q, + k_paged, + v_paged, + causal=causal, + window_size=window_size, + seqlen_kv=seqlen_kv, + page_table=page_table, + seqlen_k=seqlen_k, + deterministic=deterministic, + softmax_scale=sm_scale, + ) + out = cutlass_blackwell_fmha_func( q, k, @@ -202,25 +309,35 @@ def _execute_cutlass_blackwell_attn_dense( causal=causal, window_size=window_size, seqlen_kv=seqlen_kv, + page_table=None, + seqlen_k=seqlen_k, deterministic=deterministic, softmax_scale=sm_scale, ) + if DEBUG: print("cutlass_blackwell_fmha_func completed successfully!") # Follow FlashAttention's numerical evaluation # Compare outputs - self._allclose(out, out_ref, out_pt) + if is_paged: + # Compare paged output with both reference and non paged output + self._allclose(out_paged, out_ref, out_pt) + self._allclose(out_paged, out, out_pt) + else: + self._allclose(out, out_ref, out_pt) if deterministic: # Rerun the test. The outputs must be bit-wise exact out_d = cutlass_blackwell_fmha_func( q, - k, - v, + cast(torch.Tensor, k_paged) if is_paged else k, + cast(torch.Tensor, v_paged) if is_paged else v, causal=causal, window_size=window_size, seqlen_kv=seqlen_kv, + page_table=page_table if is_paged else None, + seqlen_k=seqlen_k, deterministic=deterministic, softmax_scale=sm_scale, ) @@ -272,12 +389,14 @@ def _execute_cutlass_blackwell_attn_varlen( q_heads: int, kv_heads: int, head_dim: int, + page_block_size: int, dtype: torch.dtype, causal: bool, window_size: tuple[int, int], fwd_only: bool, deterministic: bool, sm_scale: Optional[float], + is_paged: Optional[bool], ) -> None: device = torch.accelerator.current_accelerator() assert device is not None @@ -475,6 +594,7 @@ def test_decode( q_heads, kv_heads=num_groups if is_mqa else q_heads, head_dim=head_dim, + page_block_size=0, dtype=dtype, causal=causal, # Decode kernel does not support sliding window attention yet @@ -483,6 +603,7 @@ def test_decode( deterministic=False, # Decode kernel does not support sm_scale sm_scale=None, + is_paged=False, ) @skip_cuda_lt_sm100 @@ -729,12 +850,92 @@ def test_forward( q_heads=kv_heads * q_heads_per_kv_head, kv_heads=kv_heads, head_dim=head_dim, + page_block_size=0, + dtype=dtype, + causal=causal, + window_size=window_size, + fwd_only=True, + deterministic=False, + sm_scale=sm_scale, + is_paged=False, + ) + + @skip_cuda_lt_sm100 + @skip_rocm + @parameterized.expand( + [ + ( + seqlen_q, + offset_q, + batch_size, + causal, + is_gqa, + is_varlen, + kv_heads, + window_size, + head_dim, + sm_scale, + page_block_size, + ) + for seqlen_q, offset_q in [ + (101, 0), + (111, 2), + (256, 0), + (1024, 0), + (113, 90), + (128, 90), + (256, 90), + (256, 128), + (1024, 128), + ] + for batch_size in [1, 2, 8] + for causal in [False, True] + for is_gqa in [False, True] + for is_varlen in [ + False + ] # Variable length is not supported for paged attention. + for kv_heads in [1, 2, 3, 4] + for window_size in [(-1, -1), (0, 0), (0, 128), (128, 0), (1024, 0)] + for head_dim in [64, 128] + for sm_scale in [None, 1.0 / head_dim] + for page_block_size in [128, 256] + ] + ) + def test_paged_forward( + self, + seqlen_q: int, + offset_q: int, + batch_size: int, + causal: bool, + is_gqa: bool, + is_varlen: bool, + kv_heads: int, + window_size: tuple[int, int], + head_dim: int, + sm_scale: Optional[float], + page_block_size: int, + dtype: torch.dtype = torch.bfloat16, + ) -> None: + seqlen_k = offset_q + seqlen_q + if seqlen_k > seqlen_q: + causal = True + + q_heads_per_kv_head = random.randint(2, 8) if is_gqa else 1 + self._execute_cutlass_blackwell_attn_dense( + batch_size, + seqlen_q, + seqlen_k, + q_heads=kv_heads * q_heads_per_kv_head, + kv_heads=kv_heads, + head_dim=head_dim, + page_block_size=page_block_size, dtype=dtype, causal=causal, window_size=window_size, fwd_only=True, deterministic=False, sm_scale=sm_scale, + is_paged=True, ) @skip_cuda_lt_sm100 @@ -828,10 +1029,12 @@ def test_backward( q_heads=kv_heads * q_heads_per_kv_head, kv_heads=kv_heads, head_dim=head_dim, + page_block_size=0, dtype=dtype, causal=causal, window_size=window_size, fwd_only=False, deterministic=deterministic, sm_scale=sm_scale, + is_paged=False, )