Skip to content

Commit ae37562

Browse files
authored
Allow single indexing of arrays of GPU arrays (#760)
* Allow single indexing of arrays of GPU arrays * bump version
1 parent 40b9058 commit ae37562

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.58.0"
3+
version = "1.58.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/indexing.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ end
144144
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)
145145

146146
function ∇getindex!(dx::AbstractArray, dy, inds::Integer...)
147-
view(dx, inds...) .+= Ref(dy)
147+
@views dx[inds...] += dy
148148
return dx
149149
end
150150
function ∇getindex!(dx::AbstractArray, dy, inds...)

test/rulesets/Base/indexing.jl

+8
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ end
177177
@test Array(y3) == Array(x_23_gpu)[1, [1,1,2]]
178178
@test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0])
179179
end
180+
181+
@testset "getindex(::Array{<:AbstractGPUArray})" begin
182+
x_gpu = jl(rand(1))
183+
y, back = rrule(getindex, [x_gpu], 1)
184+
@test y === x_gpu
185+
dxs_gpu = unthunk(back(jl([1.0]))[2])
186+
@test dxs_gpu == [jl([1.0])]
187+
end
180188
end
181189

182190
# first & tail handled by getfield rules

0 commit comments

Comments
 (0)