-
Notifications
You must be signed in to change notification settings - Fork 794
[SYCL][Matrix] syntax changes as preparation before moving joint matrix from experimental namespace #11215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL][Matrix] syntax changes as preparation before moving joint matrix from experimental namespace #11215
Changes from 15 commits
b68aead
5fbb285
bf6cd56
b399041
dae1ec6
4ec8360
a461cbb
5ff715b
8ad7da9
26ea49d
a09a778
821fa89
a3921b5
ef1bc67
f395199
c71fee6
8f2f197
1411376
95df3b1
fb1afdc
11df531
a29e8f3
a821107
3f1b575
1d091de
1e20968
1fe7fcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -29,10 +29,6 @@ | |||
| namespace sycl { | ||||
| inline namespace _V1 { | ||||
| namespace ext { | ||||
| namespace intel::experimental::matrix::layout { | ||||
| constexpr sycl::ext::oneapi::experimental::matrix::layout packed = | ||||
| static_cast<sycl::ext::oneapi::experimental::matrix::layout>(2); | ||||
| } | ||||
| namespace oneapi { | ||||
| namespace experimental { | ||||
| namespace matrix { | ||||
|
|
@@ -48,8 +44,7 @@ template <layout Layout> struct spv_matrix_layout_traits { | |||
|
|
||||
| SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor) | ||||
| SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor) | ||||
| SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed, | ||||
| __spv::MatrixLayout::Packed) | ||||
| SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed) | ||||
| SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic) | ||||
|
|
||||
| template <use Use> struct spv_matrix_use_traits { | ||||
|
|
@@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait< | |||
| using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32; | ||||
| using storage_element_type = float; | ||||
| }; | ||||
| } // namespace detail | ||||
| } // namespace oneapi | ||||
|
|
||||
| namespace intel::experimental::matrix { | ||||
|
|
||||
| using namespace sycl::ext::oneapi::experimental::matrix; | ||||
| // Begin wi_element definition | ||||
|
|
@@ -121,12 +112,12 @@ class wi_element { | |||
| std::size_t i) | ||||
| : M(Mat), idx(i) {} | ||||
|
|
||||
| inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() { | ||||
| inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() { | ||||
| #if defined(__SYCL_DEVICE_ONLY__) | ||||
| __ocl_vec_t<uint32_t, 2> coord = | ||||
| __spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx); | ||||
| const uint32_t row = coord[0]; | ||||
| const uint32_t col = coord[1]; | ||||
| const size_t row = coord[0]; | ||||
| const size_t col = coord[1]; | ||||
| return std::make_tuple(row, col); | ||||
| #else | ||||
| throw runtime_error("joint matrix is not supported on host device.", | ||||
|
|
@@ -196,7 +187,7 @@ class wi_element { | |||
|
|
||||
| #if __SYCL_DEVICE_ONLY__ | ||||
| #define OP(op) \ | ||||
| template <typename T2> wi_element &operator op##=(const T2 &rhs) { \ | ||||
| template <typename T2> wi_element &operator op##=(const T2 & rhs) { \ | ||||
| M.spvm = __spirv_VectorInsertDynamic( \ | ||||
| M.spvm, \ | ||||
| static_cast<storage_element_type>( \ | ||||
|
|
@@ -211,7 +202,7 @@ class wi_element { | |||
| } | ||||
| #else // __SYCL_DEVICE_ONLY__ | ||||
| #define OP(op) \ | ||||
| template <typename T2> wi_element &operator op##=(const T2 &rhs) { \ | ||||
| template <typename T2> wi_element &operator op##=(const T2 & rhs) { \ | ||||
| (void)rhs; \ | ||||
| throw runtime_error("joint matrix is not supported on host device.", \ | ||||
| PI_ERROR_INVALID_DEVICE); \ | ||||
|
|
@@ -315,7 +306,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |||
|
|
||||
| #if __SYCL_DEVICE_ONLY__ | ||||
| #define OP(opassign, op) \ | ||||
| wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \ | ||||
| wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \ | ||||
| M.spvm = __spirv_VectorInsertDynamic( \ | ||||
| M.spvm, \ | ||||
| __spirv_VectorExtractDynamic< \ | ||||
|
|
@@ -328,7 +319,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout, | |||
| } | ||||
| #else // __SYCL_DEVICE_ONLY__ | ||||
| #define OP(opassign, op) \ | ||||
| wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \ | ||||
| wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \ | ||||
| (void)rhs; \ | ||||
| throw runtime_error("joint matrix is not supported on host device.", \ | ||||
| PI_ERROR_INVALID_DEVICE); \ | ||||
|
|
@@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix< | |||
| } | ||||
|
|
||||
| // End wi_data definition | ||||
| } // namespace detail | ||||
| } // namespace oneapi | ||||
|
|
||||
| namespace intel::experimental::matrix { | ||||
| template < | ||||
| typename Group, typename T, typename Tp, | ||||
| sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows, | ||||
|
|
@@ -490,7 +484,7 @@ template < | |||
| bool> = true> | ||||
| inline __SYCL_ALWAYS_INLINE void | ||||
| joint_matrix_store(Group sg, | ||||
| sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||||
| const sycl::ext::oneapi::experimental::matrix::joint_matrix< | ||||
| Group, Tp, Use, NumRows, NumCols, Layout> &src, | ||||
| multi_ptr<T, Space, IsDecorated> dst, size_t stride) { | ||||
| #if defined(__SYCL_DEVICE_ONLY__) | ||||
|
|
@@ -528,6 +522,43 @@ joint_matrix_store(Group sg, | |||
| PI_ERROR_INVALID_DEVICE); | ||||
| #endif // defined(__SYCL_DEVICE_ONLY__) | ||||
| } | ||||
|
|
||||
| template <typename Group, typename T, | ||||
| sycl::ext::oneapi::experimental::matrix::use Use, size_t M, size_t N, | ||||
| sycl::ext::oneapi::experimental::matrix::layout Layout, typename F> | ||||
| inline __SYCL_ALWAYS_INLINE void joint_matrix_apply( | ||||
| Group sg, | ||||
| sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, M, N, | ||||
| Layout> &jm, | ||||
| F &&lambda) { | ||||
| #if defined(__SYCL_DEVICE_ONLY__) | ||||
| #if defined(__NVPTX__) | ||||
| std::ignore = sg; | ||||
| for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) { | ||||
| lambda(jm.cuda_impl.wi_marray[i]); | ||||
| } | ||||
| #else // NVPTX | ||||
| using storage_element_type = | ||||
| typename oneapi::detail::jm_type_interpretation_helper_trait< | ||||
| T>::storage_element_type; | ||||
| auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); | ||||
| for (int i = 0; i < wi_data_c.length(); i++) { | ||||
| storage_element_type element = wi_data_c[i]; | ||||
| auto [row, col] = wi_data_c[i].get_coord(); | ||||
| lambda(element, row, col); | ||||
| wi_data_c[i] = element; | ||||
| } | ||||
| #endif | ||||
| #else | ||||
| std::ignore = sg; | ||||
| std::ignore = jm; | ||||
| std::ignore = lambda; | ||||
| throw runtime_error("joint matrix is not supported on host device.", | ||||
| PI_ERROR_INVALID_DEVICE); | ||||
| #endif | ||||
| return; | ||||
|
||||
| return; |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,19 +61,8 @@ struct joint_matrix { | |
| } | ||
| #ifdef __SYCL_DEVICE_ONLY__ | ||
| #if defined(__SPIR__) | ||
| // Generate a non-trivial assignment operator and copy c'tor that prevents | ||
| // memcpy from being generated. | ||
| // TODO: to remove, when either IGC can handle alloca JointMatrix or | ||
| // combination of InstCombine + SROA + mem2reg can remove it | ||
| joint_matrix(const joint_matrix &other) { | ||
| spvm = other.spvm; | ||
| return *this; | ||
| } | ||
|
|
||
| joint_matrix &operator=(const joint_matrix &rhs) { | ||
| spvm = rhs.spvm; | ||
| return *this; | ||
| } | ||
| joint_matrix(const joint_matrix &other) = delete; | ||
| joint_matrix &operator=(const joint_matrix &rhs) = delete; | ||
| #endif // defined(__SPIR__) | ||
| #endif | ||
| }; | ||
|
|
@@ -98,19 +87,21 @@ class wi_data { | |
| #if defined(__NVPTX__) | ||
| return jm.cuda_impl.wi_marray.size(); | ||
| #else | ||
| throw runtime_error("get_wi_data is available using: " | ||
| "ext::intel::experimental::matrix::get_wi_data.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| throw runtime_error( | ||
|
||
| "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " | ||
| "intel users are expected to use joint_matrix_copy instead.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| #endif | ||
| }; | ||
|
|
||
| decltype(auto) operator[](size_t i) { | ||
| #if defined(__NVPTX__) | ||
| return (jm.cuda_impl.wi_marray[i]); | ||
| #else | ||
| throw runtime_error("get_wi_data is available using: " | ||
| "ext::intel::experimental::matrix::get_wi_data.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| throw runtime_error( | ||
| "get_wi_data is available using: ext::oneapi::detail::get_wi_data, but " | ||
|
||
| "intel users are expected to use joint_matrix_copy instead.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| #endif | ||
| }; | ||
| }; | ||
YuriPlyakhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
@@ -138,9 +129,9 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols, | |
| __SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please " | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should remove this. This is not really deprecated as joint_matrix is experimental so we can just remove APIs. Deprecated means they still exist and implementations maintain them. In the case of get_wi_data. it is replaced by joint_matrix_apply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this addressed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be addressed by @JackAKirk among other CUDA changes in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I will make this change as soon as this PR is merged. |
||
| "use joint_matrix_apply() instead.") | ||
| #else | ||
| __attribute__((unavailable( | ||
| "get_wi_data can't be used on intel device, please use " | ||
| "sycl::ext::intel::experimental::matrix::get_wi_data instead!"))) | ||
| __attribute__(( | ||
|
||
| unavailable("get_wi_data can't be used on intel device, please use " | ||
| "joint_matrix_apply instead!"))) | ||
| #endif | ||
dkhaldi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| #endif | ||
| inline __SYCL_ALWAYS_INLINE decltype(auto) | ||
|
|
@@ -176,7 +167,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm, | |
| using storage_element_type = | ||
| typename oneapi::detail::jm_type_interpretation_helper_trait< | ||
| T>::storage_element_type; | ||
| auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm); | ||
| auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); | ||
| for (int i = 0; i < wi_data_c.length(); i++) { | ||
| storage_element_type element = wi_data_c[i]; | ||
| lambda(element); | ||
|
|
@@ -262,7 +253,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( | |
| Ptr, stride, __spv::MatrixLayout::ColumnMajor, | ||
| spv_scope_traits<Group>::value); | ||
| break; | ||
| case sycl::ext::intel::experimental::matrix::layout::packed: | ||
| case layout::ext_intel_packed: | ||
| res.spvm = __spirv_JointMatrixLoadINTEL< | ||
| DecorT, S, NumRows, NumCols, | ||
| spv_matrix_use_traits<use::accumulator>::value, | ||
|
|
@@ -327,8 +318,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, | |
| access::address_space Space, access::decorated IsDecorated> | ||
| inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | ||
| Group sg, | ||
| joint_matrix<Group, T, use::accumulator, NumRows, NumCols, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, | ||
| const joint_matrix<Group, T, use::accumulator, NumRows, NumCols, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
| &src, | ||
| multi_ptr<T, Space, IsDecorated> dst, size_t stride, | ||
| sycl::ext::oneapi::experimental::matrix::layout Layout) { | ||
| #if defined(__SYCL_DEVICE_ONLY__) | ||
|
|
@@ -361,7 +353,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | |
| Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, | ||
| spv_scope_traits<Group>::value); | ||
| break; | ||
| case sycl::ext::intel::experimental::matrix::layout::packed: | ||
| case layout::ext_intel_packed: | ||
| __spirv_JointMatrixStoreINTEL< | ||
| DecorT, T, NumRows, NumCols, | ||
| spv_matrix_use_traits<use::accumulator>::value, | ||
|
|
@@ -382,53 +374,79 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | |
| #endif // defined(__SYCL_DEVICE_ONLY__) | ||
| } | ||
|
|
||
| template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, | ||
| std::size_t K, std::size_t N, layout LayoutA, layout LayoutB> | ||
| inline __SYCL_ALWAYS_INLINE | ||
| joint_matrix<Group, Tc, use::accumulator, M, N, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
| joint_matrix_mad( | ||
| Group sg, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A, | ||
| joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B, | ||
| joint_matrix<Group, Tc, use::accumulator, M, N, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
| &C) { | ||
| template <typename Group, typename Ta, typename Tb, typename Tc, typename Td, | ||
| std::size_t M, std::size_t K, std::size_t N, layout LayoutA, | ||
| layout LayoutB> | ||
| inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( | ||
| Group sg, | ||
| joint_matrix<Group, Td, use::accumulator, M, N, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, | ||
| const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A, | ||
| const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B, | ||
| const joint_matrix<Group, Tc, use::accumulator, M, N, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
| &C) { | ||
| #if defined(__SYCL_DEVICE_ONLY__) | ||
| #if defined(__NVPTX__) | ||
| std::ignore = sg; | ||
| if constexpr (std::is_same<Ta, Tb>::value) { | ||
| joint_matrix<Group, Tc, use::accumulator, M, N, | ||
| sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
| D; | ||
| sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA, | ||
| LayoutB>( | ||
| D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); | ||
| return D; | ||
| } else { | ||
| assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " | ||
| "requires that joint_matrix data types Ta and Tb match"); | ||
| } | ||
| #else | ||
| joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res; | ||
| if constexpr (std::is_same<Ta, uint16_t>::value && | ||
| std::is_same<Tb, uint16_t>::value && | ||
| std::is_same<Tc, float>::value) | ||
| res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value) | ||
| res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value) | ||
| res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value) | ||
| res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| else | ||
| res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| return res; | ||
| D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
| #endif // defined(__NVPTX__) | ||
| #else | ||
| std::ignore = sg; | ||
| std::ignore = A; | ||
| std::ignore = B; | ||
| std::ignore = C; | ||
| std::ignore = D; | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| #endif // defined(__SYCL_DEVICE_ONLY__) | ||
| } | ||
|
|
||
| template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols, | ||
| use Use1, use Use2, layout Layout1, layout Layout2> | ||
| void joint_matrix_copy( | ||
| Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src, | ||
| joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) { | ||
| #if defined(__SYCL_DEVICE_ONLY__) | ||
| #if defined(__NVPTX__) | ||
| std::ignore = sg; | ||
| for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { | ||
| dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; | ||
| } | ||
| #else | ||
| using storage_element_type = | ||
| typename oneapi::detail::jm_type_interpretation_helper_trait< | ||
| T2>::storage_element_type; | ||
| auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src); | ||
| auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst); | ||
| for (int i = 0; i < wi_data_c.length(); i++) { | ||
| wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]); | ||
| } | ||
| #endif // defined(__NVPTX__) | ||
| #else | ||
| std::ignore = sg; | ||
| std::ignore = dst; | ||
| std::ignore = src; | ||
| throw runtime_error("joint matrix is not supported on host device.", | ||
| PI_ERROR_INVALID_DEVICE); | ||
| #endif // defined(__SYCL_DEVICE_ONLY__) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to be consistent with spec. Also depending on matrix instead of M and N, it can be N and K, etc..., so to avoid ambiguity, I suggest not to use M and N for API that can be applied to matrices with different use