Skip to content

Commit abf4e07

Browse files
Adding Any for rmul and lmul
1 parent 0c26c5f commit abf4e07

File tree

1 file changed

+10
-29
lines changed

1 file changed

+10
-29
lines changed

src/host/linalg.jl

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ else
178178
end
179179
end
180180

181-
function Base.:\(D::Diagonal{<:Any, <:AbstractGPUArray}, B::AbstractGPUVecOrMat)
181+
function Base.:\(D::Diagonal{<:Any, <:AnyGPUArray}, B::AnyGPUVecOrMat)
182182
z = D.diag .== 0
183183
if any(z)
184184
i = findfirst(collect(z))
@@ -189,7 +189,7 @@ function Base.:\(D::Diagonal{<:Any, <:AbstractGPUArray}, B::AbstractGPUVecOrMat)
189189
end
190190

191191
if VERSION < v"1.8-"
192-
function LinearAlgebra.ldiv!(D::Diagonal{<:Any, <:AbstractGPUArray},
192+
function LinearAlgebra.ldiv!(D::Diagonal{<:Any, <:AnyGPUArray},
193193
B::StridedVecOrMat)
194194
m, n = size(B, 1), size(B, 2)
195195
if m != length(D.diag)
@@ -289,7 +289,7 @@ end
289289

290290
## matrix multiplication
291291

292-
function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::AbstractArray{S}, a::Number, b::Number) where {T,S,R}
292+
function generic_matmatmul!(C::AnyArray{R}, A::AnyArray{T}, B::AnyArray{S}, a::Number, b::Number) where {T,S,R}
293293
if size(A,2) != size(B,1)
294294
throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))"))
295295
end
@@ -319,29 +319,10 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac
319319
C
320320
end
321321

322-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
323-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
324-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
325-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
326-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
327-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
328-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
329-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
330-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
331-
332-
# specificity hacks
333-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
334-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
335-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::AbstractGPUVecOrMat, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
336-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
337-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::AbstractGPUVecOrMat, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
338-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
339-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
340-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Adjoint{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
341-
LinearAlgebra.mul!(C::AbstractGPUVecOrMat, A::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, B::LinearAlgebra.Transpose{<:Any, <:AbstractGPUVecOrMat}, a::Real, b::Real) = generic_matmatmul!(C, A, B, a, b)
342-
343-
344-
function generic_rmul!(X::AbstractArray, s::Number)
322+
LinearAlgebra.mul!(C::AnyGPUVecOrMat, A::AnyGPUVecOrMat, B::AnyGPUVecOrMat, a::Number, b::Number) = generic_matmatmul!(C, A, B, a, b)
323+
324+
325+
function generic_rmul!(X::AnyArray, s::Number)
345326
gpu_call(X, s; name="rmul!") do ctx, X, s
346327
i = @linearidx X
347328
@inbounds X[i] *= s
@@ -350,9 +331,9 @@ function generic_rmul!(X::AbstractArray, s::Number)
350331
return X
351332
end
352333

353-
LinearAlgebra.rmul!(A::AbstractGPUArray, b::Number) = generic_rmul!(A, b)
334+
LinearAlgebra.rmul!(A::AnyGPUArray, b::Number) = generic_rmul!(A, b)
354335

355-
function generic_lmul!(s::Number, X::AbstractArray)
336+
function generic_lmul!(s::Number, X::AnyArray)
356337
gpu_call(X, s; name="lmul!") do ctx, X, s
357338
i = @linearidx X
358339
@inbounds X[i] = s*X[i]
@@ -361,7 +342,7 @@ function generic_lmul!(s::Number, X::AbstractArray)
361342
return X
362343
end
363344

364-
LinearAlgebra.lmul!(a::Number, B::AbstractGPUArray) = generic_lmul!(a, B)
345+
LinearAlgebra.lmul!(a::Number, B::AnyGPUArray) = generic_lmul!(a, B)
365346

366347

367348
## permutedims

0 commit comments

Comments
 (0)