diff --git a/CHANGELOG.md b/CHANGELOG.md index 28bcaae5b69..6fce19483ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added tensor-wise quantization for CK_TILE GEMM. * Added support for batched contraction kernel. * Added pooling kernel in CK_TILE +* Added top-k sigmoid kernel in CK_TILE ### Changed diff --git a/example/ck_tile/09_topk_softmax/topk_softmax.cpp b/example/ck_tile/09_topk_softmax/topk_softmax.cpp index 0487bd05d23..400329986a4 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax.cpp @@ -83,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor& x, reference_topk(y, y_values, y_indices, k, dim, largest, sorted); } +template +auto reference_topk_sigmoid(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + // topk only - no need to apply the sigmoid first + auto x_fp32 = x.template CopyAsType(); + reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted); + // apply sigmoid + std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) { + return WeightType(1) / (WeightType(1) + exp(-value)); + }); +} + // different threshold for different dtype template auto get_elimit(std::string /*init_method*/) @@ -133,7 +153,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "topk_softmax.json", "json file name to dump results"); + .insert("jsonfile", "topk_softmax.json", "json file name to dump results") + .insert("activation", "softmax", "activation function to use: softmax or sigmoid"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -154,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + std::string activation = args.get_str("activation"); if(stride_input < 0) { @@ -204,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - topk_softmax_trait trait{input_prec, weight_prec, experts}; + topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), value_dev.GetDeviceBuffer(), @@ -221,7 +243,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) warmup, repeat}; auto ms = topk_softmax(trait, karg, sc); - printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ", input_prec.c_str(), weight_prec.c_str(), tokens, @@ -229,6 +251,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) topk, stride_input, stride_output, + activation.c_str(), ms); if(ms < 0) printf("not supported\n"); @@ -247,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args) ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); - reference_topk_softmax( - x_host, value_ref, index_ref, topk); + if(activation == "softmax") + { + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + } + else if(activation == "sigmoid") + { + reference_topk_sigmoid( + x_host, value_ref, index_ref, topk); + } + else + { + throw std::runtime_error("unsupported activation type: " + activation); + } auto [rtol, atol] = get_elimit(""); for(int i_t = 0; i_t < tokens; i_t++) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index 6e6bb20020c..770468d36b9 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -3,27 +3,31 @@ #include "topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_) \ - constexpr ck_tile::index_t ts_experts = experts_; \ - using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ - using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ - \ - using kernel = ck_tile::TopkSoftmaxKernel; \ - \ - auto kargs = kernel::MakeKargs(a); \ - \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(); \ - \ - float ave_time = \ - ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ - \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ + using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ + \ return ave_time; float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) { - if(t.input_type == "fp16" && t.weight_type == "fp32") + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float; @@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c #if 1 if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) } #else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } #endif } - else if(t.input_type == "bf16" && t.weight_type == "fp32") + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") { #if 1 using ts_input_type = ck_tile::bf16_t; @@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c using ts_index_type = ck_tile::index_t; if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) + } +#endif + } + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) } #endif } diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp index 65651efa4d4..c98a887736f 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp @@ -12,6 +12,7 @@ struct topk_softmax_trait std::string input_type; std::string weight_type; // currently always float int experts; + std::string activation; // "softmax" or "sigmoid" }; struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index e8727ea0659..019e940a339 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs index_t num_experts; index_t topk; index_t stride_input; // row stride for input, at least experts - index_t stride_output; // row stride for output/indices, at least tpok + index_t stride_output; // row stride for output/indices, at least topk }; template @@ -45,7 +45,7 @@ struct TopkSoftmaxKernel index_t num_experts; index_t topk; index_t stride_input; // row stride for input, at least experts - index_t stride_output; // row stride for output/indices, at least tpok + index_t stride_output; // row stride for output/indices, at least topk }; using Kargs = TopkSoftmaxKargs; diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp index d620d9bec9c..677263229b1 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp @@ -90,6 +90,11 @@ struct TopkSoftmaxWarpPerRowPipeline const auto current_expert = x_indices.at(number<1>{}); w_(idx) = current_expert >= experts ? -numeric::infinity() : w_(idx); + if constexpr(!Problem::ActivationIsSoftmax) + { + // sigmoid can be pre-computed already here if not using softmax + w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx))); + } }; tile_sweeper ts{w_, w_f}; ts(); @@ -97,10 +102,16 @@ struct TopkSoftmaxWarpPerRowPipeline #endif }(); - // softmax - auto y = softmax(w); - - topk(y, out_win, idx_win, k); + if constexpr(Problem::ActivationIsSoftmax) + { + auto y = softmax(w); + topk(y, out_win, idx_win, k); + } + else + { + // sigmoid was already pre-computed above, so only do topk now + topk(w, out_win, idx_win, k); + } // check exit if constexpr(Problem::LaunchType == 0) diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp index 917096ad5e3..1dc7e9335e0 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp @@ -13,10 +13,11 @@ template 0, persistent #occupancy - index_t BlockSize_ = 256> + bool ActivationIsSoftmax_ = true, // false: sigmoid + index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK + index_t BytesPerIssue_ = sizeof(InputType_), + index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy + index_t BlockSize_ = 256> struct TopkSoftmaxWarpPerRowProblem { // TODO: this kernel only support warp per row @@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem static constexpr index_t BlockSize = BlockSize_; static constexpr index_t WarpSize = get_warp_size(); + static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_; + static_assert(BytesPerIssue % sizeof(InputType) == 0); static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType); static_assert(Experts % VectorSize == 0); diff --git a/test/ck_tile/topk_softmax/test_topk_softmax.hpp b/test/ck_tile/topk_softmax/test_topk_softmax.hpp index 1bb400ad07a..73f07035348 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax.hpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax.hpp @@ -39,6 +39,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor& x, reference_topk(y, y_values, y_indices, k, dim, largest, sorted); } +template +auto reference_topk_sigmoid(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + // topk only - no need to apply the sigmoid first + auto x_fp32 = x.template CopyAsType(); + reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted); + // apply sigmoid + std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) { + return WeightType(1) / (WeightType(1) + exp(-value)); + }); +} + // different threshold for different dtype template auto get_elimit(std::string /*init_method*/) @@ -87,7 +107,8 @@ auto create_args(int argc, char* argv[]) .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("activation", "softmax", "activation function to use: softmax or sigmoid"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -108,6 +129,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + std::string activation = args.get_str("activation"); if(stride_input < 0) { @@ -158,7 +180,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - topk_softmax_trait trait{input_prec, weight_prec, experts}; + topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), value_dev.GetDeviceBuffer(), @@ -175,7 +197,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) warmup, repeat}; auto ms = topk_softmax(trait, karg, sc); - printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ", input_prec.c_str(), weight_prec.c_str(), tokens, @@ -183,6 +205,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) topk, stride_input, stride_output, + activation.c_str(), ms); if(ms < 0) printf("not supported\n"); @@ -201,8 +224,20 @@ bool test_topk_softmax(ck_tile::ArgParser args) ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); - reference_topk_softmax( - x_host, value_ref, index_ref, topk); + if(activation == "softmax") + { + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + } + else if(activation == "sigmoid") + { + reference_topk_sigmoid( + x_host, value_ref, index_ref, topk); + } + else + { + throw std::runtime_error("unsupported activation type: " + activation); + } auto [rtol, atol] = get_elimit(""); for(int i_t = 0; i_t < tokens; i_t++) @@ -255,7 +290,10 @@ int run_gemm_combinations(std::string const& data_type) {"-t=71", "-e=11", "-k=11", "-st_i=30", "-st_o=12"}, {"-t=1", "-e=1", "-k=1"}, {"-t=99", "-e=2", "-k=1", "-st_i=11", "-st_o=5"}, - {"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}}; + {"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}, + {"-t=20", "-e=5", "-k=2", "-activation=sigmoid"}, + {"-t=220", "-e=9", "-k=3", "-activation=sigmoid"}, + {"-t=500", "-e=21", "-k=13", "-activation=sigmoid"}}; bool result = true; std::string pr_i = "-pr_i=" + data_type; diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index 7c90c8200c6..e06935354b0 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -3,27 +3,31 @@ #include "test_topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_) \ - constexpr ck_tile::index_t ts_experts = experts_; \ - using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ - using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ - \ - using kernel = ck_tile::TopkSoftmaxKernel; \ - \ - auto kargs = kernel::MakeKargs(a); \ - \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(); \ - \ - float ave_time = \ - ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ - \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ + using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ + \ return ave_time; float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) { - if(t.input_type == "fp16" && t.weight_type == "fp32") + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float; @@ -31,36 +35,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c #if 1 if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) } #else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } #endif } - else if(t.input_type == "bf16" && t.weight_type == "fp32") + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") { #if 1 using ts_input_type = ck_tile::bf16_t; @@ -68,27 +72,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c using ts_index_type = ck_tile::index_t; if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) + } +#endif + } + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) } #endif } diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp index 65651efa4d4..c98a887736f 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp @@ -12,6 +12,7 @@ struct topk_softmax_trait std::string input_type; std::string weight_type; // currently always float int experts; + std::string activation; // "softmax" or "sigmoid" }; struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs