Skip to content

Commit 77ef0eb

Browse files
authored
Indexing (#655)
* move and rename _zerolike_writeat, NFC * simplify, use it for getindex, tests * add unsafe_getindex too * tidy, make weird types work via _setindex_zero * fix view & its zero-arrays * test unsafe_getindex * handle indexing of GPU arrays * suggested changes * restore some gpu tests * avoid the error * in fact, mystery errors persist
1 parent 4c3a869 commit 77ef0eb

File tree

9 files changed

+260
-92
lines changed

9 files changed

+260
-92
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.43.2"
3+
version = "1.44.0"
44

55
[deps]
6+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
67
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
78
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
89
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -16,6 +17,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1617
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
1718

1819
[compat]
20+
Adapt = "3.4.0"
1921
ChainRulesCore = "1.15.3"
2022
ChainRulesTestUtils = "1.5"
2123
Compat = "3.42.0, 4"
@@ -30,7 +32,6 @@ StructArrays = "0.6.11"
3032
julia = "1.6"
3133

3234
[extras]
33-
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3435
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3536
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
3637
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -40,4 +41,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4041
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4142

4243
[targets]
43-
test = ["Adapt", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
44+
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]

src/ChainRules.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module ChainRules
22

3+
using Adapt: adapt
34
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
45
using ChainRulesCore
56
using Compat
67
using Distributed
7-
using GPUArraysCore: AbstractGPUArrayStyle
8+
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
89
using IrrationalConstants: logtwo, logten
910
using LinearAlgebra
1011
using LinearAlgebra.BLAS

src/rulesets/Base/array.jl

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -515,64 +515,14 @@ for findm in (:findmin, :findmax)
515515

516516
@eval function rrule(::typeof($findm), x::AbstractArray; dims=:)
517517
y, ind = $findm(x; dims=dims)
518-
project = ProjectTo(x)
519-
# This pullback is a lot like the one for getindex. Ideally they would probably be combined?
520518
function $findm_pullback((dy, _)) # this accepts e.g. Tangent{Tuple{Float64, Int64}}(4.0, nothing)
521519
dy isa AbstractZero && return (NoTangent(), NoTangent())
522-
x_thunk = @thunk project(_zerolike_writeat(x, unthunk(dy), dims, ind))
523-
x_ithunk = InplaceableThunk(x_thunk) do dx
524-
if dims isa Colon
525-
view(dx, ind) .= view(dx, ind) .+ Ref(unthunk(dy))
526-
else
527-
view(dx, ind) .= view(dx, ind) .+ unthunk(dy) # this could be .+=, but not on Julia 1.0
528-
end
529-
dx
530-
end
531-
return (NoTangent(), x_ithunk)
520+
return (NoTangent(), thunked_∇getindex(x, dy, ind),)
532521
end
533522
return (y, ind), $findm_pullback
534523
end
535524
end
536525

537-
# This function is roughly `setindex!(zero(x), dy, inds...)`:
538-
539-
function _zerolike_writeat(x::AbstractArray{<:Number}, dy, dims, inds...)
540-
_zero_fill = eltype(dy) == Any ? 0 : zero(eltype(dy))
541-
542-
# It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
543-
# allow `eltype(dy)`, nor does it work for many structured matrices.
544-
dx = fill!(similar(x, eltype(dy), axes(x)), _zero_fill)
545-
view(dx, inds...) .= dy # possibly 0-dim view, allows dy::Number and dy::Array, and dx::CuArray
546-
dx
547-
end
548-
function _zerolike_writeat(x::AbstractArray, dy, dims, inds...)
549-
# Since we have `x`, we can also handle arrays of arrays.
550-
dx = map(zero, x)
551-
if dims isa Colon
552-
view(dx, inds...) .= Ref(dy)
553-
else
554-
view(dx, inds...) .= dy
555-
end
556-
dx
557-
end
558-
559-
# Allow for second derivatives, by writing rules for `_zerolike_writeat`;
560-
# these rules are the reason it takes a `dims` argument.
561-
562-
function frule((_, _, dẏ), ::typeof(_zerolike_writeat), x, dy, dims, inds...)
563-
return _zerolike_writeat(x, dy, dims, inds...), _zerolike_writeat(x, dẏ, dims, inds...)
564-
end
565-
566-
function rrule(::typeof(_zerolike_writeat), x, dy, dims, inds...)
567-
z = _zerolike_writeat(x, dy, dims, inds...)
568-
function _zerolike_writeat_pullback(dz)
569-
dx = sum(view(unthunk(dz), inds...); dims=dims)
570-
nots = map(_ -> NoTangent(), inds)
571-
return (NoTangent(), NoTangent(), dx, NoTangent(), nots...)
572-
end
573-
return z, _zerolike_writeat_pullback
574-
end
575-
576526
# These rules for `maximum` pick the same subgradient as `findmax`:
577527

578528
function frule((_, ẋ), ::typeof(maximum), x; dims=:)

src/rulesets/Base/base.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,23 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, xs::Tu
243243
map_back(dy::AbstractZero) = (NoTangent(), NoTangent(), ntuple(Returns(NoTangent()), num_xs)...)
244244
return y, map_pullback
245245
end
246+
247+
#####
248+
##### `task_local_storage`
249+
#####
250+
251+
# Called by `@allowscalar` from GPUArrays
252+
253+
ChainRules.@non_differentiable task_local_storage(key::Any)
254+
ChainRules.@non_differentiable task_local_storage(key::Any, value::Any)
255+
256+
function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage), body::Function, key, value)
257+
y, back = task_local_storage(key, value) do
258+
rrule_via_ad(config, body)
259+
end
260+
function task_local_storage_pullback(dy)
261+
dbody = only(back(dy))
262+
return (NoTangent(), dbody, NoTangent(), NoTangent())
263+
end
264+
return y, task_local_storage_pullback
265+
end

