Skip to content

Conversation

PawelSwider2000
Copy link

Implementation of kernels for complex datatype support for 4 ops: mm, bmm,addmm, baddbmm using OneMKL.

Current implementation of this ops for XPU is in pytorch/aten/src/ATen/native/mkldnn/xpu/Blas.cpp. Since OneMKL is a torch-xpu-ops dependency and is available only with USE_ONEMKL_XPU=ON (which is a default value). Implementation needs to be in torch-ops-xpu and kernels and TORCH_LIBRARY_IMPL are in ifdef macro to avoid complication error when OneMKL is not supported. Newly declared op will be called from existing torch implementation using c10::Dispatcher.

This is part of: #1853

@Copilot Copilot AI review requested due to automatic review settings August 29, 2025 11:02
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements OneMKL-based kernels for complex datatype support in matrix multiplication operations (mm, bmm, addmm, baddbmm) on XPU devices. The implementation provides optimized BLAS operations for complex numbers using OneMKL library integration.

Key changes include:

  • Addition of OneMKL-based complex matrix multiplication kernels
  • Implementation of four core matrix operations with complex number support
  • Conditional compilation support for OneMKL availability

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@PawelSwider2000
Copy link
Author

@CuiYifeng @kbinias Please review

@PawelSwider2000
Copy link
Author

Follow up change with tests: #1993

Comment on lines +39 to +42
oneapi::mkl::blas::row_major::gemm(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please note that the storage of input tensors and output tensor is not always row-major.

Copy link
Author

@PawelSwider2000 PawelSwider2000 Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could decide which algorithm (row-major, col-major) to use base on input/output tensors. The implementation however will be much more complicated since we need to make a decision based on strides and shapes. Example we could have one tensor as row-major and second as col-major then deciding what to use could be more complicated.

Using row-major along with transposition to this format leads to worse performance when comparing to row-major, however for contiguous, second_col_major, first_col_major they are still comparable. Only both_col_major is visibly worse.

I would suggest to make this performance improvements in subsequent PRs, as it is much better to have complex support at all. Also mm is the simplest out of these ops and for baddbmm algorithm for efficient selection between row-major and col-major could be more complicated.

Copy link
Author

@PawelSwider2000 PawelSwider2000 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made more detailed comparison between existing reference on IPEX GPU found following performance issues:

  1. Perf degradation when one/both inputs are not row-major
  2. Perf degradation when one/both inputs are conjugated
  3. Worse performance for smaller sizes

Observed differences in perf are big, for some cases current implementation is a few times slower than reference.

@CuiYifeng do you know if there is a performance difference between row-major and column-major implementations?

The other issue that I notice is that for addmm, baddbmm some tests are failing like: TestCommonXPU::test_noncontiguous_samples_addmm_xpu_complex64 which was passing with implementation proposed in this PR

Copy link
Author

@PawelSwider2000 PawelSwider2000 Sep 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be more precise about perf for 4096x4096 tensors perf for contiguous non conjugate tensors implementation speedup is: 1.028 and it is similar for both conjugate inputs.

whereas the same for 256x256 tensors is around 0.565

For both columns major tensors and 4096x4096 tensors we have 0.255.

when only one tensor is column major, then differences are smaller but still large.

Copy link
Contributor

@CuiYifeng CuiYifeng Sep 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different memory layouts may lead to performance differences. In complex MatMul kernel, data reorder for different layouts may also introduce differences.

Comment on lines +90 to +93
oneapi::mkl::blas::row_major::gemm_batch(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment on lines +173 to +176
oneapi::mkl::blas::row_major::gemm(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Comment on lines +239 to +242
oneapi::mkl::blas::row_major::gemm_batch(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants