From 22d64c5fb69629b15c1aedcb1db059f2aa45b7f9 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 07:43:30 -0700 Subject: [PATCH 01/11] Remove mutable joint_matrix for cuda. Added new supported mma builtins where C/D types differ. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcores.hpp | 138 ++++++++++++------ .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 4 +- 2 files changed, 92 insertions(+), 50 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index 94ae318540012..dac2497dd5850 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -357,55 +357,55 @@ template void store_layoutT( - joint_matrix_cuda< + const joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride) { if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same_v) { __hmma_m16n16k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __imma_m16n16k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), + reinterpret_cast(&src.wi_marray[0]), stride, get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same_v) { __hmma_m8n32k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __imma_m8n32k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), + reinterpret_cast(&src.wi_marray[0]), stride, get_layout_id()); } } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same_v) { __hmma_m32n8k16_st_c_f32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __imma_m32n8k16_st_c_i32(dst.get(), - reinterpret_cast(&src.wi_marray), + &src.wi_marray[0], stride, get_layout_id()); } else if constexpr (std::is_same_v) { __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray), + reinterpret_cast(&src.wi_marray[0]), stride, get_layout_id()); } } else if constexpr (std::is_same_v) { __dmma_m8n8k4_st_c_f64(dst.get(), - reinterpret_cast(&src.wi_marray), stride, + &src.wi_marray[0], stride, get_layout_id()); } } @@ -413,7 +413,7 @@ void store_layoutT( template void joint_matrix_store_cuda( - joint_matrix_cuda< + const joint_matrix_cuda< T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows, NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, multi_ptr dst, size_t stride, @@ -465,8 +465,8 @@ constexpr int get_layout_pair_id< } template < - typename Tm, typename Tc, std::size_t M, std::size_t K, std::size_t N, - sycl::ext::oneapi::experimental::matrix::layout LayoutA, + typename Tm, typename Tc, typename Td, std::size_t M, std::size_t K, + std::size_t N, sycl::ext::oneapi::experimental::matrix::layout LayoutA, sycl::ext::oneapi::experimental::matrix::layout LayoutB, std::enable_if_t< (LayoutA == @@ -480,13 +480,13 @@ template < bool> = true> void joint_matrix_mad_cuda( joint_matrix_cuda< - Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, + Td, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, - joint_matrix_cuda &A, - joint_matrix_cuda &B, - joint_matrix_cuda< + const joint_matrix_cuda &A, + const joint_matrix_cuda &B, + const joint_matrix_cuda< Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) { if constexpr (M == 16 && N == 16 && K == 16) { @@ -506,16 +506,29 @@ void joint_matrix_mad_cuda( auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same_v) { - __hmma_m16n16k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); - + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same_v) { - __hmma_m16n16k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + __hmma_m16n16k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m16n16k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } } else if constexpr (std::is_same_v) { __mma_bf16_m16n16k16_mma_f32( @@ -542,15 +555,29 @@ void joint_matrix_mad_cuda( auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same_v) { - __hmma_m8n32k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m8n32k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same_v) { - __hmma_m8n32k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + __hmma_m8n32k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m8n32k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } } else if constexpr (std::is_same_v) { __mma_bf16_m8n32k16_mma_f32( @@ -581,25 +608,40 @@ void joint_matrix_mad_cuda( reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same_v) { + auto ptrA = reinterpret_cast(&A.wi_marray); auto ptrB = reinterpret_cast(&B.wi_marray); if constexpr (std::is_same_v) { - __hmma_m32n8k16_mma_f32f32( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m32n8k16_mma_f16f32( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } else if constexpr (std::is_same_v) { - __hmma_m32n8k16_mma_f16f16( - reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), - get_layout_pair_id(), 0); + if constexpr (std::is_same::value) { + __hmma_m32n8k16_mma_f32f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } else { + __hmma_m32n8k16_mma_f16f16( + reinterpret_cast(&D.wi_marray), ptrA, ptrB, + reinterpret_cast(&C.wi_marray), + get_layout_pair_id(), 0); + } } } } else if constexpr (M == 16 && N == 16 && K == 8) { __mma_tf32_m16n16k8_mma_f32(reinterpret_cast(&D.wi_marray), - reinterpret_cast(&A.wi_marray), - reinterpret_cast(&B.wi_marray), - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&A.wi_marray), + reinterpret_cast(&B.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else if constexpr (std::is_same_v) { __dmma_m8n8k4_mma_f64(reinterpret_cast(&D.wi_marray), diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 327e1e326f108..84f6aa577f868 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,7 +40,7 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - mutable sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) @@ -373,7 +373,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) if constexpr (std::is_same::value) { - sycl::ext::oneapi::detail::joint_matrix_mad_cuda( D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); } else { From 476a6b4cbd52214269259b79ae4fe2fd24fe496b Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 07:45:04 -0700 Subject: [PATCH 02/11] Removed get_wi_data() for cuda backend. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 76 ------------------- 1 file changed, 76 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 84f6aa577f868..866bfb70d5a1d 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -68,82 +68,6 @@ struct joint_matrix { #endif }; -#ifdef __SYCL_DEVICE_ONLY__ -template -class wi_data { - - joint_matrix &jm; - - wi_data(joint_matrix &_jm) : jm(_jm){}; - - template - friend decltype(auto) - get_wi_data(Grp, - joint_matrix &); - -public: - size_t length() { -#if defined(__NVPTX__) - return jm.cuda_impl.wi_marray.size(); -#endif - }; - - decltype(auto) operator[](size_t i) { -#if defined(__NVPTX__) - return (jm.cuda_impl.wi_marray[i]); -#else - std::ignore = i; -#endif - }; -}; -#else -template class wi_data { - marray &data; - wi_data(marray &wi_marray) : data(wi_marray){}; - template - friend decltype(auto) - get_wi_data(Grp, - joint_matrix &); - -public: - size_t length() { return data.size(); }; - - type &operator[](size_t i) { return data[i]; }; -}; -#endif - -template -#if defined(__SYCL_DEVICE_ONLY__) -#if defined(__NVPTX__) -__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please " - "use joint_matrix_apply() instead.") -#else -__attribute__((unavailable("get_wi_data() has been removed from the API and " - "replaced with joint_matrix_apply!"))) -#endif -#endif -inline __SYCL_ALWAYS_INLINE decltype(auto) - get_wi_data(Group sg, joint_matrix &jm) { -#if defined(__SYCL_DEVICE_ONLY__) - std::ignore = sg; - return wi_data(jm); -#else - std::ignore = sg; - std::ignore = jm; - if constexpr (std::is_same_v) { - marray unused{}; - return wi_data(unused); - } else { - marray unused{}; - return wi_data(unused); - } -#endif // defined(__SYCL_DEVICE_ONLY__) -} - template inline __SYCL_ALWAYS_INLINE void From 2d63b86b158fcd69073a5d553c910ab29dfd2c0c Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 07:54:45 -0700 Subject: [PATCH 03/11] Updated the tests to support the changes, and test new cases. Signed-off-by: JackAKirk --- .../Matrix/joint_matrix_apply_cuda.hpp | 68 --------------- .../Matrix/joint_matrix_gemm_cuda.hpp | 86 +++++++++---------- .../Matrix/joint_matrix_tensorcores_sm70.cpp | 52 ++++++----- .../Matrix/joint_matrix_tensorcores_sm72.cpp | 40 +++------ .../Matrix/joint_matrix_tensorcores_sm80.cpp | 42 +++------ 5 files changed, 94 insertions(+), 194 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp index 8a6c6672e0a5b..68be1f341b6bf 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_apply_cuda.hpp @@ -100,71 +100,3 @@ void assert_ops_ref(T *C, const float ref) { std::numeric_limits::epsilon()); } } - -template -void matrix_verify_op(queue q, big_matrix &C, - const float ref, Operation Op) { - { - buffer bufC(C.get_data(), range<2>(N * nWGperDim, M * nWGperDim)); - - q.submit([&](handler &cgh) { - accessor accC(bufC, - cgh); - - cgh.parallel_for>( - r, [accC, - Op](nd_item<2> spmd_item) [[sycl::reqd_sub_group_size(SG_SZ)]] { - const auto global_idx = spmd_item.get_global_id(0); - const auto global_idy = spmd_item.get_global_id(1); - const auto sg_startx = global_idx - spmd_item.get_local_id(0); - const auto sg_starty = global_idy - spmd_item.get_local_id(1); - - auto sg = spmd_item.get_sub_group(); - - joint_matrix sub_a; - joint_matrix sub_b; - joint_matrix sub_c; - - joint_matrix_fill(sg, sub_a, 3); - joint_matrix_fill(sg, sub_b, 1); - joint_matrix_fill(sg, sub_c, -80); - - auto wi_slice_a = get_wi_data(sg, sub_a); - for (int i = 0; i < wi_slice_a.length(); i++) { - if constexpr (std::is_same_v) { - if (wi_slice_a[i]) { - if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 3.0 || - wi_slice_a[i] < 4.0 || wi_slice_a[i] <= 3.0) { - T val = (wi_slice_a[i] != (2.0)) ? wi_slice_a[i] - : static_cast(2.0); - val = ((val) - (1)); - val = ((val) + (1)); - if (wi_slice_a[i] == (2.0)) { - val = ((val) - (2)); - val = ((val) * (3)); - val = ((val) / (2)); - - } else { - val = ((val) + (2)); - } - wi_slice_a[i] = val; - } - } - } else { - wi_slice_a[i] = Op(wi_slice_a[i], 2); - } - } - - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); - - joint_matrix_store( - sg, sub_c, - accC.template get_multi_ptr() + - (sg_startx * M) * (N * nWGperDim) + sg_starty / SG_SZ * N, - (N * nWGperDim), layout::row_major); - }); // parallel for - }).wait(); - } - assert_ops_ref(C.get_data(), ref); -} diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 219a3976f4c90..60c448c98faa8 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -30,37 +30,37 @@ constexpr float bf16_eps = 0.00390625; constexpr int N_THREADS_PER_MATRIX_OP = 32; // number of submatrices per row of accumulator ("C", "D") matrices. -constexpr int SUB_TILES_M = 2; +constexpr int SUB_TILES_M = 3; // number of submatrices per col of accumulator matrices. -constexpr int SUB_TILES_N = 3; +constexpr int SUB_TILES_N = 2; // number of submatrices per col of "A"/per row of "B", matrices. -constexpr int SUB_TILES_K = 4; +constexpr int SUB_TILES_K = 1; -template +template class TypeHelper; -template -using KernelName = class TypeHelper; +template +using KernelName = class TypeHelper; -template -T2 matrix_ref_mn(const int &m, const int &n, T1 *A, T1 *B, T2 *C) { - T2 res = C[m * Big_N + n]; +template +Tc matrix_ref_mn(const int &m, const int &n, Tm *A, Tm *B, Tc *C) { + Tc res = C[m * Big_N + n]; - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { for (int k = 0; k < Big_K; k++) res += A[m * Big_K + k] * B[k * Big_N + n]; } else { for (int k = 0; k < Big_K; k++) res += - static_cast(A[m * Big_K + k]) * static_cast(B[k * Big_N + n]); + static_cast(A[m * Big_K + k]) * static_cast(B[k * Big_N + n]); } return res; } -template > + typename T3 = std::remove_const_t> void test(queue &q) { // total number of M dimension matrix elements for the "Big matrix". constexpr auto Big_M = Sub_Tiles_M * M; @@ -69,27 +69,27 @@ void test(queue &q) { // total number of K dimension matrix elements for the "Big matrix". constexpr auto Big_K = Sub_Tiles_K * K; - std::remove_const_t A[Big_M * Big_K]; - std::remove_const_t B[Big_K * Big_N]; - std::remove_const_t C[Big_M * Big_N]; - std::remove_const_t D[Big_M * Big_N]; + std::remove_const_t A[Big_M * Big_K]; + std::remove_const_t B[Big_K * Big_N]; + std::remove_const_t C[Big_M * Big_N]; + Td D[Big_M * Big_N]; for (int i = 0; i < Big_M * Big_N; i++) { C[i] = 1; D[i] = 0; } - if constexpr (!std::is_same, bfloat16>::value) { + if constexpr (!std::is_same, bfloat16>::value) { for (int i = 0; i < Big_M * Big_K; i++) { - A[i] = i % 100; + A[i] = i % 3; } for (int i = 0; i < Big_K * Big_N; i++) { - B[i] = i % 100; + B[i] = i % 3; } } { - if constexpr (std::is_same, bfloat16>::value) { + if constexpr (std::is_same, bfloat16>::value) { buffer bufA(A, range<1>(Big_M * Big_K)); buffer bufB(B, range<1>(Big_K * Big_N)); @@ -97,7 +97,7 @@ void test(queue &q) { accessor accA(bufA, cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_M * Big_K), [=](item<1> item) { auto i = item.get_linear_id(); accA[i] = 0.1f * (i % 10); @@ -107,7 +107,7 @@ void test(queue &q) { accessor accB(bufB, cgh); - cgh.parallel_for>( + cgh.parallel_for>( range<1>(Big_K * Big_N), [=](item<1> item) { auto i = item.get_linear_id(); accB[i] = 0.1f * (i % 10); @@ -115,23 +115,23 @@ void test(queue &q) { }); } - buffer bufA(A, range<1>(Big_M * Big_K)); - buffer bufB(B, range<1>(Big_K * Big_N)); - buffer bufC(C, range<1>(Big_M * Big_N)); - buffer, 1> bufD(D, range<1>(Big_M * Big_N)); + buffer bufA(A, range<1>(Big_M * Big_K)); + buffer bufB(B, range<1>(Big_K * Big_N)); + buffer bufC(C, range<1>(Big_M * Big_N)); + buffer bufD(D, range<1>(Big_M * Big_N)); q.submit([&](handler &cgh) { - accessor accA(bufA, cgh); - accessor accB(bufB, cgh); - accessor accC(bufC, cgh); - accessor, 1, access::mode::write, target::device> + accessor accA(bufA, cgh); + accessor accB(bufB, cgh); + accessor accC(bufC, cgh); + accessor accD(bufD, cgh); range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; range<2> GlobalRange = {Sub_Tiles_M, Sub_Tiles_N * N_THREADS_PER_MATRIX_OP}; - cgh.parallel_for>( + cgh.parallel_for>( nd_range<2>(GlobalRange, LocalRange), [=](nd_item<2> item) { sycl::sub_group sg = item.get_sub_group(); // row id of current submatrix of BIG C matrix @@ -143,9 +143,12 @@ void test(queue &q) { sub_a; joint_matrix sub_b; - joint_matrix, + joint_matrix, use::accumulator, M, N> sub_c; + joint_matrix + sub_d; joint_matrix_load( sg, sub_c, @@ -168,20 +171,15 @@ void test(queue &q) { // round values to correct precision if using tf32 if constexpr (std::is_same::value) { - auto wi_size = get_wi_data(sg, sub_a).length(); - assert(wi_size == get_wi_data(sg, sub_b).length()); - for (auto i = 0; i < wi_size; ++i) { - get_wi_data(sg, sub_a)[i] = - round_to_tf32(get_wi_data(sg, sub_a)[i]); - get_wi_data(sg, sub_b)[i] = - round_to_tf32(get_wi_data(sg, sub_b)[i]); - } + auto round_lambda = [](auto &x) { x = round_to_tf32(x); }; + joint_matrix_apply(sg, sub_a, round_lambda); + joint_matrix_apply(sg, sub_b, round_lambda); } - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); + joint_matrix_mad(sg, sub_d, sub_a, sub_b, sub_c); } joint_matrix_store( - sg, sub_c, + sg, sub_d, accD.template get_multi_ptr() + (m * M) * Big_N + n * N, Big_N, layout::row_major); @@ -192,7 +190,7 @@ void test(queue &q) { for (int m = 0; m < Big_M; m++) { for (int n = 0; n < Big_N; n++) { - if constexpr (std::is_same, bfloat16>::value) { + if constexpr (std::is_same, bfloat16>::value) { auto res_device = matrix_ref_mn(m, n, A, B, C); assert(fabs(2 * (D[m * Big_N + n] - res_device)) / (D[m * Big_N + n] + res_device) < diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp index 6db5ac824a0d7..01ccec890828c 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp @@ -23,29 +23,47 @@ int main() { if (computeCapability >= 7.0) { // A/B half, Accumulator float - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); - test(Q); - test(Q); - test(Q); // A/B/Accumulator half - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); - test(Q); - test(Q); - test(Q); + // A/B/D half, C float + test(Q); + test(Q); + test(Q); + + test(Q); + test(Q); + test(Q); + + // A/B/C half, D float + test(Q); + test(Q); + test(Q); + + test(Q); + test(Q); + test(Q); + auto apply_add = [](auto &x) { x = x + 2; }; float D[MATRIX_M][MATRIX_N]; big_matrix MD_f((float *)&D); @@ -53,16 +71,6 @@ int main() { // joint_matrix_apply tests matrix_verify_lambda(Q, MD_f, 0.0, apply_add); - - // get_wi_data() Deprecated tests - - matrix_verify_op(Q, MD_f, 0.0, std::plus{}); - matrix_verify_op(Q, MD_f, 0.0, Logical{}); - matrix_verify_op(Q, MD_f, 16.0, - std::multiplies{}); - matrix_verify_op(Q, MD_f, -56.0, - std::divides{}); - matrix_verify_op(Q, MD_f, -64.0, std::minus{}); } return 0; diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp index d802f369e025d..923b14abaf5e4 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp @@ -22,27 +22,27 @@ int main() { std::stof(Q.get_device().get_info()); if (computeCapability >= 7.2) { - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); - test(Q); - test(Q); - test(Q); - test( + test( Q); - test(Q); - test(Q); + test(Q); + test(Q); - test(Q); - test(Q); - test(Q); auto apply_add = [](auto &x) { x = x + 2; }; @@ -54,22 +54,6 @@ int main() { matrix_verify_lambda(Q, MD_i, 0, apply_add); matrix_verify_lambda(Q, MD_i, 0, apply_add); - - // get_wi_data() Deprecated - - matrix_verify_op(Q, MD_i, 0, - std::plus{}); - matrix_verify_op(Q, MD_i, 16, - std::multiplies{}); - matrix_verify_op(Q, MD_i, -64, - std::minus{}); - matrix_verify_op(Q, MD_i, 0, - std::plus{}); - matrix_verify_op(Q, MD_i, 0.0, Logical{}); - matrix_verify_op(Q, MD_i, 16, - std::multiplies{}); - matrix_verify_op(Q, MD_i, -64, - std::minus{}); } return 0; diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp index 0834eae9679b2..38d55c0a2e8a1 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp @@ -22,25 +22,25 @@ int main() { std::stof(Q.get_device().get_info()); if (computeCapability >= 8.0) { - test(Q); - test(Q); + test(Q); - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); - test(Q); - test(Q); - test(Q); // A/B tf32 - test(Q); - test(Q); float D[MATRIX_M][MATRIX_N]; @@ -54,28 +54,6 @@ int main() { matrix_verify_lambda(Q, MD_f, 0.0, apply_add); matrix_verify_lambda(Q, MD_d, -60.0, apply_add); - - // get_wi_data() Deprecated - - matrix_verify_op(Q, MD_f, 0.0, - std::plus{}); - matrix_verify_op(Q, MD_f, 0.0, Logical{}); - matrix_verify_op(Q, MD_f, 16.0, - std::multiplies{}); - matrix_verify_op(Q, MD_f, -56.0, - std::divides{}); - matrix_verify_op(Q, MD_f, -64.0, - std::minus{}); - - matrix_verify_op(Q, MD_d, -60.0, - std::plus{}); - matrix_verify_op(Q, MD_d, -60.0, Logical{}); - matrix_verify_op(Q, MD_d, -56.0, - std::multiplies{}); - matrix_verify_op(Q, MD_d, -74.0, - std::divides{}); - matrix_verify_op(Q, MD_d, -76.0, - std::minus{}); } return 0; }; From cfab4499adc852abe47e11bd5925fa6e93925e4e Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 07:55:56 -0700 Subject: [PATCH 04/11] Required upstream fixes to correct signatures of const variables. Signed-off-by: JackAKirk --- clang/include/clang/Basic/BuiltinsNVPTX.def | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.def b/clang/include/clang/Basic/BuiltinsNVPTX.def index 6ecbd22e3fc38..1edb777959c61 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.def +++ b/clang/include/clang/Basic/BuiltinsNVPTX.def @@ -2545,22 +2545,22 @@ TARGET_BUILTIN(__hmma_m16n16k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX60)) TARGET_BUILTIN(__hmma_m16n16k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX60)) TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60)) TARGET_BUILTIN(__hmma_m16n16k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60)) -TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX60)) -TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX60)) +TARGET_BUILTIN(__hmma_m16n16k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX60)) +TARGET_BUILTIN(__hmma_m16n16k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX60)) TARGET_BUILTIN(__hmma_m32n8k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m32n8k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m32n8k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61)) -TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61)) -TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61)) +TARGET_BUILTIN(__hmma_m32n8k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61)) +TARGET_BUILTIN(__hmma_m32n8k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m8n32k16_ld_a, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m8n32k16_ld_b, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m8n32k16_ld_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61)) -TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*i*UiIi", "", AND(SM_70,PTX61)) -TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*f*UiIi", "", AND(SM_70,PTX61)) +TARGET_BUILTIN(__hmma_m8n32k16_st_c_f16, "vi*iC*UiIi", "", AND(SM_70,PTX61)) +TARGET_BUILTIN(__hmma_m8n32k16_st_c_f32, "vf*fC*UiIi", "", AND(SM_70,PTX61)) TARGET_BUILTIN(__hmma_m16n16k16_mma_f16f16, "vi*iC*iC*iC*IiIi", "", AND(SM_70,PTX60)) TARGET_BUILTIN(__hmma_m16n16k16_mma_f32f16, "vf*iC*iC*iC*IiIi", "", AND(SM_70,PTX60)) From 96ab577c9031cf4650353345d2416c10dccf3017 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 08:49:16 -0700 Subject: [PATCH 05/11] Fix Format. Signed-off-by: JackAKirk --- .../ext/oneapi/matrix/matrix-tensorcores.hpp | 54 ++++++------- .../sycl/ext/oneapi/matrix/matrix-unified.hpp | 7 +- .../Matrix/joint_matrix_gemm_cuda.hpp | 11 +-- .../Matrix/joint_matrix_tensorcores_sm70.cpp | 76 +++++++++++-------- .../Matrix/joint_matrix_tensorcores_sm72.cpp | 39 +++++----- .../Matrix/joint_matrix_tensorcores_sm80.cpp | 36 +++++---- 6 files changed, 120 insertions(+), 103 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index dac2497dd5850..a74f87307483b 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -363,49 +363,45 @@ void store_layoutT( multi_ptr dst, size_t stride) { if constexpr (NumRows == 16 && NumCols == 16) { if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __hmma_m16n16k16_st_c_f32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m16n16k16_st_c_i32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __imma_m16n16k16_st_c_i32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m16n16k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray[0]), - stride, get_layout_id()); + __hmma_m16n16k16_st_c_f16( + reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray[0]), stride, + get_layout_id()); } } else if constexpr (NumRows == 8 && NumCols == 32) { if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __hmma_m8n32k16_st_c_f32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m8n32k16_st_c_i32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __imma_m8n32k16_st_c_i32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m8n32k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray[0]), - stride, get_layout_id()); + __hmma_m8n32k16_st_c_f16( + reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray[0]), stride, + get_layout_id()); } } else if constexpr (NumRows == 32 && NumCols == 8) { if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __hmma_m32n8k16_st_c_f32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __imma_m32n8k16_st_c_i32(dst.get(), - &src.wi_marray[0], - stride, get_layout_id()); + __imma_m32n8k16_st_c_i32(dst.get(), &src.wi_marray[0], stride, + get_layout_id()); } else if constexpr (std::is_same_v) { - __hmma_m32n8k16_st_c_f16(reinterpret_cast(dst.get()), - reinterpret_cast(&src.wi_marray[0]), - stride, get_layout_id()); + __hmma_m32n8k16_st_c_f16( + reinterpret_cast(dst.get()), + reinterpret_cast(&src.wi_marray[0]), stride, + get_layout_id()); } } else if constexpr (std::is_same_v) { - __dmma_m8n8k4_st_c_f64(dst.get(), - &src.wi_marray[0], stride, + __dmma_m8n8k4_st_c_f64(dst.get(), &src.wi_marray[0], stride, get_layout_id()); } } diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp index 866bfb70d5a1d..92cc593a9a7a2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp @@ -40,8 +40,7 @@ struct joint_matrix { #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) - sycl::ext::oneapi::detail::joint_matrix_cuda + sycl::ext::oneapi::detail::joint_matrix_cuda cuda_impl; #elif defined(__SPIR__) __spv::__spirv_JointMatrixINTEL< @@ -297,8 +296,8 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( #if defined(__SYCL_DEVICE_ONLY__) #if defined(__NVPTX__) if constexpr (std::is_same::value) { - sycl::ext::oneapi::detail::joint_matrix_mad_cuda( + sycl::ext::oneapi::detail::joint_matrix_mad_cuda( D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); } else { assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " diff --git a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp index 60c448c98faa8..fe5b110864e6b 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp +++ b/sycl/test-e2e/Matrix/joint_matrix_gemm_cuda.hpp @@ -58,8 +58,8 @@ Tc matrix_ref_mn(const int &m, const int &n, Tm *A, Tm *B, Tc *C) { return res; } -template > void test(queue &q) { // total number of M dimension matrix elements for the "Big matrix". @@ -124,8 +124,7 @@ void test(queue &q) { accessor accA(bufA, cgh); accessor accB(bufB, cgh); accessor accC(bufC, cgh); - accessor - accD(bufD, cgh); + accessor accD(bufD, cgh); range<2> LocalRange = {1, N_THREADS_PER_MATRIX_OP}; range<2> GlobalRange = {Sub_Tiles_M, @@ -146,9 +145,7 @@ void test(queue &q) { joint_matrix, use::accumulator, M, N> sub_c; - joint_matrix - sub_d; + joint_matrix sub_d; joint_matrix_load( sg, sub_c, diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp index 01ccec890828c..e7e7dea389ca7 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp @@ -23,46 +23,62 @@ int main() { if (computeCapability >= 7.0) { // A/B half, Accumulator float - test(Q); - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); + test( + Q); + test( + Q); + test( + Q); + + test(Q); + test(Q); + test(Q); // A/B/Accumulator half - test(Q); + test( + Q); test(Q); test(Q); - test(Q); + test(Q); + test(Q); + + // A/B/D half, C float + test( + Q); + test( + Q); + test( + Q); + + test(Q); - test(Q); - test(Q); - // A/B/D half, C float - test(Q); - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); - // A/B/C half, D float - test(Q); - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); + test( + Q); + test( + Q); + test( + Q); + + test(Q); + test(Q); + test(Q); auto apply_add = [](auto &x) { x = x + 2; }; float D[MATRIX_M][MATRIX_N]; diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp index 923b14abaf5e4..cea15392408cc 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm72.cpp @@ -22,28 +22,33 @@ int main() { std::stof(Q.get_device().get_info()); if (computeCapability >= 7.2) { - test(Q); - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); + test(Q); - test( - Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); - test(Q); - test(Q); + test(Q); - test(Q); + test(Q); + + test(Q); + test(Q); + test(Q); auto apply_add = [](auto &x) { x = x + 2; }; diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp index 38d55c0a2e8a1..2a0731d9b988e 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm80.cpp @@ -22,26 +22,30 @@ int main() { std::stof(Q.get_device().get_info()); if (computeCapability >= 8.0) { - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); - - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + + test(Q); + test(Q); + test(Q); + + test(Q); + test(Q); + test(Q); // A/B tf32 test(Q); - test(Q); + test(Q); float D[MATRIX_M][MATRIX_N]; big_matrix MD_f((float *)&D); From f4e7f42297965a3494c99a6f8762860ad125fdbd Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 09:28:04 -0700 Subject: [PATCH 06/11] Update cuda mma ops support table. Signed-off-by: JackAKirk --- .../sycl_ext_oneapi_matrix.asciidoc | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index cb51dfe295715..ac50150412c7f 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -918,7 +918,6 @@ The complete set of matrix data types and shapes that are supported by the `ext_oneapi_cuda` backend are represented in the following table. In this architecture's implementation, the type of the A matrix must be the same as the type of the B -matrix. Also, the type of the C matrix must be the same as the type of the D matrix. IMPORTANT: When compiling for the `ext_oneapi_cuda` backend the target @@ -933,29 +932,37 @@ supported parameter combination is specified in the following table. [frame="none",options="header"] |====================== -| A and B type | C and D type | M | N | K | Minimum Compute Capability -.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` +| A and B type | C type | D type | M | N | K | Minimum Compute Capability +.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32` |16 |16 |16 .6+| sm_70 |8 |32 |16 |32 |8 |16 -.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` +.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp16` |16 |16 |16 |8 |32 |16 |32 |8 |16 -.3+| `matrix_type::sint8` .3+| `matrix_type::sint32` +.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16` +|16 |16 |16 .6+| sm_70 +|8 |32 |16 +|32 |8 |16 +.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16` +|16 |16 |16 +|8 |32 |16 +|32 |8 |16 +.3+| `matrix_type::sint8` .3+| `matrix_type::sint32` .3+| `matrix_type::sint32` |16 |16 |16 .6+| sm_72 |8 |32 |16 |32 |8 |16 -.3+|`matrix_type::uint8` .3+|`matrix_type::sint32` +.3+|`matrix_type::uint8` .3+|`matrix_type::sint32` .3+|`matrix_type::sint32` |16 |16 |16 |8 |32 |16 |32 |8 |16 -| `matrix_type::tf32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80 -.3+|`matrix_type::bf16` .3+| `matrix_type::fp32` +| `matrix_type::tf32` | `matrix_type::fp32` | `matrix_type::fp32` |16 |16 |8 .5+| sm_80 +.3+|`matrix_type::bf16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32` |16 |16 |16 |8 |32 |16 |32 |8 |16 -| `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4 +| `matrix_type::fp64` | `matrix_type::fp64` | `matrix_type::fp64` |8 |8 |4 |====================== IMPORTANT: The `stride` argument to `joint_matrix_load` and From 39cd26dc01f3e00cd7858d23fc8e04ea8088069f Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Tue, 17 Oct 2023 09:32:04 -0700 Subject: [PATCH 07/11] Correct the cuda support table from the last commit. Signed-off-by: JackAKirk --- .../sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc index ac50150412c7f..aa103ebb3d282 100644 --- a/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc +++ b/sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc @@ -934,7 +934,7 @@ supported parameter combination is specified in the following table. |====================== | A and B type | C type | D type | M | N | K | Minimum Compute Capability .3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp32` -|16 |16 |16 .6+| sm_70 +|16 |16 |16 .12+| sm_70 |8 |32 |16 |32 |8 |16 .3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp16` @@ -942,10 +942,10 @@ supported parameter combination is specified in the following table. |8 |32 |16 |32 |8 |16 .3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16` -|16 |16 |16 .6+| sm_70 +|16 |16 |16 |8 |32 |16 |32 |8 |16 -.3+| `matrix_type::fp16` .3+| `matrix_type::fp32` .3+| `matrix_type::fp16` +.3+| `matrix_type::fp16` .3+| `matrix_type::fp16` .3+| `matrix_type::fp32` |16 |16 |16 |8 |32 |16 |32 |8 |16 From b337e589a406d53d09485cf2f944d32a030748f7 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Wed, 18 Oct 2023 01:41:51 -0700 Subject: [PATCH 08/11] Removed unnecessary get_wi_data() from tf32 device code check test. Signed-off-by: JackAKirk --- .../cuda/matrix/matrix-nvptx-tf32-test.cpp | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index f47a701fe7bc6..52e55a21c7f91 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -79,15 +79,6 @@ int main() { sg, sub_c, accC.template get_multi_ptr(), N, layout::row_major); - // Round a, b to tf32 - for (auto i = 0; i < 4; ++i) - get_wi_data(sg, sub_a)[i] = - round_to_tf32(get_wi_data(sg, sub_a)[i]); - - for (auto i = 0; i < 4; ++i) - get_wi_data(sg, sub_b)[i] = - round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store( @@ -128,15 +119,6 @@ int main() { sg, sub_c, accC.template get_multi_ptr(), N, layout::col_major); - // Round a, b to tf32 - for (auto i = 0; i < 4; ++i) - get_wi_data(sg, sub_a)[i] = - round_to_tf32(get_wi_data(sg, sub_a)[i]); - - for (auto i = 0; i < 4; ++i) - get_wi_data(sg, sub_b)[i] = - round_to_tf32(get_wi_data(sg, sub_b)[i]); - joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16) joint_matrix_store( From d57c448f9e5217a16745b28236c8eec68f4f7d3e Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 19 Oct 2023 01:57:22 -0700 Subject: [PATCH 09/11] Small fixes. Signed-off-by: JackAKirk --- .../sycl/ext/oneapi/matrix/matrix-tensorcores.hpp | 2 +- .../Matrix/joint_matrix_tensorcores_sm70.cpp | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp index a74f87307483b..e7594a4d890a2 100644 --- a/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp +++ b/sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp @@ -623,7 +623,7 @@ void joint_matrix_mad_cuda( if constexpr (std::is_same::value) { __hmma_m32n8k16_mma_f32f16( reinterpret_cast(&D.wi_marray), ptrA, ptrB, - reinterpret_cast(&C.wi_marray), + reinterpret_cast(&C.wi_marray), get_layout_pair_id(), 0); } else { __hmma_m32n8k16_mma_f16f16( diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp index e7e7dea389ca7..23dffcb744f59 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp @@ -58,11 +58,11 @@ int main() { test( Q); - test(Q); - test(Q); - test(Q); // A/B/C half, D float @@ -73,11 +73,11 @@ int main() { test( Q); - test(Q); - test(Q); - test(Q); auto apply_add = [](auto &x) { x = x + 2; }; From 93637f64f9927faa5893243496ed47f739a7d276 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 19 Oct 2023 02:03:15 -0700 Subject: [PATCH 10/11] Fixed format. Signed-off-by: JackAKirk --- .../Matrix/joint_matrix_tensorcores_sm70.cpp | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp index 23dffcb744f59..f28372b6277dc 100644 --- a/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp +++ b/sycl/test-e2e/Matrix/joint_matrix_tensorcores_sm70.cpp @@ -58,12 +58,12 @@ int main() { test( Q); - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); // A/B/C half, D float test( @@ -73,12 +73,12 @@ int main() { test( Q); - test(Q); - test(Q); - test(Q); + test(Q); + test(Q); + test(Q); auto apply_add = [](auto &x) { x = x + 2; }; float D[MATRIX_M][MATRIX_N]; From 675836259697eec0fc06629a9a769c1ced24b269 Mon Sep 17 00:00:00 2001 From: JackAKirk Date: Thu, 19 Oct 2023 09:38:06 -0700 Subject: [PATCH 11/11] Added a single dev code check for round_to_tf32. Signed-off-by: JackAKirk --- .../check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp index 52e55a21c7f91..e2ae423f04c7d 100644 --- a/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp +++ b/sycl/test/check_device_code/cuda/matrix/matrix-nvptx-tf32-test.cpp @@ -79,6 +79,10 @@ int main() { sg, sub_c, accC.template get_multi_ptr(), N, layout::row_major); + auto round_lambda = [](auto &x) { x = round_to_tf32(x); }; + //CHECK-OPAQUE: tail call i32 @llvm.nvvm.f2tf32.rna(float %{{.*}}) + joint_matrix_apply(sg, sub_a, round_lambda); + joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c); //CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 {{.*}} joint_matrix_store(