Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
472 changes: 472 additions & 0 deletions include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
*/
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;

/**
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
*/
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;

protected:
index_t num_tiles_;
index_t grid_;
Expand Down Expand Up @@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
ck_tile::index_t grid);

public:
static constexpr bool PERSISTENT = true;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
* case, no extra workgroups are allocated for the data parallel section, making the grid
Expand Down Expand Up @@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
ck_tile::index_t grid);

public:
static constexpr bool PERSISTENT = false;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
* case, extra workgroups are allocated for the data parallel section, making the grid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
return n_;
}

template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
const noexcept
{
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
// writing final results to a given macro tile in C.
int num_wgs_per_tile = 1;

// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
}

return std::max(num_wgs_per_tile, 1);
}

template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType,
bool Persistent>
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ set(REGRESSION_TESTS
test_ck_tile_fmha_fwd_bf16
test_ck_tile_fmha_fwd_fp16
test_ck_tile_fmha_fwd_fp8
test_ck_tile_streamk_reboot_extended
)

function(add_test_executable TEST_NAME)
Expand Down
12 changes: 12 additions & 0 deletions test/ck_tile/gemm_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ if(GPU_TARGETS MATCHES "gfx9")
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
# )
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_extended
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
else()
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
endif()
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);

#include "test_gemm_streamk_reboot_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);

#include "test_gemm_streamk_reboot_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);

#include "test_gemm_streamk_reboot_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);

#include "test_gemm_streamk_reboot_extended_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);

#include "test_gemm_streamk_reboot_smoke_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);

#include "test_gemm_streamk_reboot_smoke_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);

#include "test_gemm_streamk_reboot_smoke_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"

template <typename Tuple>
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};

#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent

TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);

#include "test_gemm_streamk_reboot_smoke_cases.inc"

#undef TEST_SUITE_NAME
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

TYPED_TEST(TEST_SUITE_NAME, StreamK_DP2TSK)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;

// For DP 2-Tile SK, there are 2 important terms:
// Term 1: (M_Tile * num_cu * 2) - This ensures we have at least 2 cycles that will fully
// saturate all CUs. This assumes tile sizes are large enough such that occupancy is 1.
// Term 2: (M_Tile * 2) - This ensures we have 1 cycle that does not fully saturate all CUs
// (i.e., we will have remainder tiles). This guarantees we have 1 full tile cycle plus
// remainder tiles for the 2 Tile SK portion; the rest of the tiles will fully saturate all CUs
// for the DP portion.
ck_tile::index_t M = (M_Tile * num_cu * 2) + (M_Tile * 2);
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = 2048;

this->Run(M, N, K);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
{
ck_tile::index_t M = 256;
ck_tile::index_t N = 256;
ck_tile::index_t K = 256;

this->Run(M, N, K);
}

TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;

// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
// assumes tile sizes are large enough such that occupancy is 1.
ck_tile::index_t M = M_Tile * num_cu;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile;

this->Run(M, N, K);
}

TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;

// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
// is 1.
ck_tile::index_t M = M_Tile * 2;
ck_tile::index_t N = N_Tile * 2;
ck_tile::index_t K = K_Tile * num_cu;

this->Run(M, N, K);
}
56 changes: 56 additions & 0 deletions test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <tuple>
#include <type_traits>

#include "gtest/gtest.h"

#include "ck_tile/host.hpp"

using F16 = ck_tile::half_t;
using F32 = float;
using BF16 = ck_tile::bf16_t;

using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

using Persistent = std::true_type;
using NonPersistent = std::false_type;

using I32 = ck_tile::number<32>;
using I256 = ck_tile::number<256>;

// clang-format off
using KernelTypesStreamKFp16Persistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent

std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;

using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
>;

using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent

std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
>;

using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
>;
// clang-format on
10 changes: 10 additions & 0 deletions test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#include "test_gemm_streamk_reboot_util.hpp"

ck_tile::index_t get_cu_count()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}
Loading