-
Notifications
You must be signed in to change notification settings - Fork 57
Implementation of matmul for complex datatypes. #1992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implementation of matmul for complex datatypes. #1992
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR implements OneMKL-based kernels for complex datatype support in matrix multiplication operations (mm
, bmm
, addmm
, baddbmm
) on XPU devices. The implementation provides optimized BLAS operations for complex numbers using OneMKL library integration.
Key changes include:
- Addition of OneMKL-based complex matrix multiplication kernels
- Implementation of four core matrix operations with complex number support
- Conditional compilation support for OneMKL availability
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
@CuiYifeng @kbinias Please review |
Follow up change with tests: #1993 |
oneapi::mkl::blas::row_major::gemm( | ||
c10::xpu::getCurrentXPUStream().queue(), | ||
oneapi::mkl::transpose::nontrans, | ||
oneapi::mkl::transpose::nontrans, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please note that the storage of input tensors and output tensor is not always row-major.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we could decide which algorithm (row-major, col-major) to use base on input/output tensors. The implementation however will be much more complicated since we need to make a decision based on strides and shapes. Example we could have one tensor as row-major and second as col-major then deciding what to use could be more complicated.
Using row-major along with transposition to this format leads to worse performance when comparing to row-major, however for contiguous, second_col_major, first_col_major they are still comparable. Only both_col_major is visibly worse.
I would suggest to make this performance improvements in subsequent PRs, as it is much better to have complex support at all. Also mm is the simplest out of these ops and for baddbmm algorithm for efficient selection between row-major and col-major could be more complicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made more detailed comparison between existing reference on IPEX GPU found following performance issues:
- Perf degradation when one/both inputs are not row-major
- Perf degradation when one/both inputs are conjugated
- Worse performance for smaller sizes
Observed differences in perf are big, for some cases current implementation is a few times slower than reference.
@CuiYifeng do you know if there is a performance difference between row-major and column-major implementations?
The other issue that I notice is that for addmm, baddbmm some tests are failing like: TestCommonXPU::test_noncontiguous_samples_addmm_xpu_complex64
which was passing with implementation proposed in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be more precise about perf for 4096x4096 tensors perf for contiguous non conjugate tensors implementation speedup is: 1.028 and it is similar for both conjugate inputs.
whereas the same for 256x256 tensors is around 0.565
For both columns major tensors and 4096x4096 tensors we have 0.255.
when only one tensor is column major, then differences are smaller but still large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Different memory layouts may lead to performance differences. In complex MatMul kernel, data reorder for different layouts may also introduce differences.
oneapi::mkl::blas::row_major::gemm_batch( | ||
c10::xpu::getCurrentXPUStream().queue(), | ||
oneapi::mkl::transpose::nontrans, | ||
oneapi::mkl::transpose::nontrans, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
oneapi::mkl::blas::row_major::gemm( | ||
c10::xpu::getCurrentXPUStream().queue(), | ||
oneapi::mkl::transpose::nontrans, | ||
oneapi::mkl::transpose::nontrans, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
oneapi::mkl::blas::row_major::gemm_batch( | ||
c10::xpu::getCurrentXPUStream().queue(), | ||
oneapi::mkl::transpose::nontrans, | ||
oneapi::mkl::transpose::nontrans, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto.
Implementation of kernels for complex datatype support for 4 ops:
mm
,bmm
,addmm
,baddbmm
using OneMKL.Current implementation of this ops for XPU is in
pytorch/aten/src/ATen/native/mkldnn/xpu/Blas.cpp
. Since OneMKL is a torch-xpu-ops dependency and is available only with USE_ONEMKL_XPU=ON (which is a default value). Implementation needs to be in torch-ops-xpu and kernels and TORCH_LIBRARY_IMPL are in ifdef macro to avoid complication error when OneMKL is not supported. Newly declared op will be called from existing torch implementation using c10::Dispatcher.This is part of: #1853