-
Notifications
You must be signed in to change notification settings - Fork 2
Updates for TensorKit compatibility #49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
🚀 New features to boost your workflow:
|
|
||
include("yacusolver.jl") | ||
|
||
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedCuMatrix} | ||
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {TT<:BlasFloat, T<:StridedCuMatrix{TT}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are probably a couple more of these somewhat complex wrapper types that can still be handled by these algorithms, how do you feel about doing something like
for MatType in [...]
@eval ...
end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds fine to me, do we have a list of the ones we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this particular change also induced by TensorKit requirements, or simply more strictness (which I fully support)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be equivalent to defining a new type constant
const StridedCuBLASMatrix{T} = StridedCuMatrix{T} where {T<:BlasFloat}
and then using default_xxx_algorithm(::Type{<:StridedCuBLASMatrix}; kwargs...)
everywhere?
@@ -288,7 +288,7 @@ function MatrixAlgebraKit.svd_compact!(A::AbstractMatrix, USVᴴ, alg::GPU_SVDAl | |||
U, S, Vᴴ = USVᴴ | |||
if alg isa GPU_QRIteration | |||
isempty(alg.kwargs) || | |||
throw(ArgumentError("GPU_QRIteration does not accept any keyword arguments")) | |||
@warn "GPU_QRIteration does not accept any keyword arguments" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something I've been wondering is if these kinds of warnings and checks should just be moved to the actual algorithm constructors, which might avoid having to repeat them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would need to check, but some of the LAPACK algorithm names are used for different factorisations, and might support keywords for one and not for the other, or might support a different set of keywords. We can also remove these explicit checks, just pass on alg.kwargs
to the underlying LAPACK/CUSOLVER/rocSOLVER routine always, and then have these complain if the set of keyword arguments is invalid.
@@ -181,7 +181,7 @@ function lq_via_qr!(A::AbstractMatrix, L::AbstractMatrix, Q::AbstractMatrix, | |||
qr_alg::AbstractAlgorithm) | |||
m, n = size(A) | |||
minmn = min(m, n) | |||
At = adjoint!(similar(A'), A)::AbstractMatrix | |||
At = min(m, n) > 0 ? adjoint!(similar(A'), A)::AbstractMatrix : similar(A') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this ::AbstractMatrix
type assert useful?
Changed the format of some of the |
src/implementations/svd.jl
Outdated
Ut = similar(U') | ||
Vᴴt = similar(Vᴴ') | ||
if size(U) == (m, m) | ||
_gpu_gesvd!(At, view(S, 1:minmn, 1), Vᴴt, Ut) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't think long about it, but it was not immediately clear to me why this was necessary. Isn't S
always of length minmn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here I'm following what the CPU bindings have done, but I suppose we could be reusing an S over and over between differently sized arrays?
Co-authored-by: Lukas Devos <[email protected]>
The
ReshapedArray
overrides are needed to dispatch to the correct GPU algorithms. Needed to modify the type signature for the default algorithms to avoid ambiguities. Also it's nice to give some more info about dimension mismatches.