Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
955 commits
Select commit Hold shift + click to select a range
21ae9d9
Move ck decoder codes to xformers/csrc/attention/hip_decoder folder
qianfengz Sep 22, 2024
fb3628d
Sync to latest ck_tile commits for fixing NaN when seqlen_k == 0
qianfengz Sep 22, 2024
ffa9906
Separate the kernel/pipeline dispatch into two files for infer/forward
qianfengz Sep 22, 2024
221860e
Remove unused member variable in GroupedForwardParams
qianfengz Sep 22, 2024
6a07c16
delete autogenerated files
tenpercent Sep 24, 2024
74355e9
delete autogenerated files (2)
tenpercent Sep 25, 2024
0dbdc5f
Initial add support of fmha-forward splitk (copmiling passed)
qianfengz Sep 23, 2024
6b8ddde
Add generated files under hip_decoder into gitignore list
qianfengz Sep 26, 2024
cf9be1c
apply black and fix lint
tenpercent Sep 26, 2024
08219dc
rewrite hipified split-k decoder invocation to ck-tile style
tenpercent Sep 25, 2024
f37fb3d
Merge pull request #25 from tenpercent/refactor-hip-decoder
tenpercent Sep 27, 2024
9c8d2f1
Force kPadSeqLenQ == true for grouped mode splitkv-combine kernel traits
qianfengz Sep 28, 2024
761e8a5
Add compile-time checking to save compile-time
qianfengz Sep 29, 2024
eb50024
add dockerfile
tenpercent Jul 3, 2024
669ee34
migrate base docker image to manylinux
tenpercent Oct 2, 2024
0a97ed6
build a wheel
tenpercent Oct 2, 2024
f8129c3
rename dockerfile
tenpercent Oct 2, 2024
a0221e5
lint
tenpercent Oct 2, 2024
9975759
Merge pull request #26 from ROCm/dockerfile
tenpercent Oct 2, 2024
4d2a37d
Try adding docker image build workflow
tenpercent Oct 3, 2024
eddb1ec
Merge branch 'develop' into dockerfile
tenpercent Oct 3, 2024
ea3b796
add newline
tenpercent Oct 3, 2024
983fc19
add newline
tenpercent Oct 3, 2024
d6ea535
Merge pull request #27 from ROCm/dockerfile
tenpercent Oct 3, 2024
eb986e1
Update README.md
tenpercent Oct 3, 2024
ac0b05c
Merge pull request #28 from ROCm/tenpercent-patch-1
tenpercent Oct 3, 2024
e391974
Remove directly including of <ck_tile/host.hpp>
qianfengz Oct 3, 2024
9d03beb
Synchronize with latest ck develop commit
qianfengz Oct 3, 2024
0a4d420
Remove the printing in attention_forward_generic_ck_tiled.cpp
qianfengz Oct 3, 2024
a1c788e
Tune the TilePartitioner for splitkv-combine kernel
qianfengz Oct 3, 2024
b1e5ee4
Use 64 as maximum possible number of splitkv
qianfengz Oct 3, 2024
a53ed75
Add environment variable to disable building fmha-fwd-splitkv
qianfengz Oct 4, 2024
772e8f6
Use 32 as maximum number of splits
qianfengz Oct 6, 2024
04bb150
Fix compilation errors due to CK interface change
poyenc Oct 6, 2024
7949da4
Determine kHasUnevenSplits at runtime
poyenc Oct 6, 2024
3a8d7cf
Determine kPadSeqLenK at runtime
poyenc Oct 7, 2024
28ac1ca
Let kPadSeqLenK be reversed value of kHasUnevenSplits
qianfengz Oct 7, 2024
00482a0
Merge branch 'develop' into add-splitkv
qianfengz Oct 7, 2024
b62e722
Merge pull request #29 from ROCm/add-splitkv
qianfengz Oct 7, 2024
e8143c3
Synchronize to latest ck develop commit for updates with regard to fm…
qianfengz Oct 8, 2024
7986c2c
fix build: stream type
tenpercent Oct 8, 2024
93524db
Merge pull request #30 from tenpercent/develop
qianfengz Oct 9, 2024
ee10600
Add support for fmha-bwd headdim-96
qianfengz Oct 11, 2024
c9fa526
Use kK2=96
qianfengz Oct 12, 2024
abc9361
Synchronize the change in ck-tile to rename kQKHeaddimForGemmN to kQK…
qianfengz Oct 14, 2024
5bb0542
Synchronize the change in ck-tile to replace kVHeaddimForGemmN by kVH…
qianfengz Oct 14, 2024
c5b594d
Simplify FmhaBwdPipelineEnumSelector templates
qianfengz Oct 14, 2024
dd6cf04
Merge branch 'develop' into bwd_hd96_perf
qianfengz Oct 14, 2024
723c420
Synchronize to latest ck_tile commit
qianfengz Oct 14, 2024
a15e559
Replace TileFmhaBwdTraits by TileFmhaTraits
qianfengz Oct 15, 2024
2773383
Relocate to ck_tile develop branch and synchronize to latest commits
qianfengz Oct 16, 2024
d4437ad
Merge pull request #31 from ROCm/bwd_hd96_perf
qianfengz Oct 16, 2024
f94fdfd
Remove using splitkv from fmha-fwd training path
qianfengz Oct 16, 2024
4b4327e
Revert "Remove using splitkv from fmha-fwd training path"
qianfengz Oct 17, 2024
bc107ad
Add kMaxSplits=8 support
qianfengz Oct 17, 2024
91e01f9
Add tile settings for splitkv kernel
qianfengz Oct 18, 2024
139334c
Use WarpTile 16x16x16 for fmha-fwd splitkv
qianfengz Oct 20, 2024
c553f1a
Add MaxSeqlenQ as parameter for creating tile shape settings
qianfengz Oct 20, 2024
eb4586e
Update in FmhaFwdSplitKVShape
qianfengz Oct 20, 2024
6b0fae2
Synchronize to the latest commit of ck_tile for split-kv support
qianfengz Oct 21, 2024
46bc17d
Merge pull request #32 from ROCm/splitkv_improve
qianfengz Oct 21, 2024
76b9738
Change the selection of Default2DEpilogue for Fwd SplitKV kernel to a…
qianfengz Oct 25, 2024
7243b49
Try to have kPadSeqLenK be false in splitkv dispatch
qianfengz Oct 25, 2024
6ffea6a
Revert "Try to have kPadSeqLenK be false in splitkv dispatch"
qianfengz Oct 26, 2024
5f1ec0c
Synchronize for latest splitkv support in ck-tile
qianfengz Oct 26, 2024
3437842
Use kSubQKHeaddim to replace kK0BlockLength
qianfengz Oct 27, 2024
6c8a8b4
Add headdim96 support for fmha-fwd
qianfengz Oct 28, 2024
cb58e69
Synchronize to latest commit in ck-tile
qianfengz Oct 28, 2024
7d8ced0
Reposition the composable_kernel_tiled submodule to latest ck develop…
qianfengz Oct 30, 2024
06b548c
Merge pull request #34 from ROCm/fwd_hd96_debug
qianfengz Oct 30, 2024
7f91bb1
Synchronize to latest ck_tile commit for some bug fixing in page-attn
qianfengz Nov 11, 2024
44b6def
Fix grad_k/grad_v strides
qianfengz Nov 13, 2024
b000bb3
Merge pull request #36 from ROCm/stride_fix
qianfengz Nov 13, 2024
bdfffaa
Synchronize to latest ck_tile commit for adding Paged-KVCache dependa…
qianfengz Nov 21, 2024
266e3c6
Let splitkv combine kernel not called when num_splits is 1
qianfengz Nov 22, 2024
273a892
Add supported for Paged-KVCache (PagedBlockDiagonalPaddedKeysMask pas…
qianfengz Nov 25, 2024
22df8c9
Add is_gappy indicator to let kernel have special treatment for seqst…
qianfengz Nov 25, 2024
e768502
Fix in _custom_mask_type of ck.py
qianfengz Nov 26, 2024
00c70d0
Add test_paged_attention_ck in tests/test_mem_eff_attention.py
qianfengz Nov 26, 2024
468c83f
position to the latest ck develop branch
qianfengz Nov 26, 2024
95460bc
Change to check causalmask type and window_size parameter together to…
qianfengz Nov 26, 2024
56dba6b
Merge pull request #37 from ROCm/add_paged_kvcache
qianfengz Nov 26, 2024
9ccc42f
bump python op maxk
tenpercent Nov 27, 2024
760cdcc
run codegen
tenpercent Nov 27, 2024
4de46f4
run codegen (1)
tenpercent Nov 27, 2024
89e8e91
add missing FmhaFwdBlockTile instance; handle 512 case when computing…
tenpercent Dec 2, 2024
f13d987
Initial adding support for splitkv smallq pipeline
qianfengz Dec 3, 2024
672617b
fix compile error in qr_ks_vs pipeline
tenpercent Dec 3, 2024
d7099cb
fix occupancy related compilation errors
tenpercent Dec 3, 2024
a198345
try adding qsksvs pipeline and stash the result
tenpercent Dec 4, 2024
580ec51
Synchronize to latest ck_tile commit to utilize the padding optimzation
qianfengz Dec 6, 2024
8a45436
Merge pull request #40 from ROCm/optimize_padding
qianfengz Dec 6, 2024
a19d6a3
Resync to latest ck-tile commit for padding optimization
qianfengz Dec 6, 2024
e27b84c
Fix in batched_forward splitkv dispatch
qianfengz Dec 6, 2024
5041a12
Merge branch 'develop' into add_splitkv_smallq
qianfengz Dec 6, 2024
aee3570
Fix in batched_forward splitkv smallq dispatch
qianfengz Dec 6, 2024
be06c43
Update the splits selector and instances settings for splitkv-smallq …
qianfengz Dec 9, 2024
aff7bfd
Enable gemm-0 to use 16x16x16 warp-gemm
qianfengz Dec 10, 2024
1922015
enable offload compression
tenpercent Dec 4, 2024
2cc18ef
run black
tenpercent Dec 11, 2024
da455ec
fix merge conflict (1)
tenpercent Dec 11, 2024
21330ed
reset submodule
tenpercent Dec 11, 2024
e8946b2
cleanup
tenpercent Dec 11, 2024
7e92d1f
Merge remote-tracking branch 'origin/develop' into ci-fixes
tenpercent Dec 11, 2024
8b580f4
run black
tenpercent Dec 11, 2024
3f9a40b
Merge pull request #41 from ROCm/ci-fixes
tenpercent Dec 12, 2024
afdfa46
Synchronize to use the latest optimization for splitkv combine kernel
qianfengz Dec 13, 2024
1258328
Update in ck FwOp apply() to welll utilize the group query support in…
qianfengz Dec 15, 2024
08edbf9
Update to let fmha infer kernel can select either 16x16 or 32x32 inst…
qianfengz Dec 16, 2024
57e157e
Remove the conditional compiling of using splitkv kernel
qianfengz Dec 16, 2024
c1647c7
Merge remote-tracking branch 'origin/develop' into hdim-512
tenpercent Dec 16, 2024
84d7253
Sync to the latest commit of the ck_tile branch
qianfengz Dec 17, 2024
97523dd
Sync to the latest commit of the ck_tile branch for updated pipeline …
qianfengz Dec 17, 2024
aa781c8
Update in the method for determining num_kv_splits
qianfengz Dec 17, 2024
e53d164
Update to the tile setting for splitkv-smallq headdim128
qianfengz Dec 17, 2024
1ae3de9
call qsksvs pipeline on either async or sync codepath in dispatch
tenpercent Dec 18, 2024
f10bc80
more pipeline changes
tenpercent Dec 18, 2024
83cabd4
update submodule
tenpercent Dec 18, 2024
53d4e0e
update headdim switch
tenpercent Dec 18, 2024
d0431e1
Update to the splitkv and splitkv-smallq selector
qianfengz Dec 18, 2024
5644f9f
fix kernel not being called
tenpercent Dec 18, 2024
bb703b5
test head dimension 512 for ckF
tenpercent Dec 18, 2024
2ea82a9
re-run generate_instances.py to please clang-format
tenpercent Dec 18, 2024
82ba746
run clang-format
tenpercent Dec 18, 2024
70d767d
run black
tenpercent Dec 18, 2024
6605ddb
Add ck in tests/test_mem_eff_attention.py::test_backward_gqa
qianfengz Dec 18, 2024
c1ab8e5
Re-position to latest develop branch and rename the SplitkvSmallq pip…
qianfengz Dec 20, 2024
4a5298c
Merge branch 'develop' into add_splitkv_smallq_nwarps
qianfengz Dec 20, 2024
73204e1
Merge pull request #44 from ROCm/add_splitkv_smallq_nwarps
qianfengz Dec 20, 2024
73d06c1
Replace the reshape() by flatten/unflatten in ck.py
qianfengz Dec 20, 2024
2980a55
Update ck.py to support expanded 5-D input for ck.FwOp
qianfengz Dec 20, 2024
84414b1
Fix in ck.py
qianfengz Dec 21, 2024
bf33926
Remove using partitioner for fmha kernels
qianfengz Dec 26, 2024
256d6a4
Add support for mqa_decoder optimization which merge Hq/Hkv with seql…
qianfengz Jan 7, 2025
23d7b1c
Synchronize to latest ck_tile commit which has changed GridSize() of …
qianfengz Jan 7, 2025
d66e7bf
Merge pull request #46 from ROCm/mqa_decoder_improve
qianfengz Jan 7, 2025
e07d13c
bump submodule
tenpercent Jan 7, 2025
bf78988
bump submodule to today's merge commit in ck
tenpercent Jan 7, 2025
8c28fdb
Merge remote-tracking branch 'origin/develop' into hdim-512
tenpercent Jan 7, 2025
40cbefb
refactor dispatch
tenpercent Jan 8, 2025
40f92e7
bump ck submodule to the current develop branch
tenpercent Jan 8, 2025
e4a7f3b
fix flake8 lint
tenpercent Jan 8, 2025
e5a43d4
Merge branch 'develop' into hdim-512
tenpercent Jan 8, 2025
cbe8e20
clang-format
tenpercent Jan 8, 2025
c2d9939
Removing the compressing of expanded 5D to 4D for xops.fmha.ck.FwOp
qianfengz Jan 9, 2025
58fa14a
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Jan 9, 2025
cb60bad
wheels
johnnynunez Jan 9, 2025
b301741
Synchronize to latest ck_tile commit
qianfengz Jan 10, 2025
2f75f5a
Skip PagedBlockDiagonal attn_bias types for hdim-512
qianfengz Jan 10, 2025
acb58a5
Remove using DISABLE_HD256_HIP_FMHA env-variable and FMHA_SUPPORT_MAX…
qianfengz Jan 10, 2025
1887a33
Add using ENABLE_HD512_HIP_FMHA env-variable and FMHA_LIMIT_MAX_HEADD…
qianfengz Jan 10, 2025
a5c68d2
Update to the selector to explicitly use non-splitkv kernel for hdim-512
qianfengz Jan 10, 2025
73d7b78
Merge pull request #1 from ROCm/hdim-512-testing
tenpercent Jan 10, 2025
6da69d3
Update wheels.yml
johnnynunez Jan 10, 2025
eeb581f
Synchronize to latest ck commit
qianfengz Jan 13, 2025
701685c
Use 64x128 Gemm0 Tile and WarpGemm-16x16x16 for hdim-512
qianfengz Jan 13, 2025
fd11dbd
Merge pull request #48 from ROCm/hdim-512
qianfengz Jan 13, 2025
84883b5
Remove using splitkv kernel from fmha fwd training path
qianfengz Jan 13, 2025
2f66b19
Merge pull request #49 from ROCm/hack_test_backward
qianfengz Jan 13, 2025
be6f8c2
Add -Wc++11-narrowing to hip_fmha compiling options to avoid any erro…
qianfengz Jan 14, 2025
1f12982
Merge branch 'develop' into develop
johnnynunez Jan 15, 2025
e14bf36
Update wheels.yml
johnnynunez Jan 15, 2025
6213bf6
Disable PagedAttn bias types and hdim-512 for test_logsumexp
qianfengz Jan 15, 2025
028196d
Merge pull request #50 from ROCm/fix_test_logsumexp
qianfengz Jan 15, 2025
d6e7e4f
Merge branch 'develop' into develop
johnnynunez Jan 15, 2025
58c037b
Update wheels.yml
johnnynunez Jan 15, 2025
1dcb9d8
hotfix typo
tenpercent Jan 15, 2025
21ede52
Merge pull request #51 from tenpercent/develop
tenpercent Jan 15, 2025
433f4f9
Merge branch 'develop' into develop
tenpercent Jan 15, 2025
fdc222d
Use new pipeline assignment strategy and separate tile shape settings…
qianfengz Jan 16, 2025
4685c44
Merge pull request #47 from johnnynunez/develop
tenpercent Jan 16, 2025
6c78398
enable hdim=512 by default
tenpercent Jan 16, 2025
865e802
Merge branch 'develop' into develop
tenpercent Jan 16, 2025
beadd0b
Merge pull request #52 from tenpercent/develop
qianfengz Jan 17, 2025
0c85bee
Further update to build hdim-512 by default
qianfengz Jan 17, 2025
fdc410a
Merge pull request #53 from ROCm/further_fix
qianfengz Jan 17, 2025
9928374
Merge remote-tracking branch 'upstream/main' into merge_upstream
qianfengz Jan 17, 2025
9045af7
Merge pull request #54 from ROCm/merge_upstream
qianfengz Jan 17, 2025
5a74138
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Jan 18, 2025
ddbe036
Use separate setting for qr_ks_vs and qr_ks_vs_async pipelines
qianfengz Jan 22, 2025
30702d7
Revert "Remove using splitkv kernel from fmha fwd training path"
qianfengz Jan 23, 2025
4ffcac0
Merge pull request #55 from ROCm/some_roll_back
qianfengz Jan 23, 2025
dd59c20
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Jan 23, 2025
d77158c
Merge branch 'develop' into async_improvement
qianfengz Jan 23, 2025
45a4365
Ensure to qr_ks_vs pipeline is used when kHasDropout is true and MaxK…
qianfengz Jan 24, 2025
501a9ba
Tune the tile sizes for hdim-128 for qr_ks_vs_async pipeline
qianfengz Jan 27, 2025
46d424c
Synchronize to the change in ck_tile (QLoadOnece == false for qr_ks_v…
qianfengz Feb 2, 2025
dfb31aa
Use kM0 = 128 for hdim-96 when using qr_ks_vs_async pipeline
qianfengz Feb 2, 2025
b699977
Synchronize to the updated ck_tile commit
qianfengz Feb 4, 2025
e7146b6
Adjust the tile shape settings for hdim-128 and hdim-96
qianfengz Feb 7, 2025
4a68301
Adjust warp tile settings for hdim-128 and mtile-128
qianfengz Feb 7, 2025
3d4bac7
Tune the tile settings for hdim-96 and hdim-128
qianfengz Feb 9, 2025
f68019a
Tune the kPadSeqLenQ and kPadSeqLenK using in batched_infer and group…
qianfengz Feb 11, 2025
e2629db
Fix in ck.py to handle attn_bias types with 5-D bias tensor
qianfengz Feb 11, 2025
6c5a72a
Let ck_splitk_decoder to use ck_tile headers only
qianfengz Feb 13, 2025
0144584
Merge pull request #58 from ROCm/ck_tile_splitk_decoder_fix
qianfengz Feb 13, 2025
8c05a8e
Synchronize to the latest ck develop branch for solving a test failure
qianfengz Feb 14, 2025
551fd23
Synchronize to the latest ck develop branch for solving the test_page…
qianfengz Feb 17, 2025
e12bca7
remove import sys in generate_instances.py
qianfengz Feb 17, 2025
981d068
Tiny scripts update in ck.py
qianfengz Feb 17, 2025
f0029c7
Rename the qr_ks_vs_async pipeline to qr_ks_vs_whole_k_prefetch pipeline
qianfengz Feb 19, 2025
3caf1de
silence lint, sync with upstream
tenpercent Feb 20, 2025
3e79085
refactor attention inner product
tenpercent Feb 14, 2025
0b61542
Merge pull request #59 from ROCm/cktile-innerproduct
qianfengz Feb 21, 2025
b930f31
Synchronize to the updated ck_tile commit
qianfengz Feb 22, 2025
8ba5216
Tiny scripts update in ck.py
qianfengz Feb 17, 2025
5a9239f
Rename the ck_tile submodule branch and synchronize to latest commit
qianfengz Feb 24, 2025
9bfdf3c
Synchronize to the updated ck_tile commit
qianfengz Mar 2, 2025
9155e2d
Synchronize to the updated ck_tile commit
qianfengz Mar 2, 2025
f17389e
Merge branch 'develop' of https://github.com/ROCm/xformers into develop
qianfengz Mar 3, 2025
44186d7
Merge branch 'develop' into async_improvement
qianfengz Mar 3, 2025
c5620b0
Remove using ck_tiled_fmha_async_fwd_setting.h and sync to updated ck…
qianfengz Mar 3, 2025
15d66d8
Use qr_ks_vs_async pipeline for hdim-96
qianfengz Mar 3, 2025
1736dc7
Synchronize to the update ck_tile commit
qianfengz Mar 4, 2025
0909c26
Synchronize to the updated ck_tile commit
qianfengz Mar 4, 2025
552d821
Synchronize to the updated ck_tile commit
qianfengz Mar 4, 2025
03fed31
Synchronize to the updated ck_tile commit
qianfengz Mar 5, 2025
0ddd927
Re-position the ck_tiled submodule to develop branch
qianfengz Mar 7, 2025
46b35b7
Re-format .gitmodules
qianfengz Mar 7, 2025
4756a53
Merge pull request #61 from ROCm/async_improvement
qianfengz Mar 7, 2025
a0a401e
Let qualified cases with MTile=128 to use qr_ks_vs_async pipeline
qianfengz Mar 11, 2025
4b18a0c
remove legacy ck decoder
tenpercent Mar 14, 2025
f1becf7
add environment knob for turning off ck fmha
tenpercent Mar 20, 2025
da59ab4
Correct the condition for using merge_nhead_groups_seqlen_q
qianfengz Mar 25, 2025
ffe4808
Merge pull request #64 from ROCm/ck-fmha-enable-knob
qianfengz Apr 7, 2025
58e1103
Merge pull request #63 from ROCm/remove-legacy-decoder
qianfengz Apr 21, 2025
02e7602
Fix to make hip_fmha compilable on torch-2.8
qianfengz May 9, 2025
39addc8
Merge remote-tracking branch 'upstream/main' into develop
qianfengz May 9, 2025
8fdfa85
Add support of BlockDiagonalCausalLocalAttentionPaddedKeysMask with c…
qianfengz May 27, 2025
e8fdbba
Import all used attn_bias types and remove prefix when referring to t…
qianfengz May 28, 2025
89967ff
Remove efficient_attention_forward_decoder_ck from the interface decl…
qianfengz Jun 11, 2025
a893fd5
Return logsumexp as std::optional<at::Tensor> in efficient_attention_…
qianfengz Jun 24, 2025
8c203d8
Update ck pin (#66)
tenpercent Jun 24, 2025
8e2050c
Update and synchronize with the latest ck_tile kernel arguments chang…
qianfengz Jun 25, 2025
9e9fda3
add fmha grouped infer pagedkv dispatch
ltqin Jul 1, 2025
ebc7f5a
limit k==kv for pagedkv
ltqin Jul 2, 2025
60860a1
remove logits_soft_cap and paged limit
ltqin Jul 8, 2025
2ecbc3a
limit seq_len_q
ltqin Jul 9, 2025
a622744
Update to latest ck develop commit to include ck_tile PR-2405
qianfengz Jul 9, 2025
a969435
Renaming in pagedkv_dispatch
qianfengz Jul 9, 2025
8ab0b15
remove gappy constraints and set to support only pagedkv
ltqin Jul 11, 2025
47a2681
Synchronize to latest ck_tile(add pagedkv for large seq_len_q)
ltqin Jul 11, 2025
04c4ff7
Selecting MTile (128 or 64) for calling grouped_infer_pagedkv_mask_bi…
qianfengz Jul 11, 2025
f62f580
Merge pull request #68 from ROCm/ck_tile/kvcache_prefill
qianfengz Jul 11, 2025
3e13f86
Clarify the usage of param.use_split_kv and param.use_paged_kvcache
qianfengz Jul 11, 2025
25e836f
remove selecting mtile,just use 128
ltqin Jul 17, 2025
0fa8753
Merge pull request #69 from ROCm/ck_tile/kvcache_prefill_ch_tile_size
qianfengz Jul 17, 2025
bd9e97b
Merge branch 'main' into develop
qianfengz Jul 18, 2025
f92ee1a
Remove the checking of compute_logsumexp at the return of efficient_a…
qianfengz Jul 18, 2025
0d6ec71
Align some files with the upstream
qianfengz Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/composable_kernel_tiled
Original file line number Diff line number Diff line change
Expand Up @@ -396,8 +396,7 @@ efficient_attention_forward_ck(

// 1) fmha fwd split-kv kernel does not support dropout
// 2) Paged-KVcache is only available from the split-kv kernel at present
p.use_split_kv =
(p.use_paged_kvcache || (!use_dropout && use_split_kv)) ? true : false;
p.use_split_kv = (!use_dropout && use_split_kv) ? true : false;

p.num_kv_splits = num_kv_splits;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ void run_grouped_forward_mask_bias_dropout_dispatch(
// (*) dropout
// (*) head dimension > 256
if constexpr (!kHasDropout) {
if (param.use_split_kv && MaxK <= 256) {
if ((param.use_split_kv || param.use_paged_kvcache) && MaxK <= 256) {
if constexpr (MaxK <= 256) {
if (use_splitkv_smallq(
param.max_seqlen_q, std::max(param.K, param.Kv))) {
Expand Down
21 changes: 16 additions & 5 deletions xformers/csrc/attention/hip_fmha/ck_tiled_fmha_grouped_infer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "ck_tiled_fmha_fwd_setting.h"
#include "ck_tiled_fmha_fwd_splitkv_smallq_selector.h"
#include "ck_tiled_fmha_grouped_infer_dispatch.h"
#include "ck_tiled_fmha_grouped_infer_pagedkv_dispatch.h"
#include "ck_tiled_fmha_grouped_infer_splitkv_dispatch.h"
#include "ck_tiled_fmha_grouped_infer_splitkv_smallq_dispatch.h"
#include "ck_tiled_fmha_seqlen_q_switch.h"
Expand All @@ -27,7 +28,7 @@ void run_grouped_infer_mask_bias_dropout_dispatch(
// (*) dropout
// (*) head dimension > 256
if constexpr (!kHasDropout) {
if (param.use_split_kv && MaxK <= 256) {
if ((param.use_split_kv || param.use_paged_kvcache) && MaxK <= 256) {
if constexpr (MaxK <= 256) {
if (use_splitkv_smallq(
param.max_seqlen_q, std::max(param.K, param.Kv))) {
Expand All @@ -37,14 +38,24 @@ void run_grouped_infer_mask_bias_dropout_dispatch(
kHasBias,
MaxK>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, MaxSeqlenQ, [&] {
grouped_infer_splitkv_mask_bias_dropout_dispatch<
if (!param.use_split_kv && param.use_paged_kvcache &&
param.page_block_size >= 128) {
grouped_infer_pagedkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
128>::Run(param, stream);
} else {
FMHA_FWD_SEQLEN_Q_SWITCH(param.max_seqlen_q, MaxSeqlenQ, [&] {
grouped_infer_splitkv_mask_bias_dropout_dispatch<
ScalarType,
kHasMask,
kHasBias,
MaxK,
MaxSeqlenQ>::Run(param, stream);
});
}
}
} else {
// Unreachable. Do not instantiate split-kv pipelines with head
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/*
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

#include <ck_tile/core/numeric/integer.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <ck_tile/ops/fmha.hpp>

#include "ck_tiled_bool_switch.h"
#include "ck_tiled_fmha_fwd_setting.h"
#include "ck_tiled_fmha_params.h"

template <
typename ScalarType,
bool kHasMask,
bool kHasBias,
ck_tile::index_t MaxK,
ck_tile::index_t MTile>
struct grouped_infer_pagedkv_mask_bias_dropout_dispatch {
using fmha_variant = ck_tile::ComposedAttention<
false * ck_tile::LOGITS_SOFT_CAP,
CK_TILE_FMHA_FWD_FAST_EXP2>;

using FmhaTileShape = typename FmhaFwdShape<MaxK, MTile>::Type;

template <
typename FmhaFwdPagedKVTraits,
typename FmhaMask,
typename ODataType>
using FmhaFwdPagedKVPipelineProblemTemp =
ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
typename FmhaFwdTypeConfig<ScalarType>::QDataType,
typename FmhaFwdTypeConfig<ScalarType>::KDataType,
typename FmhaFwdTypeConfig<ScalarType>::VDataType,
typename FmhaFwdTypeConfig<ScalarType>::SaccDataType,
typename FmhaFwdTypeConfig<ScalarType>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<ScalarType>::BiasDataType,
typename FmhaFwdTypeConfig<ScalarType>::LSEDataType,
typename FmhaFwdTypeConfig<ScalarType>::PDataType,
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
ODataType,
FmhaTileShape,
true, // kIsGroupMode
fmha_variant,
FmhaMask,
FmhaFwdPagedKVTraits>;

static void Run(GroupedForwardParams& param, hipStream_t stream) {
{
using FmhaMask = ck_tile::SimplifiedGenericAttentionMask<kHasMask>;

constexpr ck_tile::index_t occupancy = -1;

constexpr auto kBiasEnum = kHasBias
? ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS
: ck_tile::BlockAttentionBiasEnum::NO_BIAS;

constexpr bool kPadSeqLenQ = true;
constexpr bool kPadSeqLenK = true;

bool pad_headdim_q = !(param.K % FmhaTileShape::kSubQKHeaddim == 0);
bool pad_headdim_v = !(param.Kv % FmhaTileShape::kN1 == 0);

BOOL_SWITCH_2(
pad_headdim_q, kPadHeadDimQ, pad_headdim_v, kPadHeadDimV, [&] {
using FmhaTraits = ck_tile::TileFmhaFwdPagedKVTraits<
kPadSeqLenQ,
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
false, // kHasLogitsSoftCap_
kBiasEnum,
false, // kHasBiasGrad place-holder
false, // kStoreLSE
true, // kIsPagedKV
false, // kDoFp8StaticQuant place-holder
occupancy>;

using ODataType = typename FmhaFwdTypeConfig<ScalarType>::ODataType;
using FmhaPipelineProblem = FmhaFwdPagedKVPipelineProblemTemp<
FmhaTraits,
FmhaMask,
ODataType>;

using FmhaPipeline =
ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS<FmhaPipelineProblem>;

using FmhaEpilogue =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename FmhaFwdTypeConfig<ScalarType>::OaccDataType,
ODataType,
false,
false>>;

using FmhaKernel =
ck_tile::FmhaFwdPagedKVKernel<FmhaPipeline, FmhaEpilogue>;

RunWithFwdPagedKVKernel<FmhaKernel>(param, stream);
});
};
};

template <typename FmhaKernel>
static void RunWithFwdPagedKVKernel(
GroupedForwardParams& param,
hipStream_t stream) {
const auto kargs = [&] {
return FmhaKernel::MakeKargs(
param.q_ptr,
param.k_ptr,
param.v_ptr,
param.attn_bias_ptr,
nullptr, // lse_ptr,
param.out_ptr, // o_ptr
param.seqstart_q_dev_ptr,
param.seqstart_k_dev_ptr,
param.seqlen_k_dev_ptr,
param.K, // hdim_q
param.Kv, // hdim_v
param.Hq, // nhead_q
param.Hq / param.Hkv, // nhead_ratio_qk
param.use_paged_kvcache ? param.block_table_ptr : nullptr,
param.use_paged_kvcache ? param.batch_stride_block_table : 0,
param.use_paged_kvcache ? param.page_block_size : 0,
param.use_paged_kvcache ? param.is_gappy : false,
param.scale,
1.0f, // scale_p
1.0f, // scale_o
0, // logits_soft_cap
param.q_strides[0], // q, k, v, bias, out tensor seq-dim
// stride
param.k_strides[0],
param.v_strides[0],
param.attn_bias_strides[2],
param.out_strides[0],
param.q_strides[1], // q, k, v, bias, lse, out tensor
// head-dim stride
param.k_strides[1],
param.v_strides[1],
param.attn_bias_strides[1],
0, // nhead_stride_lse
param.out_strides[1],
param.use_paged_kvcache ? param.k_strides[0] * param.page_block_size
: 0, // batch_stride_k
param.use_paged_kvcache ? param.v_strides[0] * param.page_block_size
: 0, // batch_stride_v
(param.window_size > 0) ? param.window_size - 1
: -1, // window_left_size
(param.custom_mask_type == 0) ? -1 : 0, // window_right_size
param.custom_mask_type,
0); // min_seqlen_q
}();

dim3 kGridSize = FmhaKernel::GridSize(
param.num_batches,
param.Hq,
param.max_seqlen_q,
param.Kv,
kargs.seqlen_k_ptr != nullptr);
constexpr dim3 kBlockSize = FmhaKernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = FmhaKernel::kBlockPerCu;

(void)ck_tile::launch_kernel(
ck_tile::stream_config{stream, false},
ck_tile::make_kernel<kBlockSize.x, kBlockPerCu>(
FmhaKernel{}, kGridSize, kBlockSize, 0, kargs));
};
};
Loading