src/rulesets/Base/indexing.jl

Lines changed: 126 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon)
5252
return x, getindex_back_4
5353
end
5454

55-
5655
#####
57-
##### getindex
56+
##### getindex(::AbstractArray)
5857
#####
5958

6059
function frule((_, ẋ), ::typeof(getindex), x::AbstractArray, inds...)
6160
return x[inds...], ẋ[inds...]
6261
end
6362

64-
function rrule(::typeof(getindex), x::Array{<:Number}, inds...)
65-
# removes any logical indexing, CartesianIndex etc
66-
# leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
63+
function rrule(::typeof(getindex), x::AbstractArray, inds...)
64+
function getindex_pullback(dy)
65+
nots = map(Returns(NoTangent()), inds)
66+
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
67+
end
68+
return x[inds...], getindex_pullback
69+
end
70+
71+
function thunked_∇getindex(x, dy, inds...)
72+
return InplaceableThunk(
73+
dx -> ∇getindex!(dx, unthunk(dy), Base.to_indices(x, inds)...),
74+
@thunk(∇getindex(x, unthunk(dy), inds...)),
75+
)
76+
end
77+
78+
"""
79+
∇getindex(x, dy, inds...)
80+
81+
For the `rrule` of `y = x[inds...]`, this function is roughly
82+
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
83+
Differentiable. Includes `ProjectTo(x)(dx)`.
84+
"""
85+
function ∇getindex(x::AbstractArray, dy, inds...)
86+
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
87+
# leaving just Int / AbstractVector of Int
6788
plain_inds = Base.to_indices(x, inds)
68-
y = getindex(x, plain_inds...)
69-
function getindex_pullback(ȳ)
70-
function getindex_add!(Δ)
71-
# this a optimizes away for simple cases
72-
for (ȳ_ii, ii) in zip(ȳ, Iterators.product(plain_inds...))
73-
Δ[ii...] += ȳ_ii
74-
end
75-
return Δ
76-
end
89+
dx = _setindex_zero(x, dy, plain_inds...)
90+
∇getindex!(dx, dy, plain_inds...)
91+
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
92+
end
93+
94+
"""
95+
_setindex_zero(x, dy, inds...)
7796
78-
= InplaceableThunk(
79-
getindex_add!,
80-
@thunk(getindex_add!(zero(x))),
81-
)
82-
īnds = broadcast(Returns(NoTangent()), inds)
83-
return (NoTangent(), x̄, īnds...)
97+
This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
98+
and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
99+
`∇getindex` does next.
100+
101+
It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
102+
allow `eltype(dy)`, nor does it work for many structured matrices.
103+
"""
104+
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), ZeroTangent())
105+
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), ZeroTangent())
106+
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
107+
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
108+
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
109+
T = Union{typeof(dy), ZeroTangent}
110+
return fill!(similar(x, T, axes(x)), ZeroTangent())
111+
end
112+
function _setindex_zero(x::AbstractArray, dy, inds...)
113+
T = Union{eltype(dy), ZeroTangent}
114+
return fill!(similar(x, T, axes(x)), ZeroTangent())
115+
end
116+
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)
117+
118+
function ∇getindex!(dx::AbstractArray, dy, inds::Integer...)
119+
view(dx, inds...) .+= Ref(dy)
120+
return dx
121+
end
122+
function ∇getindex!(dx::AbstractArray, dy, inds...)
123+
view(dx, inds...) .+= dy
124+
return dx
125+
end
126+
127+
# Allow for second derivatives, by writing rules for `∇getindex`:
128+
129+
function frule((_, _, dẏ), ::typeof(∇getindex), x, dy, inds...)
130+
return ∇getindex(x, dy, inds...), ∇getindex(x, dẏ, inds...)
131+
end
132+
133+
function rrule(::typeof(∇getindex), x, dy, inds...)
134+
z = ∇getindex(x, dy, inds...)
135+
function ∇getindex_pullback(dz)
136+
d2y = getindex(unthunk(dz), inds...)
137+
nots = map(Returns(NoTangent()), inds)
138+
return (NoTangent(), NoTangent(), ProjectTo(dy)(d2y), nots...)
84139
end
140+
return z, ∇getindex_pullback
141+
end
142+
143+
# Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
144+
# To avoid this, copy everything back to the CPU.
145+
# But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
85146

