Skip to content

Commit 8cb4ef8

Browse files
bors[bot]GiggleLiu
andcommitted
Merge #176
176: fix permutedims dispatch, isapprox for complex numbers r=maleadt a=GiggleLiu several bug fixes. There should be some tests. Am I supposed to put tests in `CuArrays.jl` and submit another PR? Co-authored-by: Leo <[email protected]>
2 parents b7a9eee + eef51d7 commit 8cb4ef8

File tree

5 files changed

+14
-5
lines changed

5 files changed

+14
-5
lines changed

src/indexing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Base.setindex!(xs::GPUArray, v, i::Integer) = xs[i] = convert(eltype(xs), v)
6262
# Vector indexing
6363

6464
to_index(a, x) = x
65-
to_index(::A, x::Array{ET}) where {A, ET} = copyto!(similar(A, ET, size(x)), x)
65+
to_index(a::A, x::Array{ET}) where {A, ET} = copyto!(similar(a, ET, size(x)...), x)
6666
to_index(a, x::UnitRange{<: Integer}) = convert(UnitRange{Int}, x)
6767
to_index(a, x::Base.LogicalIndex) = error("Logical indexing not implemented")
6868

src/linalg.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ function genperm(I::NTuple{N}, perm::NTuple{N}) where N
8080
ntuple(d-> (@inbounds return I[perm[d]]), Val(N))
8181
end
8282

83-
function LinearAlgebra.permutedims!(dest::GPUArray, src::GPUArray, perm::NTuple{N, Integer}) where N
83+
function LinearAlgebra.permutedims!(dest::GPUArray, src::GPUArray, perm) where N
84+
perm isa Tuple || (perm = Tuple(perm))
8485
gpu_call(dest, (dest, src, perm)) do state, dest, src, perm
8586
I = @cartesianidx src state
8687
@inbounds dest[genperm(I, perm)...] = src[I...]

src/mapreduce.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,6 @@ function fast_isapprox(x::Number, y::Number, rtol::Real = Base.rtoldefault(x, y)
189189
x == y || (isfinite(x) && isfinite(y) && abs(x - y) <= atol + rtol*max(abs(x), abs(y)))
190190
end
191191

192-
Base.isapprox(A::GPUArray{T1}, B::GPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, B, T1(rtol), T1(atol)))
193-
Base.isapprox(A::AbstractArray{T1}, B::GPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, Array(B), T1(rtol), T1(atol)))
194-
Base.isapprox(A::GPUArray{T1}, B::AbstractArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(Array(A), B, T1(rtol), T1(atol)))
192+
Base.isapprox(A::GPUArray{T1}, B::GPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, B, T1(rtol)|>real, T1(atol)|>real))
193+
Base.isapprox(A::AbstractArray{T1}, B::GPUArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(A, Array(B), T1(rtol)|>real, T1(atol)|>real))
194+
Base.isapprox(A::GPUArray{T1}, B::AbstractArray{T2}, rtol::Real = Base.rtoldefault(T1, T2, 0), atol::Real=0) where {T1, T2} = all(fast_isapprox.(Array(A), B, T1(rtol)|>real, T1(atol)|>real))

src/testsuite/linalg.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ function test_linalg(AT)
99
@test compare(x -> permutedims(x, (2, 1)), AT, rand(Float32, 2, 3))
1010
@test compare(x -> permutedims(x, (2, 1, 3)), AT, rand(Float32, 4, 5, 6))
1111
@test compare(x -> permutedims(x, (3, 1, 2)), AT, rand(Float32, 4, 5, 6))
12+
@test compare(x -> permutedims(x, [2,1,4,3]), AT, randn(ComplexF64,3,4,5,1))
1213
end
1314

1415
@testset "issymmetric/ishermitian" begin

src/testsuite/mapreduce.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ function test_mapreduce(AT)
5555
@test !(A B)
5656
@test !(A Array(B))
5757
@test !(Array(A) B)
58+
59+
60+
ca = AT(randn(ComplexF64,3,3))
61+
cb = copy(ca)
62+
cb[1:1, 1:1] .+= 1e-7im
63+
@test isapprox(ca, cb, atol=1e-5)
64+
@test !isapprox(ca, cb, atol=1e-9)
5865
end
5966
end
6067
end

0 commit comments

Comments
 (0)