Skip to content

Commit 5e9b4b5

Browse files
rrule for broadcasted cast of sparse matrix
1 parent 0be4d48 commit 5e9b4b5

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "1.27.0"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
7+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
4848
n = length(v)
4949

5050
function findnz_pullback(Δ)
51-
Δ === NoTangent() && return (NoTangent(), Δ)
52-
Δ === ZeroTangent() && return (NoTangent(), Δ)
53-
5451
_, V̄ = unthunk(Δ)
5552

5653
=== NoTangent() && return (NoTangent(), V̄)
@@ -61,3 +58,11 @@ function rrule(::typeof(findnz), v::AbstractSparseVector)
6158

6259
return (I, V), findnz_pullback
6360
end
61+
62+
function rrule(::typeof(Broadcast.broadcasted), T::Type{<:Number}, x::AbstractSparseArray)
63+
proj = ProjectTo(x)
64+
function broadcasted_cast_sparse(Δ)
65+
return NoTangent(), NoTangent(), proj(Δ)
66+
end
67+
T.(x), broadcasted_cast_sparse
68+
end

test/rulesets/SparseArrays/sparsematrix.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,17 @@ end
1919
end
2020

2121
@testset "findnz" begin
22-
A = sprand(5, 5, 0.2)
23-
test_rrule(findnz, A)
22+
A = sprand(5, 5, 0.5)
23+
dA = similar(A)
24+
rand!(dA.nzval)
25+
I, J, V = findnz(A)
26+
= rand!(similar(V))
27+
test_rrule(findnz, A dA, output_tangent=(NoTangent(), NoTangent(), V̄))
2428

25-
v = sprand(10, 0.5)
26-
test_rrule(findnz, v)
29+
# A = sprand(5, 5, 0.5)
30+
# test_rrule(findnz, A)
31+
32+
# v = sprand(10, 0.5)
33+
# test_rrule(findnz, v)
2734
end
35+

0 commit comments

Comments
 (0)