-
Couldn't load subscription status.
- Fork 247
[CK_TILE] Stream-K operator() Reboot #3064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a refactored Stream-K operator implementation ("reboot" namespace) that separates Persistent and Non-Persistent data parallel execution paths. The changes build upon a previous PR by factoring the Stream-K algorithm into BaseGemm() and StreamKGemm() functions, with dedicated operator() implementations for each execution mode.
Key changes:
- Added new
BaseGemm()andStreamKGemm()methods to separate standard GEMM execution from Stream-K scheduling logic - Implemented separate operator() functions for Persistent and Non-Persistent modes using template SFINAE
- Added
estimate_num_wgs_per_tile()to the tile partitioner for better test accuracy calculations
Reviewed Changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp | Added test configuration for SK-only case with 2 workgroups per tile |
| test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp | Added unit tests for estimate_num_wgs_per_tile() method |
| test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp | Added test utility class and helper functions for reboot Stream-K tests |
| test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp | Implemented get_cu_count() helper function |
| test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp | Defined type aliases and test parameter tuples for FP16/BF16 tests |
| test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_smoke_cases.inc | Added smoke test cases for edge cases, DP-only, and SK-only scenarios |
| test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_extended_cases.inc | Added extended test case for DP + 2-Tile SK scenario |
| test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp | FP16 persistent mode smoke tests |
| test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp | FP16 non-persistent mode smoke tests |
| test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp | BF16 persistent mode smoke tests |
| test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp | BF16 non-persistent mode smoke tests |
| test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp | FP16 persistent mode extended tests |
| test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp | FP16 non-persistent mode extended tests |
| test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp | BF16 persistent mode extended tests |
| test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp | BF16 non-persistent mode extended tests |
| test/ck_tile/gemm_streamk/CMakeLists.txt | Added build targets for smoke and extended tests |
| test/CMakeLists.txt | Added extended tests to regression test list |
| include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp | Implemented estimate_num_wgs_per_tile() method |
| include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp | Added declaration and PERSISTENT constants to tile partitioner |
| include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp | Added reboot namespace with new StreamKKernel implementation |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_extended_cases.inc
Outdated
Show resolved
Hide resolved
|
@arai713 I've started a new build so we can check CI. Once that passes, I will approve! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Nice work!
23fb579 to
cf78dce
Compare
This change implements an operator() function in the reboot::StreamKKernel class that is enabled when the Persistent flag is set to true. In this case, the data-parallel portion and the Stream-K portion of the kernel are fully persistent. The changes were made in the reboot namespace. A future PR will remove the old Stream-K kernel class and remove the reboot namespace.
This change contains the inital test suite for the Persitent Stream-K Kernel. The files contain "reboot" in the name; a future PR will remove tests for the old Stream-K Kernel and remove the "reboot" naming. A future commit will add tests for the non-persistent kernel. Also added estimate_num_wgs_per_tile to the StreamKTilePartitionerBase class. This allows us to estimate the number of accumulations done per macro tile in C to use during validation when computing relative and absolute tolerance.
This code is adding the operator() function for the Non-Persistent Stream-K kernel. Persistency of the kernel is determined through a template argument. The Non-Persistent kernel will allocate additional workgroups for the data parallel section, leading to a different structure for processing the data parallel and Stream-K sections. There has been an addition to the TilePartitioner to get access to the whether Persistent has been set to true or false in the StreamKKernel.
This commit makes the following changes: - Update test cases to determine M, N, and K based on the number of CUs. This ensures that each test case is one of Edge Case, SK Only, DP Only, or DP + 2 Tile SK regardless of the architecture. - Since the DP + 2 Tile SK test case takes long to run, this change moves this case into a separate .inc file and labels it as an extended test. - Since the extended test takes > 30 seconds to run, this test is added to the list of regression tests.
Co-authored-by: Copilot <[email protected]>
Removed const volatile for typenames Set up alias for is_tuple_t Naming changes for clarity: GemmCommon -> BaseGemm Moved std::enable_if_t out of template parameters and changed to a return type for operator() Added constructor for StreamKKernelArgs to clarify UniversalGemm inheritance
cf78dce to
cc6e4cc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Nice work!
Proposed changes
These changes follows a previous PR which introduced a new Stream-K Tile Partitioner that is better aligned with the original implementation detailed in the Stream-K paper. This PR, done in collaboration with @ecamartins, follow that work and introduce two new operator() functions to accommodate Persistent and Non-Persistent data parallel sections of the Stream-K kernel. They are all added into a new namespace called reboot, which will be removed along with the older implementation of the Stream-K operator() subsequently.
Previously we had only implemented the Non-Persistent version, and all the code for the Stream-K algorithm was in a single operator() function. Now we factor the code into GemmCommon() and StreamKGemm(): GemmCommon runs the standard UniversalGemm, while StreamKGemm implements Stream-K scheduling and calls GemmCommon. There are 2 operator() functions: one for Persistent, one for Non-Persistent. This is controlled by the boolean template argument Persistent. Based on if Persistent is true/false, each operator() schedules and calls GemmCommon and StreamKGemm accordingly
Unit tests have been added for fp16/bf16 datatypes. The cases for DP + 2 Tile SK have been added as a regression test, as they took greater than 30 seconds to run.
Note: After addressing review comments, the function name
GemmCommonwas changed toBaseGemm.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed files