Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions example/91_tile_program/fmha/fmha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ auto create_args(int argc, char* argv[])
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(seqlen)")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("descale_q", "1", "scale factor for fp8 quantization")
.insert("descale_k", "1", "scale factor for fp8 quantization")
.insert("descale_v", "1", "scale factor for fp8 quantization")
Expand All @@ -68,6 +68,7 @@ auto create_args(int argc, char* argv[])
"'g:y,x', generic attention mask coordinate with y/x size\n")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init", "1", "init method. 0:random int, 1:random float, 2:trig float")
.insert("seed",
"11939",
Expand Down Expand Up @@ -157,8 +158,9 @@ bool run(const ArgParser& arg_parser)

int stream_warmup = env_get_int("CK_WARMUP", 5);
int stream_repeat = env_get_int("CK_REPEAT", 20);
bool kname = arg_parser.get_bool("kname");

StreamConfig stream_config{nullptr, true, 0, stream_warmup, stream_repeat};
StreamConfig stream_config{nullptr, true, kname ? 1 : 0, stream_warmup, stream_repeat};

const auto [seqlens_q, seqstart_q_host] = generate_seqlens_seqstarts_q(mode, batch, seqlen_q);
const std::vector<int32_t> seqstart_k_host =
Expand Down Expand Up @@ -296,9 +298,15 @@ bool run(const ArgParser& arg_parser)
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias
<< ", lse:" << lse << ", mask:" << mask << ", v:" << vlayout << std::flush;

auto fmha_traits = fmha_fwd_traits{
hdim_q, data_type, mode == mode_enum::group, is_v_rowmajor, mask.type, use_bias, lse};
auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(),
auto fmha_traits = fmha_fwd_traits{hdim_q,
hdim_v,
data_type,
mode == mode_enum::group,
is_v_rowmajor,
mask.type,
use_bias,
lse};
auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
Expand Down Expand Up @@ -440,11 +448,11 @@ bool run(const ArgParser& arg_parser)

auto [rtol, atol] = get_elimit<DataType>(init_method);
bool cur_pass = ck::utils::check_err(
o_host_result, o_host_ref, std::string("O Error: Incorrect results!"), rtol, atol);
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
pass &= cur_pass;
if(!cur_pass)
{
std::cerr << "O mismatch found at batch: " << wb << std::endl
std::cerr << "OUT mismatch found at batch: " << wb << std::endl
<< "\tseqlen_q: " << real_seqlen_q << std::endl
<< "\tseqlen_k: " << real_seqlen_k << std::endl
<< "\tseqstart_q: " << seqstart_q_host << std::endl
Expand Down
42 changes: 32 additions & 10 deletions example/91_tile_program/fmha/fmha_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,23 +290,43 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.mask_x);
}

// this is internal API, will be generated across different files to speedup compile
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck::index_t kM0_,
ck::index_t kN0_,
ck::index_t kK0_,
ck::index_t kN1_,
ck::index_t kK1_,
ck::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
typename FmhaMask_,
bool kHasBias_,
bool kStoreLse_>
bool kStoreLse_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_fwd_traits_
{
static constexpr ck::index_t HDim = HDim_;
using DataType = ck::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr ck::index_t HDim = HDim_;
using DataType = ck::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck::index_t kM0 = kM0_;
static constexpr ck::index_t kN0 = kN0_;
static constexpr ck::index_t kK0 = kK0_;
static constexpr ck::index_t kN1 = kN1_;
static constexpr ck::index_t kK1 = kK1_;
static constexpr ck::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
using FmhaMask = ck::remove_cvref_t<FmhaMask_>;
static constexpr bool kHasBias = kHasBias_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};

template <typename Traits_>
Expand All @@ -315,12 +335,14 @@ float fmha_fwd_(const StreamConfig&, fmha_fwd_args);
// This is the public API, will be generated by script
struct fmha_fwd_traits
{
int hdim;
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
mask_enum mask_type;
bool has_bias;
bool has_lse;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const StreamConfig&);
27 changes: 20 additions & 7 deletions example/91_tile_program/fmha/fmha_fwd_epilogue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"

