Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added tensor-wise quantization for CK_TILE GEMM.
* Added support for batched contraction kernel.
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE

### Changed

Expand Down
45 changes: 40 additions & 5 deletions example/ck_tile/09_topk_softmax/topk_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}

template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_sigmoid(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;

// topk only - no need to apply the sigmoid first
auto x_fp32 = x.template CopyAsType<float>();
reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted);
// apply sigmoid
std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) {
return WeightType(1) / (WeightType(1) + exp(-value));
});
}

// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
Expand Down Expand Up @@ -133,7 +153,8 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "topk_softmax.json", "json file name to dump results");
.insert("jsonfile", "topk_softmax.json", "json file name to dump results")
.insert("activation", "softmax", "activation function to use: softmax or sigmoid");

bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
Expand All @@ -154,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
std::string activation = args.get_str("activation");

if(stride_input < 0)
{
Expand Down Expand Up @@ -204,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args)

x_dev.ToDevice(x_host.data());

topk_softmax_trait trait{input_prec, weight_prec, experts};
topk_softmax_trait trait{input_prec, weight_prec, experts, activation};

topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
value_dev.GetDeviceBuffer(),
Expand All @@ -221,14 +243,15 @@ bool test_topk_softmax(ck_tile::ArgParser args)
warmup,
repeat};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
experts,
topk,
stride_input,
stride_output,
activation.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
Expand All @@ -247,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args)
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});

reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
if(activation == "softmax")
{
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else if(activation == "sigmoid")
{
reference_topk_sigmoid<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
}
else
{
throw std::runtime_error("unsupported activation type: " + activation);
}

auto [rtol, atol] = get_elimit<InputType>("");
for(int i_t = 0; i_t < tokens; i_t++)
Expand Down
135 changes: 104 additions & 31 deletions example/ck_tile/09_topk_softmax/topk_softmax_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,92 +3,165 @@

#include "topk_softmax_api.hpp"

#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \
constexpr ck_tile::index_t ts_experts = experts_; \
constexpr bool ts_use_softmax = use_softmax_; \
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
ts_weight_type, \
ts_index_type, \
ts_experts, \
ts_use_softmax>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;

float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
TOPK_SOFTMAX_DISPATCH(8, true)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
TOPK_SOFTMAX_DISPATCH(16, true)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
TOPK_SOFTMAX_DISPATCH(32, true)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
TOPK_SOFTMAX_DISPATCH(64, true)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
TOPK_SOFTMAX_DISPATCH(128, true)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
TOPK_SOFTMAX_DISPATCH(192, true)
}
#endif
}
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8, false)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16, false)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32, false)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64, false)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128, false)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192, false)
}
#endif
}
Expand Down
1 change: 1 addition & 0 deletions example/ck_tile/09_topk_softmax/topk_softmax_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ struct topk_softmax_trait
std::string input_type;
std::string weight_type; // currently always float
int experts;
std::string activation; // "softmax" or "sigmoid"
};

struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
index_t stride_output; // row stride for output/indices, at least topk
};

template <typename Pipeline_>
Expand All @@ -45,7 +45,7 @@ struct TopkSoftmaxKernel
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
index_t stride_output; // row stride for output/indices, at least topk
};

using Kargs = TopkSoftmaxKargs;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,28 @@ struct TopkSoftmaxWarpPerRowPipeline
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
if constexpr(!Problem::ActivationIsSoftmax)
{
// sigmoid can be pre-computed already here if not using softmax
w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx)));
}
};
tile_sweeper ts{w_, w_f};
ts();
return w_;
#endif
}();

// softmax
auto y = softmax(w);

topk(y, out_win, idx_win, k);
if constexpr(Problem::ActivationIsSoftmax)
{
auto y = softmax(w);
topk(y, out_win, idx_win, k);
}
else
{
// sigmoid was already pre-computed above, so only do topk now
topk(w, out_win, idx_win, k);
}

// check exit
if constexpr(Problem::LaunchType == 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ template <typename InputType_,
typename WeightType_,
typename IndexType_,
index_t Experts_,
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
bool ActivationIsSoftmax_ = true, // false: sigmoid
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
Expand All @@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();

static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_;

static_assert(BytesPerIssue % sizeof(InputType) == 0);
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
static_assert(Experts % VectorSize == 0);
Expand Down
Loading