Skip to content

Commit 8108a77

Browse files
add sparse rrule (#579)
* add sparse(I, J, V, m, n, +) rrule * cleanup * fix test * sparse(A) and sparse(v) * SparseMatrixCSC and SparseVector * cleanup Co-authored-by: Michael Abbott <[email protected]>
1 parent 970fce4 commit 8108a77

File tree

5 files changed

+53
-0
lines changed

5 files changed

+53
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1111
RealDot = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9"
12+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1213
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1314

1415
[compat]

src/ChainRules.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using LinearAlgebra
88
using LinearAlgebra.BLAS
99
using Random
1010
using RealDot: realdot
11+
using SparseArrays
1112
using Statistics
1213

1314
# Basically everything this package does is overloading these, so we make an exception
@@ -43,6 +44,8 @@ include("rulesets/LinearAlgebra/symmetric.jl")
4344
include("rulesets/LinearAlgebra/factorization.jl")
4445
include("rulesets/LinearAlgebra/uniformscaling.jl")
4546

47+
include("rulesets/SparseArrays/sparsematrix.jl")
48+
4649
include("rulesets/Random/random.jl")
4750

4851
end # module
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
function rrule(::typeof(sparse), I::AbstractVector, J::AbstractVector, V::AbstractVector, m, n, combine::typeof(+))
2+
project_V = ProjectTo(V)
3+
4+
function sparse_pullback(Ω̄)
5+
ΔΩ = unthunk(Ω̄)
6+
ΔV = project_V(ΔΩ[I .+ m .* (J .- 1)])
7+
return NoTangent(), NoTangent(), NoTangent(), ΔV, NoTangent(), NoTangent(), NoTangent()
8+
end
9+
10+
return sparse(I, J, V, m, n, combine), sparse_pullback
11+
end
12+
13+
function rrule(::Type{T}, A::AbstractMatrix) where T <: SparseMatrixCSC
14+
function sparse_pullback(Ω̄)
15+
return NoTangent(), Ω̄
16+
end
17+
return T(A), sparse_pullback
18+
end
19+
20+
function rrule(::Type{T}, v::AbstractVector) where T <: SparseVector
21+
function sparse_pullback(Ω̄)
22+
return NoTangent(), Ω̄
23+
end
24+
return T(v), sparse_pullback
25+
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
@testset "sparse(I, J, V, m, n, +)" begin
3+
m, n = 3, 5
4+
s, t, w = [1,2], [2,3], [0.5,0.5]
5+
6+
test_rrule(sparse, s, t, w, m, n, +)
7+
end
8+
9+
@testset "SparseMatrixCSC(A)" begin
10+
A = rand(5, 3)
11+
test_rrule(SparseMatrixCSC, A)
12+
test_rrule(SparseMatrixCSC{Float32,Int}, A, rtol=1e-5)
13+
end
14+
15+
@testset "SparseVector(v)" begin
16+
v = rand(5)
17+
test_rrule(SparseVector, v)
18+
test_rrule(SparseVector{Float32}, Float32.(v), rtol=1e-5)
19+
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using LinearAlgebra
1212
using LinearAlgebra.BLAS
1313
using LinearAlgebra: dot
1414
using Random
15+
using SparseArrays
1516
using StaticArrays
1617
using Statistics
1718
using Test
@@ -75,6 +76,10 @@ end
7576

7677
println()
7778

79+
include_test("rulesets/SparseArrays/sparsematrix.jl")
80+
81+
println()
82+
7883
include_test("rulesets/Random/random.jl")
7984
println()
8085
end

0 commit comments

Comments
 (0)