-
Notifications
You must be signed in to change notification settings - Fork 56
Tests for matmul for complex datatypes. #1993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Tests for matmul for complex datatypes. #1993
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR enables complex datatype support for matrix multiplication operations on XPU devices by implementing oneMKL-based BLAS operations and unskipping relevant tests that were previously failing due to lack of complex matmul support.
- Implements complex datatype matrix multiplication operations using oneMKL BLAS library
- Unskips hundreds of tests that were disabled due to missing complex matmul support
- Updates error message comments to reflect that only double precision (not complex) matmul is unsupported in oneDNN
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
test/xpu/skip_list_common.py | Removes complex64/complex128 test skips and updates comments to reflect current oneDNN limitations |
src/ATen/native/xpu/Blas.cpp | Implements oneMKL-based complex matrix multiplication operations (mm, bmm, addmm, baddbmm) |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
@@ -49,7 +49,7 @@ | |||
# OneDNN issues, https://github.com/intel/torch-xpu-ops/issues/253 | |||
# RuntimeError: Long is not supported in oneDNN! | |||
# RuntimeError: could not create a primitive descriptor for a deconvolution forward propagation primitive | |||
# RuntimeError: Double and complex datatype matmul is not supported in oneDNN | |||
# RuntimeError: Double datatype matmul is not supported in oneDNN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The comment should be more descriptive about what operations are affected. Consider: "RuntimeError: Double datatype matmul is not supported in oneDNN for certain operations"
# RuntimeError: Double datatype matmul is not supported in oneDNN | |
# RuntimeError: Double datatype matmul is not supported in oneDNN for certain operations (e.g., matmul, convolution) |
Copilot uses AI. Check for mistakes.
Tensor& out) { | ||
at::Tensor mat1_cont = mat1.contiguous().resolve_conj(); | ||
at::Tensor mat2_cont = mat2.contiguous().resolve_conj(); | ||
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The clone().detach()
operation creates an unnecessary copy of the tensor data. Since the tensor is already made contiguous and conjugate-resolved, consider using the tensor directly or only cloning when modification is actually needed to avoid the performance overhead.
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); | |
at::Tensor self_cont = self.contiguous().resolve_conj(); |
Copilot uses AI. Check for mistakes.
Tensor& out) { | ||
at::Tensor batch1_cont = batch1.contiguous().resolve_conj(); | ||
at::Tensor batch2_cont = batch2.contiguous().resolve_conj(); | ||
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the addmm implementation, this clone().detach()
operation creates an unnecessary copy. Consider avoiding the clone unless the tensor needs to be modified in-place to improve performance.
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach(); | |
at::Tensor self_cont = self.contiguous().resolve_conj(); |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although there are many test cases require complex MatMul for assertion, we don't need to enable all these cases at once. Activating the test case for directly testing Complex MatMul is sufficient.
Sure, I think we should run this test on CI and check them, if they are passing, I do not see reason to now unskip them. However if they are failing and are not directly testing MatMul they could stay skipped for now. |
Unskip test that were passing after implementing matmuls for complex datatypes.
Some test still fails but with different errors, unrelated to matmuls like:
# NotImplementedError: The operator 'aten::cholesky_inverse.out' is not currently implemented for the XPU device.
Created as separate PR (implementation: #1992) as this will require additional PR in upstream torch