Skip to content

Commit 2d334a6

Browse files
authored
implement to_vec for sparse arrays (#202)
* implement to_vec for sparse arrays * densify the sparse matrix like we do with the diagonal
1 parent 4d7835e commit 2d334a6

File tree

5 files changed

+33
-1
lines changed

5 files changed

+33
-1
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "FiniteDifferences"
22
uuid = "26cc04aa-876d-5657-8c51-4c34ba976000"
3-
version = "0.12.22"
3+
version = "0.12.23"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
88
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
Richardson = "708f8203-808e-40c0-ba2d-98a6953ed40d"
11+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1112
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1213

1314
[compat]

src/FiniteDifferences.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using LinearAlgebra
55
using Printf
66
using Random
77
using Richardson
8+
using SparseArrays
89
using StaticArrays
910

1011
export to_vec, grad, jacobian, jvp, j′vp

src/to_vec.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,30 @@ function to_vec(X::T) where {T<:PermutedDimsArray}
156156
return x_vec, PermutedDimsArray_from_vec
157157
end
158158

159+
function to_vec(v::SparseVector)
160+
inds, _ = findnz(v)
161+
sizes = size(v)
162+
163+
x_vec, back = to_vec(collect(v))
164+
function SparseVector_from_vec(x_v)
165+
v_values = back(x_v)
166+
return sparsevec(inds, v_values[inds], sizes...)
167+
end
168+
return x_vec, SparseVector_from_vec
169+
end
170+
171+
function to_vec(m::SparseMatrixCSC)
172+
is, js, _ = findnz(m)
173+
sizes = size(m)
174+
175+
x_vec, back = to_vec(collect(m))
176+
function SparseMatrixCSC_from_vec(x_v)
177+
v_values = back(x_v)
178+
return sparse(is, js, [v_values[i, j] for (i, j) in zip(is, js)], sizes...)
179+
end
180+
return x_vec, SparseMatrixCSC_from_vec
181+
end
182+
159183
# Factorizations
160184

161185
function to_vec(x::F) where {F <: SVD}

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using FiniteDifferences
44
using LinearAlgebra
55
using Printf
66
using Random
7+
using SparseArrays
78
using StaticArrays
89
using Test
910

test/to_vec.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,11 @@ end
129129
)
130130
end
131131

132+
@testset "SparseArrays" begin
133+
test_to_vec(sparsevec([1 2 0; 0 0 3; 0 4 0.0]))
134+
test_to_vec(sparse([1 2 0; 0 0 3; 0 4 0.0]))
135+
end
136+
132137
@testset "Factorizations" begin
133138
# (100, 100) is needed to test for the NaNs that can appear in the
134139
# qr(M).T matrix

0 commit comments

Comments
 (0)