Skip to content

Commit 0bb95c1

Browse files
rrule for findnz
1 parent 8108a77 commit 0bb95c1

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,57 @@ function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::Abstra
1010
return sparse(I, J, V, m, n, combine), sparse_pullback
1111
end
1212

13-
function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC
13+
function rrule(::Type{T}, A::AbstractMatrix) where T <: AbstractSparseMatrix
1414
function sparse_pullback(Ω̄)
1515
return NoTangent(), Ω̄
1616
end
1717
return T(A), sparse_pullback
1818
end
1919

20-
function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector
20+
function rrule(::Type{T}, v::AbstractVector) where T <: AbstractSparseVector
2121
function sparse_pullback(Ω̄)
2222
return NoTangent(), Ω̄
2323
end
2424
return T(v), sparse_pullback
2525
end
26+
27+
function rrule(::typeof(findnz), A::AbstractSparseMatrix)
28+
I, J, V = findnz(A)
29+
@show I, J, V
30+
m, n = size(A)
31+
32+
function findnz_pullback(Δ)
33+
Δ === NoTangent() && return (NoTangent(), Δ)
34+
Δ === ZeroTangent() && return (NoTangent(), Δ)
35+
36+
_, _, V̄ = unthunk(Δ)
37+
38+
=== NoTangent() && return (NoTangent(), V̄)
39+
=== ZeroTangent() && return (NoTangent(), V̄)
40+
41+
@show I, J, V̄
42+
43+
return NoTangent(), sparse(I, J, V̄, m, n)
44+
end
45+
46+
return (I, J, V), findnz_pullback
47+
end
48+
49+
function rrule(::typeof(findnz), v::AbstractSparseVector)
50+
I, V = findnz(v)
51+
n = length(v)
52+
53+
function findnz_pullback(Δ)
54+
Δ === NoTangent() && return (NoTangent(), Δ)
55+
Δ === ZeroTangent() && return (NoTangent(), Δ)
56+
57+
_, V̄ = unthunk(Δ)
58+
59+
=== NoTangent() && return (NoTangent(), V̄)
60+
=== ZeroTangent() && return (NoTangent(), V̄)
61+
62+
return NoTangent(), sparse(I, V̄, n)
63+
end
64+
65+
return (I, V), findnz_pullback
66+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,11 @@ end
1717
test_rrule(SparseVector, v)
1818
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5)
1919
end
20+
21+
@testset "findnz" begin
22+
A = sprand(5, 5, 0.2)
23+
test_rrule(findnz, A)
24+
25+
v = sprand(10, 0.5)
26+
test_rrule(findnz, v)
27+
end

0 commit comments

Comments
 (0)