86-
return y, getindex_pullback
147+
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Integer...)
148+
view(dx, inds...) .+= Ref(dy)
149+
return dx
150+
end
151+
function ∇getindex!(dx::AbstractGPUArray, dy, inds::Union{Integer, AbstractUnitRange, Base.Slice}...)
152+
view(dx, inds...) .+= dy
153+
return dx
154+
end
155+
function ∇getindex!(dx::AbstractGPUArray, dy, inds...)
156+
dx_cpu = adapt(Array, dx)
157+
view(dx_cpu, adapt(Array, inds)...) .+= adapt(Array, dy)
158+
copyto!(dx, dx_cpu)
159+
return dx
87160
end
88161

89162
#####
@@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
117190
return view(x, inds...), view(ẋ, inds...)
118191
end
119192

193+
function rrule(::typeof(view), x::AbstractArray, inds...)
194+
function view_pullback(dy)
195+
nots = map(Returns(NoTangent()), inds)
196+
return (NoTangent(), thunked_∇getindex(x, dy, inds...), nots...)
197+
end
198+
return view(x, inds...), view_pullback
199+
end
200+
201+
function rrule(::typeof(view), x::AbstractArray, i::Integer, jkl::Integer...)
202+
# This case returns a zero-dim array, unlike getindex. So we fool ∇getindex:
203+
function view_pullback_0(dy)
204+
nots = map(Returns(NoTangent()), (i, jkl...))
205+
return (NoTangent(), thunked_∇getindex(x, dy, i:i, jkl...), nots...)
206+
end
207+
return view(x, i, jkl...), view_pullback_0
208+
end
209+
120210
#####
121211
##### setindex!
122212
#####
@@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
125215
return setindex!(x, v, inds...), setindex!(ẋ, v̇, inds...)
126216
end
127217