template <typename OaccDataType_, typename ODataType_>
template <typename OaccDataType_, typename ODataType_, bool kPadSeqLenQ_, bool kPadHeadDimV_>
struct FmhaFwdEpilogueProblem
{
using OaccDataType = ck::remove_cvref_t<OaccDataType_>;
using ODataType = ck::remove_cvref_t<ODataType_>;
using OaccDataType = ck::remove_cvref_t<OaccDataType_>;
using ODataType = ck::remove_cvref_t<ODataType_>;
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
};

template <typename Problem_, typename Policy_ = void>
struct FmhaFwdEpilogue
{
using Problem = ck::remove_cvref_t<Problem_>;
using OaccDataType = ck::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck::remove_cvref_t<typename Problem::ODataType>;
using Problem = ck::remove_cvref_t<Problem_>;
using OaccDataType = ck::remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = ck::remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;

__host__ __device__ static constexpr ck::index_t GetSmemSize() { return 0; }

Expand All @@ -29,6 +33,15 @@ struct FmhaFwdEpilogue
using namespace ck;
using namespace ck::tile_program;

store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
// TODO: this is ugly
if constexpr(kPadSeqLenQ || kPadHeadDimV)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
buffer_store_fence();
}
else
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
};
53 changes: 47 additions & 6 deletions example/91_tile_program/fmha/fmha_fwd_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <type_traits>
#include <string>

#include "ck/utility/common_header.hpp"
#include "ck/tensor/tensor_view.hpp"
Expand Down Expand Up @@ -44,6 +45,46 @@ struct FmhaFwdKernel
using FmhaMask = ck::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking;

// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck::half_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck::bhalf_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck::f8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on

__host__ static std::string GetName()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GetName() can be implemented by calling something like miopen::get_type_name(). And we need another name for this function (maybe GetEncodedName()?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of the naming inside c++ is can print out something can help debug. This is not the symbol name(though we can mock a symbol name inside generate.py). And the name should have all the information to distinguish between different type of kernels, so it could have lot of code. The pro inside this kernel template is we can reuse this if not using our generate.py system.
And yes, if using GetEncodedName() is OK

{
// sync with generate.py
// clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile;
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_d") + _TS_(bfs::kK0BlockLength) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
"r" + _TS_(gbr::At(ck::Number<0>{})) + "x" + _TS_(gbr::At(ck::Number<1>{})) + "x" + _TS_(gbr::At(ck::Number<2>{})) + "_" +
"w" + _TS_(gwt::At(ck::Number<0>{})) + "x" + _TS_(gwt::At(ck::Number<1>{})) + "x" + _TS_(gwt::At(ck::Number<2>{})) + "_" +
"o" + _TS_(kBlockPerCu) + "_" + _SS_(FmhaPipeline::name) + "_" +
"v" + (ck::is_same_v<VLayout, ck::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}

template <ck::index_t I> // to avoid duplicated base class prblem, introduce an template arg
struct FmhaFwdEmptyKargs
{
Expand Down Expand Up @@ -447,7 +488,7 @@ struct FmhaFwdKernel
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentQ>{},
Number<1>{});
if constexpr(FmhaPipeline::kQLoadOnce)
{
Expand All @@ -469,7 +510,7 @@ struct FmhaFwdKernel
k_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentK>{},
Number<1>{});

return pad_tensor_view(
Expand All @@ -484,7 +525,7 @@ struct FmhaFwdKernel
v_ptr,
make_tuple(kargs.seqlen_k, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentV>{},
Number<1>{});

const auto v_dram_transposed =
Expand All @@ -505,7 +546,7 @@ struct FmhaFwdKernel
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_k),
make_tuple(kargs.stride_v, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentV>{},
Number<1>{});

return pad_tensor_view(
Expand Down Expand Up @@ -551,7 +592,7 @@ struct FmhaFwdKernel
bias_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_bias, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentBias>{},
Number<1>{});

return pad_tensor_view(bias_dram_naive,
Expand Down Expand Up @@ -636,7 +677,7 @@ struct FmhaFwdKernel
o_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_v),
make_tuple(kargs.stride_o, 1),
Number<32>{},
Number<FmhaPipeline::kAlignmentO>{},
Number<1>{});

return pad_tensor_view(
Expand Down
Loading