Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b68aead
[Matrix] syntax changes as prepraration before moving joint matrix from
yubingex007-a11y Sep 19, 2023
5fbb285
clang-format
yubingex007-a11y Sep 19, 2023
bf6cd56
fix typo: dest->dst
yubingex007-a11y Sep 19, 2023
b399041
fix testcase
yubingex007-a11y Sep 19, 2023
dae1ec6
fix mad bug
yubingex007-a11y Sep 19, 2023
4ec8360
fix cuda const joint_matrix_cuda
yubingex007-a11y Sep 19, 2023
a461cbb
fix const issue of jm_store_cuda
yubingex007-a11y Sep 19, 2023
5ff715b
fix const
yubingex007-a11y Sep 19, 2023
8ad7da9
lint
yubingex007-a11y Sep 19, 2023
26ea49d
address dounia's comments and roll back all the testcase changes
yubingex007-a11y Sep 21, 2023
a09a778
test changes: mov D in mad
yubingex007-a11y Sep 21, 2023
821fa89
testcase changes: ext_intel_layout
yubingex007-a11y Sep 21, 2023
a3921b5
testcase changes: wi_data=>jm_apply
yubingex007-a11y Sep 21, 2023
ef1bc67
lint
yubingex007-a11y Sep 21, 2023
f395199
Merge remote-tracking branch 'intel_llvm/sycl' into jm_syntax
yubingex007-a11y Sep 21, 2023
c71fee6
Merge remote-tracking branch 'intel_llvm/sycl' into jm_syntax
yubingex007-a11y Sep 22, 2023
8f2f197
handle cuda testcase compfail
yubingex007-a11y Sep 22, 2023
1411376
address dounia's comments
yubingex007-a11y Sep 22, 2023
95df3b1
lint
yubingex007-a11y Sep 22, 2023
fb1afdc
rm sycl/test/matrix/query-use.cpp
yubingex007-a11y Sep 22, 2023
11df531
fix x jm_mad in joint_matrix_bf16_fill_k_cache_impl.hpp
yubingex007-a11y Sep 25, 2023
a29e8f3
Merge remote-tracking branch 'intel_llvm/sycl' into jm_syntax
yubingex007-a11y Oct 9, 2023
a821107
address comments
yubingex007-a11y Oct 11, 2023
3f1b575
Merge remote-tracking branch 'intel_llvm/sycl' into jm_syntax
yubingex007-a11y Oct 11, 2023
1d091de
rm element_wise_irreg_sum_rows_impl.hpp
yubingex007-a11y Oct 11, 2023
1e20968
small fix
yubingex007-a11y Oct 11, 2023
1fe7fcd
small fix
yubingex007-a11y Oct 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 49 additions & 18 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
namespace sycl {
inline namespace _V1 {
namespace ext {
namespace intel::experimental::matrix::layout {
constexpr sycl::ext::oneapi::experimental::matrix::layout packed =
static_cast<sycl::ext::oneapi::experimental::matrix::layout>(2);
}
namespace oneapi {
namespace experimental {
namespace matrix {
Expand All @@ -48,8 +44,7 @@ template <layout Layout> struct spv_matrix_layout_traits {

SPV_MATRIX_LAYOUT_TRAITS(layout::row_major, __spv::MatrixLayout::RowMajor)
SPV_MATRIX_LAYOUT_TRAITS(layout::col_major, __spv::MatrixLayout::ColumnMajor)
SPV_MATRIX_LAYOUT_TRAITS(sycl::ext::intel::experimental::matrix::layout::packed,
__spv::MatrixLayout::Packed)
SPV_MATRIX_LAYOUT_TRAITS(layout::ext_intel_packed, __spv::MatrixLayout::Packed)
SPV_MATRIX_LAYOUT_TRAITS(layout::dynamic, __spv::MatrixLayout::Dynamic)

template <use Use> struct spv_matrix_use_traits {
Expand Down Expand Up @@ -94,10 +89,6 @@ struct jm_type_interpretation_helper_trait<
using element_type = sycl::ext::oneapi::experimental::matrix::precision::tf32;
using storage_element_type = float;
};
} // namespace detail
} // namespace oneapi

namespace intel::experimental::matrix {

using namespace sycl::ext::oneapi::experimental::matrix;
// Begin wi_element definition
Expand All @@ -121,12 +112,12 @@ class wi_element {
std::size_t i)
: M(Mat), idx(i) {}

inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
inline __SYCL_ALWAYS_INLINE std::tuple<size_t, size_t> get_coord() {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
const size_t row = coord[0];
const size_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
Expand Down Expand Up @@ -196,7 +187,7 @@ class wi_element {

#if __SYCL_DEVICE_ONLY__
#define OP(op) \
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
static_cast<storage_element_type>( \
Expand All @@ -211,7 +202,7 @@ class wi_element {
}
#else // __SYCL_DEVICE_ONLY__
#define OP(op) \
template <typename T2> wi_element &operator op##=(const T2 &rhs) { \
template <typename T2> wi_element &operator op##=(const T2 & rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
Expand Down Expand Up @@ -315,7 +306,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

#if __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
M.spvm = __spirv_VectorInsertDynamic( \
M.spvm, \
__spirv_VectorExtractDynamic< \
Expand All @@ -328,7 +319,7 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
}
#else // __SYCL_DEVICE_ONLY__
#define OP(opassign, op) \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 &rhs) { \
wi_element &operator opassign(const sycl::ext::oneapi::bfloat16 & rhs) { \
(void)rhs; \
throw runtime_error("joint matrix is not supported on host device.", \
PI_ERROR_INVALID_DEVICE); \
Expand Down Expand Up @@ -479,7 +470,10 @@ get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
}

// End wi_data definition
} // namespace detail
} // namespace oneapi

namespace intel::experimental::matrix {
template <
typename Group, typename T, typename Tp,
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
Expand All @@ -490,7 +484,7 @@ template <
bool> = true>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_store(Group sg,
sycl::ext::oneapi::experimental::matrix::joint_matrix<
const sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, Tp, Use, NumRows, NumCols, Layout> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride) {
#if defined(__SYCL_DEVICE_ONLY__)
Expand Down Expand Up @@ -528,6 +522,43 @@ joint_matrix_store(Group sg,
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T,
sycl::ext::oneapi::experimental::matrix::use Use, size_t M, size_t N,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to be consistent with spec. Also depending on matrix instead of M and N, it can be N and K, etc..., so to avoid ambiguity, I suggest not to use M and N for API that can be applied to matrices with different use

Suggested change
sycl::ext::oneapi::experimental::matrix::use Use, size_t M, size_t N,
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows, size_t Cols,

sycl::ext::oneapi::experimental::matrix::layout Layout, typename F>
inline __SYCL_ALWAYS_INLINE void joint_matrix_apply(
Group sg,
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, M, N,
Layout> &jm,
F &&lambda) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) {
lambda(jm.cuda_impl.wi_marray[i]);
}
#else // NVPTX
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type element = wi_data_c[i];
auto [row, col] = wi_data_c[i].get_coord();
lambda(element, row, col);
wi_data_c[i] = element;
}
#endif
#else
std::ignore = sg;
std::ignore = jm;
std::ignore = lambda;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return statement is unnecessary, please remove.

Suggested change
return;

}

} // namespace intel::experimental::matrix

} // namespace ext
Expand Down
12 changes: 6 additions & 6 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcores.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ void store_layoutT(
template <typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store_cuda(
joint_matrix_cuda<
const joint_matrix_cuda<
T, sycl::ext::oneapi::experimental::matrix::use::accumulator, NumRows,
NumCols, sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
Expand Down Expand Up @@ -482,11 +482,11 @@ void joint_matrix_mad_cuda(
joint_matrix_cuda<
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a, M, K,
LayoutA> &A,
joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b, K, N,
LayoutB> &B,
joint_matrix_cuda<
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::a,
M, K, LayoutA> &A,
const joint_matrix_cuda<Tm, sycl::ext::oneapi::experimental::matrix::use::b,
K, N, LayoutB> &B,
const joint_matrix_cuda<
Tc, sycl::ext::oneapi::experimental::matrix::use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &C) {
if constexpr (M == 16 && N == 16 && K == 16) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ namespace matrix {

enum class use { a, b, accumulator };

enum class layout { row_major = 0, col_major = 1, dynamic = 3 };
enum class layout {
row_major = 0,
col_major = 1,
ext_intel_packed = 2,
dynamic = 3
};

namespace precision {
class tf32 {
Expand Down
116 changes: 67 additions & 49 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,8 @@ struct joint_matrix {
}
#ifdef __SYCL_DEVICE_ONLY__
#if defined(__SPIR__)
// Generate a non-trivial assignment operator and copy c'tor that prevents
// memcpy from being generated.
// TODO: to remove, when either IGC can handle alloca JointMatrix or
// combination of InstCombine + SROA + mem2reg can remove it
joint_matrix(const joint_matrix &other) {
spvm = other.spvm;
return *this;
}

joint_matrix &operator=(const joint_matrix &rhs) {
spvm = rhs.spvm;
return *this;
}
joint_matrix(const joint_matrix &other) = delete;
joint_matrix &operator=(const joint_matrix &rhs) = delete;
#endif // defined(__SPIR__)
#endif
};
Expand All @@ -98,19 +87,21 @@ class wi_data {
#if defined(__NVPTX__)
return jm.cuda_impl.wi_marray.size();
#else
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
throw runtime_error(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have not found wi_data in specification. Is the intent to remove it? If so, should the "deprecated" attribute be added to inform at compile time not to use this class at all, not only methods? Or should it just be removed?

"get_wi_data is available using: ext::oneapi::detail::get_wi_data, but "
"intel users are expected to use joint_matrix_copy instead.",
PI_ERROR_INVALID_DEVICE);
#endif
};

decltype(auto) operator[](size_t i) {
#if defined(__NVPTX__)
return (jm.cuda_impl.wi_marray[i]);
#else
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
throw runtime_error(
"get_wi_data is available using: ext::oneapi::detail::get_wi_data, but "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just say: "get_wi_data is unavailable, use joint_matrix_copy instead."

"intel users are expected to use joint_matrix_copy instead.",
PI_ERROR_INVALID_DEVICE);
#endif
};
};
Expand Down Expand Up @@ -138,9 +129,9 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this. This is not really deprecated as joint_matrix is experimental so we can just remove APIs. Deprecated means they still exist and implementations maintain them. In the case of get_wi_data. it is replaced by joint_matrix_apply

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this addressed?

Copy link
Contributor

@dkhaldi dkhaldi Oct 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be addressed by @JackAKirk among other CUDA changes in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will make this change as soon as this PR is merged.

"use joint_matrix_apply() instead.")
#else
__attribute__((unavailable(
"get_wi_data can't be used on intel device, please use "
"sycl::ext::intel::experimental::matrix::get_wi_data instead!")))
__attribute__((
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If get_wi_data is intended to be removed, we should not show messages that on CUDA backend it is deprecated and on Intel devices it can't be used. I think get_wi_data should either be removed or consistent deprecation messaging should be provided for all backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__SYCL2020_DEPRECATED pop comp warning while "attribute(( unavailable" pop compfail.
we can't let intel users to use ext::oneapi::experimental::matrix::get_wi_data. so "attribute(( unavailable" is reasonable. @dkhaldi , right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should change the Nvidia deprecate message as well. Deprecation means it is still supported. get_wi_data is not supported at all. We removed and replaced it with joint_matrix_apply. So a better message for all backends can be:
"get_wi_data() has been removed from the API and replaced with joint_matrix_apply"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use "get_wi_data() has been removed from the API and replaced with joint_matrix_apply" in all messages

unavailable("get_wi_data can't be used on intel device, please use "
"joint_matrix_apply instead!")))
#endif
#endif
inline __SYCL_ALWAYS_INLINE decltype(auto)
Expand Down Expand Up @@ -176,7 +167,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T>::storage_element_type;
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm);
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm);
for (int i = 0; i < wi_data_c.length(); i++) {
storage_element_type element = wi_data_c[i];
lambda(element);
Expand Down Expand Up @@ -262,7 +253,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
Ptr, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case sycl::ext::intel::experimental::matrix::layout::packed:
case layout::ext_intel_packed:
res.spvm = __spirv_JointMatrixLoadINTEL<
DecorT, S, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
Expand Down Expand Up @@ -327,8 +318,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols,
access::address_space Space, access::decorated IsDecorated>
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Group sg,
joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src,
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&src,
multi_ptr<T, Space, IsDecorated> dst, size_t stride,
sycl::ext::oneapi::experimental::matrix::layout Layout) {
#if defined(__SYCL_DEVICE_ONLY__)
Expand Down Expand Up @@ -361,7 +353,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor,
spv_scope_traits<Group>::value);
break;
case sycl::ext::intel::experimental::matrix::layout::packed:
case layout::ext_intel_packed:
__spirv_JointMatrixStoreINTEL<
DecorT, T, NumRows, NumCols,
spv_matrix_use_traits<use::accumulator>::value,
Expand All @@ -382,53 +374,79 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M,
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB>
inline __SYCL_ALWAYS_INLINE
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
joint_matrix_mad(
Group sg, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&C) {
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
std::size_t M, std::size_t K, std::size_t N, layout LayoutA,
layout LayoutB>
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad(
Group sg,
joint_matrix<Group, Td, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D,
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
const joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
&C) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
if constexpr (std::is_same<Ta, Tb>::value) {
joint_matrix<Group, Tc, use::accumulator, M, N,
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
D;
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA,
LayoutB>(
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl);
return D;
} else {
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad "
"requires that joint_matrix data types Ta and Tb match");
}
#else
joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res;
if constexpr (std::is_same<Ta, uint16_t>::value &&
std::is_same<Tb, uint16_t>::value &&
std::is_same<Tc, float>::value)
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value)
res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value)
res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm);
else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value)
res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm);
else
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
return res;
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm);
#endif // defined(__NVPTX__)
#else
std::ignore = sg;
std::ignore = A;
std::ignore = B;
std::ignore = C;
std::ignore = D;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
use Use1, use Use2, layout Layout1, layout Layout2>
void joint_matrix_copy(
Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src,
joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) {
dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i];
}
#else
using storage_element_type =
typename oneapi::detail::jm_type_interpretation_helper_trait<
T2>::storage_element_type;
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src);
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst);
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]);
}
#endif // defined(__NVPTX__)
#else
std::ignore = sg;
std::ignore = dst;
std::ignore = src;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // defined(__SYCL_DEVICE_ONLY__)
Expand Down
Loading