@@ -52,38 +52,111 @@ function rrule(::typeof(getindex), x::Tuple, ::Colon)
52
52
return x, getindex_back_4
53
53
end
54
54
55
-
56
55
# ####
57
- # #### getindex
56
+ # #### getindex(::AbstractArray)
58
57
# ####
59
58
60
59
function frule ((_, ẋ), :: typeof (getindex), x:: AbstractArray , inds... )
61
60
return x[inds... ], ẋ[inds... ]
62
61
end
63
62
64
- function rrule (:: typeof (getindex), x:: Array{<:Number} , inds... )
65
- # removes any logical indexing, CartesianIndex etc
66
- # leaving us just with a tuple of Int, Arrays of Int and Ranges of Int
63
+ function rrule (:: typeof (getindex), x:: AbstractArray , inds... )
64
+ function getindex_pullback (dy)
65
+ nots = map (Returns (NoTangent ()), inds)
66
+ return (NoTangent (), thunked_∇getindex (x, dy, inds... ), nots... )
67
+ end
68
+ return x[inds... ], getindex_pullback
69
+ end
70
+
71
+ function thunked_∇getindex (x, dy, inds... )
72
+ return InplaceableThunk (
73
+ dx -> ∇getindex! (dx, unthunk (dy), Base. to_indices (x, inds)... ),
74
+ @thunk (∇getindex (x, unthunk (dy), inds... )),
75
+ )
76
+ end
77
+
78
+ """
79
+ ∇getindex(x, dy, inds...)
80
+
81
+ For the `rrule` of `y = x[inds...]`, this function is roughly
82
+ `setindex(zero(x), dy, inds...)`, returning the array `dx`.
83
+ Differentiable. Includes `ProjectTo(x)(dx)`.
84
+ """
85
+ function ∇getindex (x:: AbstractArray , dy, inds... )
86
+ # `to_indices` removes any logical indexing, colons, CartesianIndex etc,
87
+ # leaving just Int / AbstractVector of Int
67
88
plain_inds = Base. to_indices (x, inds)
68
- y = getindex (x, plain_inds... )
69
- function getindex_pullback (ȳ)
70
- function getindex_add! (Δ)
71
- # this a optimizes away for simple cases
72
- for (ȳ_ii, ii) in zip (ȳ, Iterators. product (plain_inds... ))
73
- Δ[ii... ] += ȳ_ii
74
- end
75
- return Δ
76
- end
89
+ dx = _setindex_zero (x, dy, plain_inds... )
90
+ ∇getindex! (dx, dy, plain_inds... )
91
+ return ProjectTo (x)(dx) # since we have x, may as well do this inside, not in rules
92
+ end
93
+
94
+ """
95
+ _setindex_zero(x, dy, inds...)
77
96
78
- x̄ = InplaceableThunk (
79
- getindex_add!,
80
- @thunk (getindex_add! (zero (x))),
81
- )
82
- īnds = broadcast (Returns (NoTangent ()), inds)
83
- return (NoTangent (), x̄, īnds... )
97
+ This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
98
+ and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
99
+ `∇getindex` does next.
100
+
101
+ It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
102
+ allow `eltype(dy)`, nor does it work for many structured matrices.
103
+ """
104
+ _setindex_zero (x:: AbstractArray{<:Number} , dy, inds:: Integer... ) = fill! (similar (x, typeof (dy), axes (x)), ZeroTangent ())
105
+ _setindex_zero (x:: AbstractArray{<:Number} , dy, inds... ) = fill! (similar (x, eltype (dy), axes (x)), ZeroTangent ())
106
+ function _setindex_zero (x:: AbstractArray , dy, inds:: Integer... )
107
+ # This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
108
+ # but always makes an abstract type. TODO : make it infer concrete type for e.g. vectors of SVectors
109
+ T = Union{typeof (dy), ZeroTangent}
110
+ return fill! (similar (x, T, axes (x)), ZeroTangent ())
111
+ end
112
+ function _setindex_zero (x:: AbstractArray , dy, inds... )
113
+ T = Union{eltype (dy), ZeroTangent}
114
+ return fill! (similar (x, T, axes (x)), ZeroTangent ())
115
+ end
116
+ ChainRules. @non_differentiable _setindex_zero (x:: AbstractArray , dy:: Any , inds:: Any... )
117
+
118
+ function ∇getindex! (dx:: AbstractArray , dy, inds:: Integer... )
119
+ view (dx, inds... ) .+ = Ref (dy)
120
+ return dx
121
+ end
122
+ function ∇getindex! (dx:: AbstractArray , dy, inds... )
123
+ view (dx, inds... ) .+ = dy
124
+ return dx
125
+ end
126
+
127
+ # Allow for second derivatives, by writing rules for `∇getindex`:
128
+
129
+ function frule ((_, _, dẏ), :: typeof (∇getindex), x, dy, inds... )
130
+ return ∇getindex (x, dy, inds... ), ∇getindex (x, dẏ, inds... )
131
+ end
132
+
133
+ function rrule (:: typeof (∇getindex), x, dy, inds... )
134
+ z = ∇getindex (x, dy, inds... )
135
+ function ∇getindex_pullback (dz)
136
+ d2y = getindex (unthunk (dz), inds... )
137
+ nots = map (Returns (NoTangent ()), inds)
138
+ return (NoTangent (), NoTangent (), ProjectTo (dy)(d2y), nots... )
84
139
end
140
+ return z, ∇getindex_pullback
141
+ end
142
+
143
+ # Indexing with repeated indices on a GPU will lead ∇getindex to have race conditions & wrong answers.
144
+ # To avoid this, copy everything back to the CPU.
145
+ # But don't do that for indices which are known to be unique, e.g. `A[1, 2:3, :]` the colon gives Base.Slice:
85
146
86
- return y, getindex_pullback
147
+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds:: Integer... )
148
+ view (dx, inds... ) .+ = Ref (dy)
149
+ return dx
150
+ end
151
+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds:: Union{Integer, AbstractUnitRange, Base.Slice} ...)
152
+ view (dx, inds... ) .+ = dy
153
+ return dx
154
+ end
155
+ function ∇getindex! (dx:: AbstractGPUArray , dy, inds... )
156
+ dx_cpu = adapt (Array, dx)
157
+ view (dx_cpu, adapt (Array, inds)... ) .+ = adapt (Array, dy)
158
+ copyto! (dx, dx_cpu)
159
+ return dx
87
160
end
88
161
89
162
# ####
@@ -117,6 +190,23 @@ function frule((_, ẋ), ::typeof(view), x::AbstractArray, inds...)
117
190
return view (x, inds... ), view (ẋ, inds... )
118
191
end
119
192
193
+ function rrule (:: typeof (view), x:: AbstractArray , inds... )
194
+ function view_pullback (dy)
195
+ nots = map (Returns (NoTangent ()), inds)
196
+ return (NoTangent (), thunked_∇getindex (x, dy, inds... ), nots... )
197
+ end
198
+ return view (x, inds... ), view_pullback
199
+ end
200
+
201
+ function rrule (:: typeof (view), x:: AbstractArray , i:: Integer , jkl:: Integer... )
202
+ # This case returns a zero-dim array, unlike getindex. So we fool ∇getindex:
203
+ function view_pullback_0 (dy)
204
+ nots = map (Returns (NoTangent ()), (i, jkl... ))
205
+ return (NoTangent (), thunked_∇getindex (x, dy, i: i, jkl... ), nots... )
206
+ end
207
+ return view (x, i, jkl... ), view_pullback_0
208
+ end
209
+
120
210
# ####
121
211
# #### setindex!
122
212
# ####
@@ -125,6 +215,21 @@ function frule((_, ẋ, v̇), ::typeof(setindex!), x::AbstractArray, v, inds...)
125
215
return setindex! (x, v, inds... ), setindex! (ẋ, v̇, inds... )
126
216
end
127
217
218
+ # ####
219
+ # #### unsafe_getindex
220
+ # ####
221
+
222
+ # This is called by e.g. `iterate(1:0.1:2)`,
223
+ # and fixes https://github.com/FluxML/Zygote.jl/issues/1247
224
+ # Only needs to accept AbstractRange, but AbstractVector makes testing easier.
225
+
226
+ function frule ((_, ẋ), :: typeof (Base. unsafe_getindex), x:: AbstractVector , i:: Integer )
227
+ return Base. unsafe_getindex (x, i), getindex (ẋ, i)
228
+ end
229
+
230
+ function rrule (cfg:: RuleConfig{>:HasReverseMode} , :: typeof (Base. unsafe_getindex), x:: AbstractVector , i:: Integer )
231
+ return rrule_via_ad (cfg, getindex, x, i)
232
+ end
128
233
129
234
# ####
130
235
# #### `eachslice` and friends
0 commit comments