@@ -10,16 +10,57 @@ function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::Abstra
10
10
return sparse (I, J, V, m, n, combine), sparse_pullback
11
11
end
12
12
13
- function rrule (:: Type{T} , A:: AbstractMatrix ) where T <: SparseMatrixCSC
13
+ function rrule (:: Type{T} , A:: AbstractMatrix ) where T <: AbstractSparseMatrix
14
14
function sparse_pullback (Ω̄)
15
15
return NoTangent (), Ω̄
16
16
end
17
17
return T (A), sparse_pullback
18
18
end
19
19
20
- function rrule (:: Type{T} , v:: AbstractVector ) where T <: SparseVector
20
+ function rrule (:: Type{T} , v:: AbstractVector ) where T <: AbstractSparseVector
21
21
function sparse_pullback (Ω̄)
22
22
return NoTangent (), Ω̄
23
23
end
24
24
return T (v), sparse_pullback
25
25
end
26
+
27
+ function rrule (:: typeof (findnz), A:: AbstractSparseMatrix )
28
+ I, J, V = findnz (A)
29
+ @show I, J, V
30
+ m, n = size (A)
31
+
32
+ function findnz_pullback (Δ)
33
+ Δ === NoTangent () && return (NoTangent (), Δ)
34
+ Δ === ZeroTangent () && return (NoTangent (), Δ)
35
+
36
+ _, _, V̄ = unthunk (Δ)
37
+
38
+ V̄ === NoTangent () && return (NoTangent (), V̄)
39
+ V̄ === ZeroTangent () && return (NoTangent (), V̄)
40
+
41
+ @show I, J, V̄
42
+
43
+ return NoTangent (), sparse (I, J, V̄, m, n)
44
+ end
45
+
46
+ return (I, J, V), findnz_pullback
47
+ end
48
+
49
+ function rrule (:: typeof (findnz), v:: AbstractSparseVector )
50
+ I, V = findnz (v)
51
+ n = length (v)
52
+
53
+ function findnz_pullback (Δ)
54
+ Δ === NoTangent () && return (NoTangent (), Δ)
55
+ Δ === ZeroTangent () && return (NoTangent (), Δ)
56
+
57
+ _, V̄ = unthunk (Δ)
58
+
59
+ V̄ === NoTangent () && return (NoTangent (), V̄)
60
+ V̄ === ZeroTangent () && return (NoTangent (), V̄)
61
+
62
+ return NoTangent (), sparse (I, V̄, n)
63
+ end
64
+
65
+ return (I, V), findnz_pullback
66
+ end
0 commit comments