Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,13 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t BlockSize = 256;
static constexpr ck::index_t BlockSize = 128;
static constexpr bool MulRoutedWeight = true;

// return ck_moe_stage1_gemm<{A0DataType}, {B0DataType}, {AccDataType}, {EDataType}, {CDEElementOp}, V3, 256,
// 64, 128, 128/sizeof({A0DataType}), 1, 4, {Nswizzle}, {Quant} == static_cast<int>(QuantType::per_Tensor), {MulRoutedWeight}, {ActOP}>;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBNS<
A0Layout, B0Layout, DsLayout, ELayout,
Expand All @@ -162,10 +164,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MPerBlock, NPerBlock, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
2, 2,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 4>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
Expand All @@ -178,16 +180,12 @@ int main(int argc, char* argv[])

// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;

ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
ck::index_t N = 256;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t tokens = 64;
ck::index_t topk = 8;

if(argc == 1)
{
Expand Down Expand Up @@ -223,6 +221,10 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};

ck::index_t sorted_tile_num = experts > tokens * topk ? experts : tokens * topk;
ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
Expand All @@ -246,14 +248,16 @@ int main(int argc, char* argv[])

for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
expert_ids.mData[i] = i / (valid_tile_num / experts);
}

int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
Expand Down Expand Up @@ -461,10 +465,11 @@ int main(int argc, char* argv[])
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;

int valid_expert = tokens * topk < experts? tokens * topk : experts;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(B0DataType) * K * N * 2 / 2 * valid_expert +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * valid_expert +
sizeof(EDataType) * tokens * topk * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
Expand Down
45 changes: 22 additions & 23 deletions example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecializatio

constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr ck::index_t KPerBlock = 512 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;

// clang-format off
Expand All @@ -190,12 +190,12 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffl
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
Expand All @@ -208,16 +208,11 @@ int main(int argc, char* argv[])

// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;

ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
ck::index_t N = 256;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t tokens = 64;
ck::index_t topk = 8;

if(argc == 1)
{
Expand Down Expand Up @@ -253,6 +248,10 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};

ck::index_t sorted_tile_num = experts > tokens * topk ? experts : tokens * topk;
ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
Expand All @@ -273,17 +272,18 @@ int main(int argc, char* argv[])
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}

for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
expert_ids.mData[i] = i / (valid_tile_num / experts);
}

int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
Expand All @@ -293,7 +293,6 @@ int main(int argc, char* argv[])
sorted_token_ids.mData[i] = tokens;
}
}

Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Expand Down Expand Up @@ -492,11 +491,11 @@ int main(int argc, char* argv[])
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;

int valid_expert = tokens * topk < experts? tokens * topk : experts;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(B0DataType) * K * N * 2 / 2 * valid_expert +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * valid_expert +
sizeof(EDataType) * tokens * topk * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
Expand Down
54 changes: 19 additions & 35 deletions example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4_bns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2

static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;

// clang-format off
Expand All @@ -156,10 +156,10 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on

Expand All @@ -171,16 +171,11 @@ int main(int argc, char* argv[])

// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;

ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
ck::index_t N = 256;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t tokens = 64;
ck::index_t topk = 8;

if(argc == 1)
{
Expand Down Expand Up @@ -216,6 +211,10 @@ int main(int argc, char* argv[])
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};

ck::index_t sorted_tile_num = experts > tokens * topk ? experts : tokens * topk;
ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
Expand All @@ -231,34 +230,18 @@ int main(int argc, char* argv[])
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
expert_ids.mData[i] = i / (valid_tile_num / experts);
}

for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;

for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
Expand Down Expand Up @@ -457,10 +440,11 @@ int main(int argc, char* argv[])
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;

int valid_expert = tokens * topk < experts? tokens * topk : experts;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(A0DataType) * tokens * K * topk/ 2 + sizeof(B0DataType) * K * N * valid_expert / 2 +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
sizeof(XDataType) * K / ScaleBlockSize * N * valid_expert + sizeof(EDataType) * tokens * N;

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ constexpr ck::index_t DataPackedSize = 2; // Packed represent
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2

static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t MPerBlock = 32;
static constexpr bool MulRoutedWeight = true;

// clang-format off
Expand All @@ -189,7 +189,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
8, 2,
2, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
Expand All @@ -204,16 +204,16 @@ int main(int argc, char* argv[])

// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = 13;
constexpr ck::index_t sorted_tile_num = 256;
constexpr ck::index_t valid_tile_num = 256;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;

ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
ck::index_t N = 7168;
ck::index_t K = 256;
ck::index_t experts = 256;
ck::index_t tokens = 64;
ck::index_t topk = 8;

if(argc == 1)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,8 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3<
});
});
});

HotLoopScheduler();
if constexpr(MPerBlock >= 64)
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};

Expand Down
Loading
Loading