From 05096e7b56a120fbb83c54a9add4fe64be17c322 Mon Sep 17 00:00:00 2001 From: Pawel Swider Date: Fri, 8 Aug 2025 12:25:08 +0000 Subject: [PATCH 1/5] Matmul complex POC --- src/ATen/native/xpu/Blas.cpp | 44 ++++++++++++++++++ src/ATen/native/xpu/mkl/SpectralOps.cpp | 62 +++++++++++++++++++++++++ 2 files changed, 106 insertions(+) create mode 100644 src/ATen/native/xpu/Blas.cpp diff --git a/src/ATen/native/xpu/Blas.cpp b/src/ATen/native/xpu/Blas.cpp new file mode 100644 index 0000000000..2c5af7abd2 --- /dev/null +++ b/src/ATen/native/xpu/Blas.cpp @@ -0,0 +1,44 @@ +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { + + at::Tensor& mm_complex_out_xpu(const at::Tensor &self, const at::Tensor &mat2, at::Tensor &out) { + at::Tensor self_cont = self.contiguous(); + at::Tensor mat2_cont = mat2.contiguous(); + at::Tensor out_cont = out.contiguous(); + + const int64_t m = self_cont.sizes().at(0); + const int64_t n = mat2_cont.sizes().at(1); + const int64_t k = self_cont.sizes().at(1); + + constexpr std::complex alpha = {1.0f, 0.0f}; + constexpr std::complex beta = {0.0f, 0.0f}; + + oneapi::mkl::blas::row_major::gemm( + at::xpu::getCurrentSYCLQueue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + alpha, + reinterpret_cast*>(self_cont.const_data_ptr()), + k, + reinterpret_cast*>(mat2_cont.const_data_ptr()), + n, + beta, + reinterpret_cast*>(out_cont.data_ptr()), + n); + + return out; +} + +REGISTER_XPU_DISPATCH(mm_complex_stub, &mm_complex_out_xpu) + +} // namespace at::native \ No newline at end of file diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 96d4118f5b..492c4da6cd 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -10,6 +10,10 @@ #include #include #include +#include +#include +#include +#include using namespace oneapi::mkl::dft; @@ -578,3 +582,61 @@ Tensor& _fft_r2c_mkl_out( } } // namespace at::native::xpu + + +namespace at::native::xpu { + +at::Tensor& mm_out_xpu(at::Tensor &out, const at::Tensor &self, const at::Tensor &mat2) { + at::Tensor self_cont = self.contiguous(); + at::Tensor mat2_cont = mat2.contiguous(); + at::Tensor out_cont = out.contiguous(); + + const int64_t m = self_cont.sizes().at(0); + const int64_t n = mat2_cont.sizes().at(1); + const int64_t k = self_cont.sizes().at(1); + + constexpr std::complex alpha = {1.0f, 0.0f}; + constexpr std::complex beta = {0.0f, 0.0f}; + + oneapi::mkl::blas::row_major::gemm( + at::xpu::getCurrentSYCLQueue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + alpha, + reinterpret_cast*>(self_cont.const_data_ptr()), + k, + reinterpret_cast*>(mat2_cont.const_data_ptr()), + n, + beta, + reinterpret_cast*>(out_cont.data_ptr()), + n); + + return out; +} + +Tensor mm_xpu(const Tensor& self, const Tensor& other) { + TORCH_CHECK(self.is_xpu() && other.is_xpu(), + "mm_xpu only supports XPU tensors"); + + // Your SYCL implementation here + auto result = at::empty({self.size(0), other.size(1)}, self.options()); + + std::cout << "Example change" << std::endl; + mm_out_xpu(result, self, other); + + return result; +} + +} // namespace at::native::xpu + +// Register ONLY for XPU +TORCH_LIBRARY(xpu_ops, m) { + m.def("mm_xpu(Tensor self, Tensor other) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(xpu_ops, XPU, m) { + m.impl("mm_xpu", TORCH_FN(at::native::xpu::mm_xpu)); +} \ No newline at end of file From e865b3f1311890ad3f6c54fa4e313b50c84ac40e Mon Sep 17 00:00:00 2001 From: Pawel Swider Date: Tue, 26 Aug 2025 13:45:15 +0000 Subject: [PATCH 2/5] MM kernels improvements --- src/ATen/native/xpu/Blas.cpp | 320 +++++++++++++++++++++++++++++++---- 1 file changed, 287 insertions(+), 33 deletions(-) diff --git a/src/ATen/native/xpu/Blas.cpp b/src/ATen/native/xpu/Blas.cpp index 2c5af7abd2..a2bb2f9e0e 100644 --- a/src/ATen/native/xpu/Blas.cpp +++ b/src/ATen/native/xpu/Blas.cpp @@ -1,44 +1,298 @@ -#include #include -#include -#include #include -#include +#include +#include #include +#include +#include namespace at::native { - at::Tensor& mm_complex_out_xpu(const at::Tensor &self, const at::Tensor &mat2, at::Tensor &out) { - at::Tensor self_cont = self.contiguous(); - at::Tensor mat2_cont = mat2.contiguous(); - at::Tensor out_cont = out.contiguous(); - - const int64_t m = self_cont.sizes().at(0); - const int64_t n = mat2_cont.sizes().at(1); - const int64_t k = self_cont.sizes().at(1); - - constexpr std::complex alpha = {1.0f, 0.0f}; - constexpr std::complex beta = {0.0f, 0.0f}; - - oneapi::mkl::blas::row_major::gemm( - at::xpu::getCurrentSYCLQueue(), - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::nontrans, - m, - n, - k, - alpha, - reinterpret_cast*>(self_cont.const_data_ptr()), - k, - reinterpret_cast*>(mat2_cont.const_data_ptr()), - n, - beta, - reinterpret_cast*>(out_cont.data_ptr()), - n); - - return out; +inline at::Tensor resolveViewsAndConjugation(const at::Tensor& input) { + at::Tensor input_resolved = input.is_conj() ? input.resolve_conj() : input; + at::Tensor input_contiguous = input_resolved.is_contiguous() + ? input_resolved + : input_resolved.contiguous(); + + return input_contiguous; +} + +template +at::Tensor& mm_complex_out_xpu_impl( + const at::Tensor& self, + const at::Tensor& mat2, + at::Tensor& out) { + at::Tensor self_cont = resolveViewsAndConjugation(self); + at::Tensor mat2_cont = resolveViewsAndConjugation(mat2); + at::Tensor out_cont = resolveViewsAndConjugation(out); + + const int64_t m = self_cont.sizes().at(0); + const int64_t n = mat2_cont.sizes().at(1); + const int64_t k = self_cont.sizes().at(1); + + constexpr std::complex alpha = {T(1.0), T(0.0)}; + constexpr std::complex beta = {T(0.0), T(0.0)}; + + oneapi::mkl::blas::row_major::gemm( + c10::xpu::getCurrentXPUStream().queue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + alpha, + reinterpret_cast*>(self_cont.const_data_ptr()), + k, + reinterpret_cast*>(mat2_cont.const_data_ptr()), + n, + beta, + reinterpret_cast*>(out_cont.data_ptr()), + n); + + if (!out.is_same(out_cont)) { + out.copy_(out_cont); + } + + return out; +} + +at::Tensor& mm_complex_out_xpu( + const at::Tensor& self, + const at::Tensor& mat2, + at::Tensor& out) { + at::Tensor out_ref = at::mm(self.cpu(), mat2.cpu()); + + AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "mm_complex_out_xpu", [&] { + using underlying_t = typename c10::scalar_value_type::type; + mm_complex_out_xpu_impl(self, mat2, out); + }); + + return out; +} + +template +at::Tensor& bmm_complex_out_xpu_impl( + const at::Tensor& self, + const at::Tensor& batch2, + at::Tensor& out) { + at::Tensor self_cont = resolveViewsAndConjugation(self); + at::Tensor batch2_cont = resolveViewsAndConjugation(batch2); + at::Tensor out_cont = resolveViewsAndConjugation(out); + + const int64_t batch_size = self_cont.sizes().at(0); + const int64_t m = self_cont.sizes().at(1); + const int64_t n = batch2_cont.sizes().at(2); + const int64_t k = self_cont.sizes().at(2); + + constexpr std::complex alpha = {T(1.0f), T(0.0f)}; + constexpr std::complex beta = {T(0.0f), T(0.0f)}; + + oneapi::mkl::blas::row_major::gemm_batch( + c10::xpu::getCurrentXPUStream().queue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + alpha, + reinterpret_cast*>(self_cont.const_data_ptr()), + k, + m * k, + reinterpret_cast*>(batch2_cont.const_data_ptr()), + n, + k * n, + beta, + reinterpret_cast*>(out_cont.data_ptr()), + n, + m * n, + batch_size); + + if (!out.is_same(out_cont)) { + out.copy_(out_cont); + } + + return out; +} + +at::Tensor& bmm_complex_out_xpu( + const at::Tensor& self, + const at::Tensor& mat2, + at::Tensor& out) { + + AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "bmm_complex_out_xpu", [&] { + using underlying_t = typename c10::scalar_value_type::type; + bmm_complex_out_xpu_impl(self, mat2, out); + }); + + return out; +} + +template +at::Tensor& addmm_complex_out_xpu_impl( + const Tensor& input, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { + at::Tensor mat1_cont = resolveViewsAndConjugation(mat1); + at::Tensor mat2_cont = resolveViewsAndConjugation(mat2); + at::Tensor input_cont = resolveViewsAndConjugation(input).clone().detach(); + + const int64_t m = mat1_cont.sizes().at(0); + const int64_t n = mat2_cont.sizes().at(1); + const int64_t k = mat1_cont.sizes().at(1); + + // Some paths in the code below do not handle multiplications of the form [n, 0] x [0, m] + if (k == 0) { + if (result.numel() == 0) { + return result; + } + if (beta.toComplexDouble() == 0.0) { + result.zero_(); + } else { + if (!input.is_same(result)) { + result.copy_(input); + } + result.mul_(beta); + } + return result; + } + + if (m == 0 || n == 0) { + return result; + } + + const std::vector mm_output_size = {m, n}; + if (input_cont.sizes() != mm_output_size) { + input_cont = at::broadcast_to(input_cont, mm_output_size).contiguous(); + } + + + std::complex complex_alpha = + static_cast>(alpha.toComplexDouble()); + std::complex complex_beta = + static_cast>(beta.toComplexDouble()); + + oneapi::mkl::blas::row_major::gemm( + c10::xpu::getCurrentXPUStream().queue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + complex_alpha, + reinterpret_cast*>(mat1_cont.const_data_ptr()), + k, + reinterpret_cast*>(mat2_cont.const_data_ptr()), + n, + complex_beta, + reinterpret_cast*>(input_cont.data_ptr()), + n); + + if (result.sizes() == input_cont.sizes()) { + result.copy_(input_cont); + } else { + result.copy_(input_cont.view(result.sizes())); + } + + return result; +} + +at::Tensor& addmm_complex_out_xpu( + const Tensor& input, + const Tensor& mat1, + const Tensor& mat2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { + + AT_DISPATCH_COMPLEX_TYPES(input.scalar_type(), "addmm_complex_out_xpu", [&] { + using underlying_t = typename c10::scalar_value_type::type; + addmm_complex_out_xpu_impl( + input, mat1, mat2, beta, alpha, result); + }); + + return result; +} + +template +at::Tensor& baddbmm_complex_out_xpu_impl( + const Tensor& input, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { + at::Tensor batch1_cont = resolveViewsAndConjugation(batch1); + at::Tensor batch2_cont = resolveViewsAndConjugation(batch2); + at::Tensor input_cont = resolveViewsAndConjugation(input).clone().detach(); + + const int64_t batch_size = batch1_cont.sizes().at(0); + const int64_t m = batch1_cont.sizes().at(1); + const int64_t n = batch2_cont.sizes().at(2); + const int64_t k = batch1_cont.sizes().at(2); + + const std::vector mm_output_size = {batch_size, m, n}; + if (input_cont.sizes() != mm_output_size) { + input_cont = at::broadcast_to(input_cont, mm_output_size).contiguous();; + } + + std::complex complex_alpha = + static_cast>(alpha.toComplexDouble()); + std::complex complex_beta = + static_cast>(beta.toComplexDouble()); + + oneapi::mkl::blas::row_major::gemm_batch( + c10::xpu::getCurrentXPUStream().queue(), + oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::nontrans, + m, + n, + k, + complex_alpha, + reinterpret_cast*>(batch1_cont.const_data_ptr()), + k, + m * k, + reinterpret_cast*>(batch2_cont.const_data_ptr()), + n, + k * n, + complex_beta, + reinterpret_cast*>(input_cont.data_ptr()), + n, + m * n, + batch_size); + + if (result.sizes() == input_cont.sizes()) { + result.copy_(input_cont); + } else { + result.copy_(input_cont.view(result.sizes())); + } + + return result; +} + +at::Tensor& baddbmm_complex_out_xpu( + const Tensor& input, + const Tensor& batch1, + const Tensor& batch2, + const Scalar& beta, + const Scalar& alpha, + Tensor& result) { + + AT_DISPATCH_COMPLEX_TYPES( + input.scalar_type(), "baddbmm_complex_out_xpu", [&] { + using underlying_t = typename c10::scalar_value_type::type; + baddbmm_complex_out_xpu_impl( + input, batch1, batch2, beta, alpha, result); + }); + + return result; } REGISTER_XPU_DISPATCH(mm_complex_stub, &mm_complex_out_xpu) +REGISTER_XPU_DISPATCH(bmm_complex_stub, &bmm_complex_out_xpu) +REGISTER_XPU_DISPATCH(addmm_complex_stub, &addmm_complex_out_xpu) +REGISTER_XPU_DISPATCH(baddbmm_complex_stub, &baddbmm_complex_out_xpu) } // namespace at::native \ No newline at end of file From 55dc07e217aaf47ac60d5c5657f43dbd9de2d812 Mon Sep 17 00:00:00 2001 From: Pawel Swider Date: Thu, 28 Aug 2025 12:44:05 +0000 Subject: [PATCH 3/5] Switch to TORCH_LIBRARY makro --- src/ATen/native/xpu/Blas.cpp | 97 +++++++++++++------------ src/ATen/native/xpu/mkl/SpectralOps.cpp | 62 ---------------- 2 files changed, 52 insertions(+), 107 deletions(-) diff --git a/src/ATen/native/xpu/Blas.cpp b/src/ATen/native/xpu/Blas.cpp index a2bb2f9e0e..6738da10df 100644 --- a/src/ATen/native/xpu/Blas.cpp +++ b/src/ATen/native/xpu/Blas.cpp @@ -1,6 +1,4 @@ #include -#include -#include #include #include #include @@ -129,15 +127,15 @@ at::Tensor& bmm_complex_out_xpu( template at::Tensor& addmm_complex_out_xpu_impl( - const Tensor& input, + const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, - Tensor& result) { + Tensor& out) { at::Tensor mat1_cont = resolveViewsAndConjugation(mat1); at::Tensor mat2_cont = resolveViewsAndConjugation(mat2); - at::Tensor input_cont = resolveViewsAndConjugation(input).clone().detach(); + at::Tensor self_cont = resolveViewsAndConjugation(self).clone().detach(); const int64_t m = mat1_cont.sizes().at(0); const int64_t n = mat2_cont.sizes().at(1); @@ -145,27 +143,27 @@ at::Tensor& addmm_complex_out_xpu_impl( // Some paths in the code below do not handle multiplications of the form [n, 0] x [0, m] if (k == 0) { - if (result.numel() == 0) { - return result; + if (out.numel() == 0) { + return out; } if (beta.toComplexDouble() == 0.0) { - result.zero_(); + out.zero_(); } else { - if (!input.is_same(result)) { - result.copy_(input); + if (!self.is_same(out)) { + out.copy_(self); } - result.mul_(beta); + out.mul_(beta); } - return result; + return out; } if (m == 0 || n == 0) { - return result; + return out; } const std::vector mm_output_size = {m, n}; - if (input_cont.sizes() != mm_output_size) { - input_cont = at::broadcast_to(input_cont, mm_output_size).contiguous(); + if (self_cont.sizes() != mm_output_size) { + self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous(); } @@ -187,46 +185,46 @@ at::Tensor& addmm_complex_out_xpu_impl( reinterpret_cast*>(mat2_cont.const_data_ptr()), n, complex_beta, - reinterpret_cast*>(input_cont.data_ptr()), + reinterpret_cast*>(self_cont.data_ptr()), n); - if (result.sizes() == input_cont.sizes()) { - result.copy_(input_cont); + if (out.sizes() == self_cont.sizes()) { + out.copy_(self_cont); } else { - result.copy_(input_cont.view(result.sizes())); + out.copy_(self_cont.view(out.sizes())); } - return result; + return out; } at::Tensor& addmm_complex_out_xpu( - const Tensor& input, + const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, - Tensor& result) { + Tensor& out) { - AT_DISPATCH_COMPLEX_TYPES(input.scalar_type(), "addmm_complex_out_xpu", [&] { + AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "addmm_complex_out_xpu", [&] { using underlying_t = typename c10::scalar_value_type::type; addmm_complex_out_xpu_impl( - input, mat1, mat2, beta, alpha, result); + self, mat1, mat2, beta, alpha, out); }); - return result; + return out; } template at::Tensor& baddbmm_complex_out_xpu_impl( - const Tensor& input, + const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, - Tensor& result) { + Tensor& out) { at::Tensor batch1_cont = resolveViewsAndConjugation(batch1); at::Tensor batch2_cont = resolveViewsAndConjugation(batch2); - at::Tensor input_cont = resolveViewsAndConjugation(input).clone().detach(); + at::Tensor self_cont = resolveViewsAndConjugation(self).clone().detach(); const int64_t batch_size = batch1_cont.sizes().at(0); const int64_t m = batch1_cont.sizes().at(1); @@ -234,8 +232,8 @@ at::Tensor& baddbmm_complex_out_xpu_impl( const int64_t k = batch1_cont.sizes().at(2); const std::vector mm_output_size = {batch_size, m, n}; - if (input_cont.sizes() != mm_output_size) { - input_cont = at::broadcast_to(input_cont, mm_output_size).contiguous();; + if (self_cont.sizes() != mm_output_size) { + self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous();; } std::complex complex_alpha = @@ -258,41 +256,50 @@ at::Tensor& baddbmm_complex_out_xpu_impl( n, k * n, complex_beta, - reinterpret_cast*>(input_cont.data_ptr()), + reinterpret_cast*>(self_cont.data_ptr()), n, m * n, batch_size); - if (result.sizes() == input_cont.sizes()) { - result.copy_(input_cont); + if (out.sizes() == self_cont.sizes()) { + out.copy_(self_cont); } else { - result.copy_(input_cont.view(result.sizes())); + out.copy_(self_cont.view(out.sizes())); } - return result; + return out; } at::Tensor& baddbmm_complex_out_xpu( - const Tensor& input, + const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha, - Tensor& result) { + Tensor& out) { AT_DISPATCH_COMPLEX_TYPES( - input.scalar_type(), "baddbmm_complex_out_xpu", [&] { + self.scalar_type(), "baddbmm_complex_out_xpu", [&] { using underlying_t = typename c10::scalar_value_type::type; baddbmm_complex_out_xpu_impl( - input, batch1, batch2, beta, alpha, result); + self, batch1, batch2, beta, alpha, out); }); - return result; + return out; } -REGISTER_XPU_DISPATCH(mm_complex_stub, &mm_complex_out_xpu) -REGISTER_XPU_DISPATCH(bmm_complex_stub, &bmm_complex_out_xpu) -REGISTER_XPU_DISPATCH(addmm_complex_stub, &addmm_complex_out_xpu) -REGISTER_XPU_DISPATCH(baddbmm_complex_stub, &baddbmm_complex_out_xpu) +TORCH_LIBRARY(xpu_mkl, m) { + m.def("xpu_mkl::mm(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); + m.def("xpu_mkl::bmm(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); + m.def("xpu_mkl::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); + m.def("xpu_mkl::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); +} + +TORCH_LIBRARY_IMPL(xpu_mkl, XPU, m) { + m.impl("xpu_mkl::mm", mm_complex_out_xpu); + m.impl("xpu_mkl::bmm", bmm_complex_out_xpu); + m.impl("xpu_mkl::addmm", addmm_complex_out_xpu); + m.impl("xpu_mkl::baddbmm", baddbmm_complex_out_xpu); +} -} // namespace at::native \ No newline at end of file +} // namespace at::native diff --git a/src/ATen/native/xpu/mkl/SpectralOps.cpp b/src/ATen/native/xpu/mkl/SpectralOps.cpp index 492c4da6cd..96d4118f5b 100644 --- a/src/ATen/native/xpu/mkl/SpectralOps.cpp +++ b/src/ATen/native/xpu/mkl/SpectralOps.cpp @@ -10,10 +10,6 @@ #include #include #include -#include -#include -#include -#include using namespace oneapi::mkl::dft; @@ -582,61 +578,3 @@ Tensor& _fft_r2c_mkl_out( } } // namespace at::native::xpu - - -namespace at::native::xpu { - -at::Tensor& mm_out_xpu(at::Tensor &out, const at::Tensor &self, const at::Tensor &mat2) { - at::Tensor self_cont = self.contiguous(); - at::Tensor mat2_cont = mat2.contiguous(); - at::Tensor out_cont = out.contiguous(); - - const int64_t m = self_cont.sizes().at(0); - const int64_t n = mat2_cont.sizes().at(1); - const int64_t k = self_cont.sizes().at(1); - - constexpr std::complex alpha = {1.0f, 0.0f}; - constexpr std::complex beta = {0.0f, 0.0f}; - - oneapi::mkl::blas::row_major::gemm( - at::xpu::getCurrentSYCLQueue(), - oneapi::mkl::transpose::nontrans, - oneapi::mkl::transpose::nontrans, - m, - n, - k, - alpha, - reinterpret_cast*>(self_cont.const_data_ptr()), - k, - reinterpret_cast*>(mat2_cont.const_data_ptr()), - n, - beta, - reinterpret_cast*>(out_cont.data_ptr()), - n); - - return out; -} - -Tensor mm_xpu(const Tensor& self, const Tensor& other) { - TORCH_CHECK(self.is_xpu() && other.is_xpu(), - "mm_xpu only supports XPU tensors"); - - // Your SYCL implementation here - auto result = at::empty({self.size(0), other.size(1)}, self.options()); - - std::cout << "Example change" << std::endl; - mm_out_xpu(result, self, other); - - return result; -} - -} // namespace at::native::xpu - -// Register ONLY for XPU -TORCH_LIBRARY(xpu_ops, m) { - m.def("mm_xpu(Tensor self, Tensor other) -> Tensor"); -} - -TORCH_LIBRARY_IMPL(xpu_ops, XPU, m) { - m.impl("mm_xpu", TORCH_FN(at::native::xpu::mm_xpu)); -} \ No newline at end of file From 963531cfcb35397e46e4807937052b59c83446a4 Mon Sep 17 00:00:00 2001 From: Pawel Swider Date: Fri, 29 Aug 2025 10:52:07 +0000 Subject: [PATCH 4/5] Refactor --- src/ATen/native/xpu/Blas.cpp | 139 ++++++++++++++++++----------------- 1 file changed, 70 insertions(+), 69 deletions(-) diff --git a/src/ATen/native/xpu/Blas.cpp b/src/ATen/native/xpu/Blas.cpp index 6738da10df..31527722d1 100644 --- a/src/ATen/native/xpu/Blas.cpp +++ b/src/ATen/native/xpu/Blas.cpp @@ -6,13 +6,18 @@ namespace at::native { -inline at::Tensor resolveViewsAndConjugation(const at::Tensor& input) { - at::Tensor input_resolved = input.is_conj() ? input.resolve_conj() : input; - at::Tensor input_contiguous = input_resolved.is_contiguous() - ? input_resolved - : input_resolved.contiguous(); +#if defined(USE_ONEMKL_XPU) - return input_contiguous; +at::Tensor& handle_output_copy(at::Tensor& out, const at::Tensor& result) { + if (!out.is_same(result)) { + if (out.sizes() == result.sizes()) { + out.copy_(result); + } else { + out.copy_(result.view(out.sizes())); + } + } + + return out; } template @@ -20,16 +25,16 @@ at::Tensor& mm_complex_out_xpu_impl( const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { - at::Tensor self_cont = resolveViewsAndConjugation(self); - at::Tensor mat2_cont = resolveViewsAndConjugation(mat2); - at::Tensor out_cont = resolveViewsAndConjugation(out); + at::Tensor self_cont = self.contiguous().resolve_conj(); + at::Tensor mat2_cont = mat2.contiguous().resolve_conj(); + at::Tensor out_cont = out.contiguous().resolve_conj(); const int64_t m = self_cont.sizes().at(0); const int64_t n = mat2_cont.sizes().at(1); const int64_t k = self_cont.sizes().at(1); - constexpr std::complex alpha = {T(1.0), T(0.0)}; - constexpr std::complex beta = {T(0.0), T(0.0)}; + constexpr std::complex alpha = {T(1), T(0)}; + constexpr std::complex beta = {T(0), T(0)}; oneapi::mkl::blas::row_major::gemm( c10::xpu::getCurrentXPUStream().queue(), @@ -47,18 +52,15 @@ at::Tensor& mm_complex_out_xpu_impl( reinterpret_cast*>(out_cont.data_ptr()), n); - if (!out.is_same(out_cont)) { - out.copy_(out_cont); - } - - return out; + return handle_output_copy(out, out_cont); } at::Tensor& mm_complex_out_xpu( const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { - at::Tensor out_ref = at::mm(self.cpu(), mat2.cpu()); + TORCH_CHECK( + self.is_complex(), "_mm_mkl.out expects self to be a complex datatype."); AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "mm_complex_out_xpu", [&] { using underlying_t = typename c10::scalar_value_type::type; @@ -71,19 +73,19 @@ at::Tensor& mm_complex_out_xpu( template at::Tensor& bmm_complex_out_xpu_impl( const at::Tensor& self, - const at::Tensor& batch2, + const at::Tensor& mat2, at::Tensor& out) { - at::Tensor self_cont = resolveViewsAndConjugation(self); - at::Tensor batch2_cont = resolveViewsAndConjugation(batch2); - at::Tensor out_cont = resolveViewsAndConjugation(out); + at::Tensor self_cont = self.contiguous().resolve_conj(); + at::Tensor mat2_cont = mat2.contiguous().resolve_conj(); + at::Tensor out_cont = out.contiguous().resolve_conj(); const int64_t batch_size = self_cont.sizes().at(0); const int64_t m = self_cont.sizes().at(1); - const int64_t n = batch2_cont.sizes().at(2); + const int64_t n = mat2_cont.sizes().at(2); const int64_t k = self_cont.sizes().at(2); - constexpr std::complex alpha = {T(1.0f), T(0.0f)}; - constexpr std::complex beta = {T(0.0f), T(0.0f)}; + constexpr std::complex alpha = {T(1), T(0)}; + constexpr std::complex beta = {T(0), T(0)}; oneapi::mkl::blas::row_major::gemm_batch( c10::xpu::getCurrentXPUStream().queue(), @@ -96,7 +98,7 @@ at::Tensor& bmm_complex_out_xpu_impl( reinterpret_cast*>(self_cont.const_data_ptr()), k, m * k, - reinterpret_cast*>(batch2_cont.const_data_ptr()), + reinterpret_cast*>(mat2_cont.const_data_ptr()), n, k * n, beta, @@ -105,17 +107,15 @@ at::Tensor& bmm_complex_out_xpu_impl( m * n, batch_size); - if (!out.is_same(out_cont)) { - out.copy_(out_cont); - } - - return out; + return handle_output_copy(out, out_cont); } at::Tensor& bmm_complex_out_xpu( const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { + TORCH_CHECK( + self.is_complex(), "_bmm_mkl.out expects self to be a complex datatype."); AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "bmm_complex_out_xpu", [&] { using underlying_t = typename c10::scalar_value_type::type; @@ -133,15 +133,14 @@ at::Tensor& addmm_complex_out_xpu_impl( const Scalar& beta, const Scalar& alpha, Tensor& out) { - at::Tensor mat1_cont = resolveViewsAndConjugation(mat1); - at::Tensor mat2_cont = resolveViewsAndConjugation(mat2); - at::Tensor self_cont = resolveViewsAndConjugation(self).clone().detach(); + at::Tensor mat1_cont = mat1.contiguous().resolve_conj(); + at::Tensor mat2_cont = mat2.contiguous().resolve_conj(); + at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); const int64_t m = mat1_cont.sizes().at(0); const int64_t n = mat2_cont.sizes().at(1); const int64_t k = mat1_cont.sizes().at(1); - // Some paths in the code below do not handle multiplications of the form [n, 0] x [0, m] if (k == 0) { if (out.numel() == 0) { return out; @@ -166,7 +165,6 @@ at::Tensor& addmm_complex_out_xpu_impl( self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous(); } - std::complex complex_alpha = static_cast>(alpha.toComplexDouble()); std::complex complex_beta = @@ -188,13 +186,7 @@ at::Tensor& addmm_complex_out_xpu_impl( reinterpret_cast*>(self_cont.data_ptr()), n); - if (out.sizes() == self_cont.sizes()) { - out.copy_(self_cont); - } else { - out.copy_(self_cont.view(out.sizes())); - } - - return out; + return handle_output_copy(out, self_cont); } at::Tensor& addmm_complex_out_xpu( @@ -204,6 +196,9 @@ at::Tensor& addmm_complex_out_xpu( const Scalar& beta, const Scalar& alpha, Tensor& out) { + TORCH_CHECK( + self.is_complex(), + "_addmm_mkl.out expects self to be a complex datatype."); AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "addmm_complex_out_xpu", [&] { using underlying_t = typename c10::scalar_value_type::type; @@ -222,9 +217,9 @@ at::Tensor& baddbmm_complex_out_xpu_impl( const Scalar& beta, const Scalar& alpha, Tensor& out) { - at::Tensor batch1_cont = resolveViewsAndConjugation(batch1); - at::Tensor batch2_cont = resolveViewsAndConjugation(batch2); - at::Tensor self_cont = resolveViewsAndConjugation(self).clone().detach(); + at::Tensor batch1_cont = batch1.contiguous().resolve_conj(); + at::Tensor batch2_cont = batch2.contiguous().resolve_conj(); + at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); const int64_t batch_size = batch1_cont.sizes().at(0); const int64_t m = batch1_cont.sizes().at(1); @@ -233,7 +228,7 @@ at::Tensor& baddbmm_complex_out_xpu_impl( const std::vector mm_output_size = {batch_size, m, n}; if (self_cont.sizes() != mm_output_size) { - self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous();; + self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous(); } std::complex complex_alpha = @@ -261,13 +256,7 @@ at::Tensor& baddbmm_complex_out_xpu_impl( m * n, batch_size); - if (out.sizes() == self_cont.sizes()) { - out.copy_(self_cont); - } else { - out.copy_(self_cont.view(out.sizes())); - } - - return out; + return handle_output_copy(out, self_cont); } at::Tensor& baddbmm_complex_out_xpu( @@ -277,29 +266,41 @@ at::Tensor& baddbmm_complex_out_xpu( const Scalar& beta, const Scalar& alpha, Tensor& out) { + TORCH_CHECK( + self.is_complex(), + "_baddbmm_mkl.out expects self to be a complex datatype."); - AT_DISPATCH_COMPLEX_TYPES( - self.scalar_type(), "baddbmm_complex_out_xpu", [&] { - using underlying_t = typename c10::scalar_value_type::type; - baddbmm_complex_out_xpu_impl( - self, batch1, batch2, beta, alpha, out); - }); + AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "baddbmm_complex_out_xpu", [&] { + using underlying_t = typename c10::scalar_value_type::type; + baddbmm_complex_out_xpu_impl( + self, batch1, batch2, beta, alpha, out); + }); return out; } -TORCH_LIBRARY(xpu_mkl, m) { - m.def("xpu_mkl::mm(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); - m.def("xpu_mkl::bmm(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); - m.def("xpu_mkl::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); - m.def("xpu_mkl::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); +#endif // USE_ONEMKL_XPU + +TORCH_LIBRARY_FRAGMENT(aten, m) { + m.def( + "aten::_mm_mkl.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "aten::_bmm_mkl.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "aten::_addmm_mkl.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); + m.def( + "aten::_baddbmm_mkl.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)"); } -TORCH_LIBRARY_IMPL(xpu_mkl, XPU, m) { - m.impl("xpu_mkl::mm", mm_complex_out_xpu); - m.impl("xpu_mkl::bmm", bmm_complex_out_xpu); - m.impl("xpu_mkl::addmm", addmm_complex_out_xpu); - m.impl("xpu_mkl::baddbmm", baddbmm_complex_out_xpu); +#if defined(USE_ONEMKL_XPU) + +TORCH_LIBRARY_IMPL(aten, XPU, m) { + m.impl("aten::_mm_mkl.out", mm_complex_out_xpu); + m.impl("aten::_bmm_mkl.out", bmm_complex_out_xpu); + m.impl("aten::_addmm_mkl.out", addmm_complex_out_xpu); + m.impl("aten::_baddbmm_mkl.out", baddbmm_complex_out_xpu); } +#endif // USE_ONEMKL_XPU + } // namespace at::native From 312d8edc1b73888d8828e6d9f04a86c05880c824 Mon Sep 17 00:00:00 2001 From: Pawel Swider Date: Mon, 1 Sep 2025 13:02:28 +0000 Subject: [PATCH 5/5] Add device guard --- src/ATen/native/xpu/Blas.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/ATen/native/xpu/Blas.cpp b/src/ATen/native/xpu/Blas.cpp index 31527722d1..7434a197a8 100644 --- a/src/ATen/native/xpu/Blas.cpp +++ b/src/ATen/native/xpu/Blas.cpp @@ -59,6 +59,7 @@ at::Tensor& mm_complex_out_xpu( const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { + c10::DeviceGuard guard(self.device()); TORCH_CHECK( self.is_complex(), "_mm_mkl.out expects self to be a complex datatype."); @@ -114,6 +115,7 @@ at::Tensor& bmm_complex_out_xpu( const at::Tensor& self, const at::Tensor& mat2, at::Tensor& out) { + c10::DeviceGuard guard(self.device()); TORCH_CHECK( self.is_complex(), "_bmm_mkl.out expects self to be a complex datatype."); @@ -196,6 +198,7 @@ at::Tensor& addmm_complex_out_xpu( const Scalar& beta, const Scalar& alpha, Tensor& out) { + c10::DeviceGuard guard(self.device()); TORCH_CHECK( self.is_complex(), "_addmm_mkl.out expects self to be a complex datatype."); @@ -266,6 +269,7 @@ at::Tensor& baddbmm_complex_out_xpu( const Scalar& beta, const Scalar& alpha, Tensor& out) { + c10::DeviceGuard guard(self.device()); TORCH_CHECK( self.is_complex(), "_baddbmm_mkl.out expects self to be a complex datatype.");