-
Notifications
You must be signed in to change notification settings - Fork 248
[CK_BUILDER] First fwd convolution builder implementation #3070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 24 commits
c632bba
db25611
121884d
c9466c8
51b76a7
79f057b
63a9d9f
289990f
dd7a6ed
fc258eb
9c0fdff
25837b4
3aaf8b9
7b2a622
11e71ab
7b89486
16df5ba
c6a1fa4
6cf8cc1
c76954b
c3f5097
28f6707
fc5caa1
37e5aee
6ade5a1
806ddac
2f2e86e
89b7954
faf7811
273d50a
5f60df4
275e688
9d3f88c
c388a87
3a33509
d3fce7b
b2a13a3
9c5f262
7987f07
7df49a8
60b265b
6db2117
90357f6
f005d65
2e9d840
3385dc2
0cea23e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include "ck/utility/sequence.hpp" | ||
| #include "ck_tile/builder/types.hpp" | ||
|
|
||
| namespace ck_tile::builder { | ||
|
|
||
| // Convert a static array to a sequence | ||
| // Usage example: | ||
| // static constexpr std::vector arr {1, 2, 3}; | ||
| // using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3> | ||
| template <typename T, const T& Arr> | ||
| struct to_sequence_t | ||
| { | ||
| private: | ||
| template <std::size_t... Is> | ||
| static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>; | ||
|
|
||
| // Helper method to handler the unusual .Size() method name in ck::Array. | ||
| static constexpr auto get_size(const auto& arr) | ||
| { | ||
| if constexpr(requires { arr.size(); }) | ||
| { | ||
| return arr.size(); | ||
| } | ||
| else | ||
| { | ||
| return arr.Size(); | ||
| } | ||
| } | ||
|
|
||
| public: | ||
| using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{})); | ||
| }; | ||
|
|
||
| template <auto& Arr> | ||
| using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value; | ||
|
|
||
| // Wrapper function to make constexpr strings a structural type for NTTP. | ||
| template <size_t N> | ||
| struct StringLiteral | ||
| { | ||
| char data[N]; | ||
| constexpr StringLiteral(const char (&str)[N]) | ||
| { | ||
| for(size_t i = 0; i < N; ++i) | ||
| data[i] = str[i]; | ||
| } | ||
|
|
||
| constexpr bool operator==(const StringLiteral<N>& other) const | ||
| { | ||
| for(size_t i = 0; i < N; ++i) | ||
| { | ||
| if(data[i] != other.data[i]) | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
| }; | ||
|
|
||
| // This is a C++17 deduction guide. It allows the compiler to automatically | ||
| // deduce the template argument `N` for `StringLiteral` from a string literal | ||
| // constructor argument. For example, you can write `StringLiteral s{"foo"};` | ||
| // instead of `StringLiteral<4> s{"foo"};`. | ||
| template <size_t N> | ||
| StringLiteral(const char (&)[N]) -> StringLiteral<N>; | ||
|
|
||
| // Helper to provide a readable error for unsupported enum values. | ||
| // The compiler will print the name of this struct in the error message, so | ||
| // the name of the enum value will appear instead of just its integer value. | ||
| template <auto T> | ||
| struct UnsupportedEnumValue | ||
| { | ||
| }; | ||
|
|
||
| // Helper functions to convert enums to strings | ||
| constexpr std::string_view ConvDirectionToString(ConvDirection dir) | ||
| { | ||
| switch(dir) | ||
| { | ||
| case ConvDirection::FORWARD: return "Forward"; | ||
| case ConvDirection::BACKWARD_DATA: return "Backward Data"; | ||
| case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight"; | ||
| default: return "Unknown"; | ||
| } | ||
| } | ||
|
|
||
| constexpr std::string_view DataTypeToString(DataType dt) | ||
| { | ||
| switch(dt) | ||
| { | ||
| case DataType::FP16: return "FP16"; | ||
| case DataType::FP32: return "FP32"; | ||
| case DataType::BF16: return "BF16"; | ||
| case DataType::FP8: return "FP8"; | ||
| case DataType::I8: return "I8"; | ||
| default: return "Unknown"; | ||
| } | ||
| } | ||
|
|
||
| constexpr std::string_view LayoutToString(GroupConvLayout layout) | ||
| { | ||
| switch(layout) | ||
| { | ||
| case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)"; | ||
|
||
| case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)"; | ||
| default: return "Unknown"; | ||
| } | ||
| } | ||
|
|
||
| } // namespace ck_tile::builder | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <type_traits> | ||
| #include <concepts> | ||
| #include <array> | ||
|
|
||
| #include "ck_tile/builder/types.hpp" | ||
|
|
||
| namespace ck_tile::builder { | ||
|
|
||
| /********************************************************************/ | ||
| /* Descriptors for individual elements of the algorithm description */ | ||
| /********************************************************************/ | ||
|
|
||
| // Concept for thread block dimensions for a GEMM problem. | ||
| template <typename T> | ||
| concept ThreadBlockDescriptor = requires(T t) { | ||
| { t.block_size } -> std::convertible_to<size_t>; | ||
| { t.tile_size.m } -> std::convertible_to<size_t>; | ||
| { t.tile_size.n } -> std::convertible_to<size_t>; | ||
| { t.tile_size.k } -> std::convertible_to<size_t>; | ||
| }; | ||
|
|
||
| // Concept for parameters that describe a gridwise GEMM problem. | ||
| template <typename T> | ||
| concept GridwiseGemmDescriptor = requires(T t) { | ||
| { t.ak1 } -> std::convertible_to<size_t>; | ||
| { t.bk1 } -> std::convertible_to<size_t>; | ||
| { t.m_per_xdl } -> std::convertible_to<size_t>; | ||
| { t.n_per_xdl } -> std::convertible_to<size_t>; | ||
| { t.m_xdl_per_wave } -> std::convertible_to<size_t>; | ||
| { t.n_xdl_per_wave } -> std::convertible_to<size_t>; | ||
| }; | ||
|
|
||
| // Concept for convolution input block transfer. | ||
| template <typename T> | ||
| concept InputBlockTransferDescriptor = requires(T t) { | ||
| { t.k0 } -> std::convertible_to<size_t>; | ||
| { t.m_n } -> std::convertible_to<size_t>; | ||
| { t.k1 } -> std::convertible_to<size_t>; | ||
| }; | ||
|
|
||
| // Concept for output block transfer. | ||
| template <typename T> | ||
| concept OutputBlockTransferDescriptor = requires(T t) { | ||
| { t.m_block } -> std::convertible_to<size_t>; | ||
| { t.m_wave_per_xdl } -> std::convertible_to<size_t>; | ||
| { t.n_block } -> std::convertible_to<size_t>; | ||
| { t.n_wave_per_xdl } -> std::convertible_to<size_t>; | ||
| }; | ||
|
|
||
| // Concept for the convolution input vector transfer. | ||
| template <typename T> | ||
| concept InputVectorTransferDescriptor = requires(T t) { | ||
| { t.src_vector_dim } -> std::convertible_to<size_t>; | ||
| { t.src_scalar_per_vector } -> std::convertible_to<size_t>; | ||
vpietila-amd marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| { t.dest_scalar_per_vector_k1 } -> std::convertible_to<size_t>; | ||
aosewski marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| { t.add_extra } -> std::convertible_to<bool>; | ||
|
||
| }; | ||
|
|
||
| // Concepts for the convolution output vector transfer. | ||
| template <typename T> | ||
| concept OutputVectorTransferDescriptor = requires(T t) { | ||
| { t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>; | ||
| { t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>; | ||
| { t.scalar_per_vector } -> std::convertible_to<size_t>; | ||
| }; | ||
|
|
||
| // Concept for the thread cluster access order | ||
| template <typename T> | ||
| concept AccessOrderDescriptor = requires(T t) { | ||
| { t.order } -> std::convertible_to<std::array<size_t, 3>>; | ||
| }; | ||
|
|
||
| // No requirements yet for a ConvAlogorithm concept. | ||
| template <typename T> | ||
| concept ConvAlgorithmDescriptor = std::is_class_v<T>; | ||
|
|
||
| /******************************************** */ | ||
| /* Requirements for the algorithm description */ | ||
| /******************************************** */ | ||
|
|
||
| // Concept to check if struct specifies thread block info. | ||
| template <typename T> | ||
| concept SpecifiesThreadBlock = requires { | ||
| { T::thread_block } -> ThreadBlockDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if a struct specifies gridwise GEMM info. | ||
| template <typename T> | ||
| concept SpecifiesGridwiseGemm = requires { | ||
| { T::tuning_params } -> GridwiseGemmDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if a struct specifies convolution input and output block transfer info. | ||
| template <typename T> | ||
| concept SpecifiesBlockTransfer = requires(T t) { | ||
| { T::block_transfer.thread_cluster_dims_a } -> InputBlockTransferDescriptor; | ||
| { T::block_transfer.thread_cluster_dims_b } -> InputBlockTransferDescriptor; | ||
| { T::block_transfer.thread_cluster_dims_c } -> OutputBlockTransferDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if a struct specifies block vector transfer info. | ||
| template <typename T> | ||
| concept SpecifiesBlockVectorTransfer = requires(T t) { | ||
| { T::block_transfer.vector_transfer_a } -> InputVectorTransferDescriptor; | ||
| { T::block_transfer.vector_transfer_b } -> InputVectorTransferDescriptor; | ||
| { T::block_transfer.vector_transfer_c } -> OutputVectorTransferDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if a struct specifies thread cluster access order info. | ||
| template <typename T> | ||
| concept SpecifiesThreadClusterAccessOrder = requires(T t) { | ||
| { T::block_transfer.thread_cluster_access_order_a } -> AccessOrderDescriptor; | ||
| { T::block_transfer.thread_cluster_access_order_b } -> AccessOrderDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if a struct specifies source access order info. | ||
| template <typename T> | ||
| concept SpecifiesSourceAccessOrder = requires(T t) { | ||
| { T::block_transfer.src_access_order_a } -> AccessOrderDescriptor; | ||
| { T::block_transfer.src_access_order_b } -> AccessOrderDescriptor; | ||
| }; | ||
|
|
||
| // Concept to check if struct specifies block_gemm_pipeline_version. | ||
| template <typename T> | ||
| concept SpecifiesGemmPipelineVersion = requires { | ||
| { T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>; | ||
| }; | ||
|
|
||
| template <typename T> | ||
| concept SpecifiesFwdConcSpecialization = requires { | ||
| { T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>; | ||
| }; | ||
|
|
||
| } // namespace ck_tile::builder | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <type_traits> | ||
| #include <concepts> | ||
|
|
||
| namespace ck_tile::builder { | ||
|
|
||
| // Limits for input vector transfer. | ||
| template <auto Value> | ||
| concept InputVectorTransferLimits = requires { | ||
| requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 && | ||
| Value.dest_scalar_per_vector_k1 > 0; | ||
| }; | ||
|
|
||
| // Limits for output vector transfer. | ||
| template <auto Value> | ||
| concept OutputVectorTransferLimits = requires { | ||
| requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 && | ||
| Value.n_xdl_per_wave_per_shuffle > 0; | ||
| }; | ||
|
|
||
| // Limits for access order. Must be a permutation of {0, 1, 2}. | ||
| template <auto Value> | ||
| concept AccessOrderLimits = requires { | ||
| requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) && | ||
| (Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) && | ||
| (Value[2] >= 0 && Value[2] < 3)); | ||
| }; | ||
|
|
||
| } // namespace ck_tile::builder |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| // SPDX-License-Identifier: MIT | ||
| // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <concepts> | ||
| #include <type_traits> | ||
|
|
||
| #include "ck_tile/builder/conv_factory.hpp" | ||
| #include "ck_tile/builder/versions.hpp" | ||
|
|
||
| namespace ck_tile::builder { | ||
|
|
||
| /** | ||
| * @brief Top-level builder for creating convolution kernel instances. | ||
| * | ||
| * This struct serves as the main entry point for generating a convolution kernel. | ||
| * It uses a factory pattern based on the provided signature, algorithm, and version | ||
| * to construct the appropriate kernel instance. | ||
| * | ||
| * @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of | ||
| * the algorithm (e.g., data types, layouts, direction). | ||
| * @tparam ALGORITHM The specific convolution algorithm to be used for the implementation. | ||
| * @tparam VERSION The version of the builder implementation. | ||
| */ | ||
| template <ConvSignatureDescriptor auto SIGNATURE, | ||
| ConvAlgorithmDescriptor auto ALGORITHM, | ||
| StringLiteral VERSION = LATEST_API_VERSION> | ||
| requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE> | ||
| struct ConvBuilder | ||
| { | ||
| static constexpr auto kVersion = VERSION; | ||
| using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>; | ||
| // Output: The kernel class. | ||
| using Instance = Factory::Instance; | ||
| }; | ||
|
|
||
| } // namespace ck_tile::builder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is "S8" more common (signed eight bit integer)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to
S8. Although we are not yet using the type.