218+
#####
219+
##### unsafe_getindex
220+
#####
221+
222+
# This is called by e.g. `iterate(1:0.1:2)`,
223+
# and fixes https://github.com/FluxML/Zygote.jl/issues/1247
224+
# Only needs to accept AbstractRange, but AbstractVector makes testing easier.
225+
226+
function frule((_, ẋ), ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
227+
return Base.unsafe_getindex(x, i), getindex(ẋ, i)
228+
end
229+
230+
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(Base.unsafe_getindex), x::AbstractVector, i::Integer)
231+
return rrule_via_ad(cfg, getindex, x, i)
232+
end
128233

129234
#####
130235
##### `eachslice` and friends

src/rulesets/Base/sort.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,7 @@ function rrule(::typeof(sortslices), x::AbstractArray; dims::Integer, kw...)
6262
p = sortperm(collect(eachslice(x; dims=dims)); kw...)
6363
inds = ntuple(d -> d == dims ? p : (:), ndims(x))
6464
function sortslices_pullback(dy)
65-
# No actual need to zero this, and if you didn't, then you could widen eltype
66-
# Also, you could use similar(dy) here not x, same size?
67-
dx = _zerolike_writeat(x, unthunk(dy), (), inds...)
68-
return (NoTangent(), ProjectTo(x)(dx))
65+
return (NoTangent(), ∇getindex(x, unthunk(dy), inds...))
6966
end
7067
return x[inds...], sortslices_pullback
7168
end
@@ -94,12 +91,11 @@ function rrule(::typeof(unique), x::AbstractArray{<:Number}; dims=:)
9491
mask .= (mask .== cumsum(mask, dims=1) .== true) # this implements findfirst(mask; dims=1)
9592
keep = map(I -> I[1], findall(mask))
9693
if dims isa Colon
97-
# The function `_zerolike_writeat` allows second derivatives.
98-
# Should perhaps eventually be shared with `getindex`.
99-
dx = reshape(_zerolike_writeat(vec(x), vec(dy), (), keep), axes_x)
94+
# The function `∇getindex` allows second derivatives.
95+
dx = reshape(∇getindex(vec(x), vec(dy), keep), axes_x) ## TODO understand again why vec!
10096
else
10197
inds = ntuple(d -> d==dims ? keep : (:), length(axes_x))
102-
dx = _zerolike_writeat(x, dy, (), inds...)
98+
dx = ∇getindex(x, dy, inds...)
10399
end
104100
return (NoTangent(), ProjectTo(x)(dx))
105101
end

test/rulesets/Base/array.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -366,14 +366,7 @@ end
366366
@test [5 0; 6 0] == @inferred unthunk(rrule(findmin, [1 2; 3 4], dims=2)[2]((hcat([5,6]), nothing))[2])
367367
test_rrule(findmin, rand(3,4), fkwargs=(dims=1,), output_tangent = (rand(1,4), NoTangent()))
368368
test_rrule(findmin, rand(3,4), fkwargs=(dims=2,))
369-
370-
# Second derivatives
371-
test_frule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
372-
test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, :, CartesianIndex(2, 2))
373-
@test_skip test_rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)] NoTangent()) # MethodError: no method matching isapprox(::Matrix{Float64}, ::Float64; rtol=1.0e-9, atol=1.0e-9)
374-
y, bk = rrule(ChainRules._zerolike_writeat, rand(2, 2), 5.0, 1, [CartesianIndex(2, 1) CartesianIndex(2, 2)])
375-
@test y == [0 0; 5 5]
376-
@test bk([1 2; 3 4]) == (NoTangent(), NoTangent(), [3 4], NoTangent(), NoTangent())
369+
test_rrule(findmin, rand(3,4), fkwargs=(dims=(1,2),))
377370
end
378371

379372
@testset "$imum" for imum in [maximum, minimum]

0 commit comments

Comments
 (0)