Skip to content

Commit 63cc4e0

Browse files
authored
Improve cat rules (#660)
* use allowscalar in cat rules * use require_one_based_indexing * restore InplaceableThunk
1 parent 77ef0eb commit 63cc4e0

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.44.0"
3+
version = "1.44.1"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/ChainRules.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
55
using ChainRulesCore
66
using Compat
77
using Distributed
8-
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle
8+
using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle, @allowscalar
99
using IrrationalConstants: logtwo, logten
1010
using LinearAlgebra
1111
using LinearAlgebra.BLAS

src/rulesets/Base/array.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ _catsize(x::AbstractArray) = size(x)
216216

217217
function rrule(::typeof(hcat), Xs...)
218218
Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray
219+
Base.require_one_based_indexing(Y)
219220
ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability
220221
sizes = map(_catsize, Xs) # this avoids closing over Xs
221222
project_Xs = map(ProjectTo, Xs)
@@ -233,15 +234,10 @@ function rrule(::typeof(hcat), Xs...)
233234
d > ndimsX ? 1 : (:)
234235
end
235236
end
236-
dX = if ndimsX > 0
237-
# Here InplaceableThunk breaks @inferred, removed for now
238-
# InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...]))
239-
dY[ind...]
240-
else
241-
# This is a hack to perhaps avoid GPU scalar indexing
242-
sum(view(dY, ind...))
243-
end
244-
return project(dX)
237+
InplaceableThunk(
238+
dX -> dX .+= view(dY, ind...),
239+
@thunk project(@allowscalar dY[ind...])
240+
)
245241
end
246242
return (NoTangent(), dXs...)
247243
end
@@ -253,6 +249,8 @@ function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVecto
253249
end
254250

255251
function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat})
252+
Y = reduce(hcat, As)
253+
Base.require_one_based_indexing(Y)
256254
widths = map(A -> size(A,2), As)
257255
function reduce_hcat_pullback_2(dY)
258256
hi = Ref(0)
@@ -263,7 +261,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe
263261
end
264262
return (NoTangent(), NoTangent(), dAs)
265263
end
266-
return reduce(hcat, As), reduce_hcat_pullback_2
264+
return Y, reduce_hcat_pullback_2
267265
end
268266

269267
function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVector})
@@ -286,6 +284,7 @@ end
286284

287285
function rrule(::typeof(vcat), Xs...)
288286
Y = vcat(Xs...)
287+
Base.require_one_based_indexing(Y)
289288
ndimsY = Val(ndims(Y))
290289
sizes = map(_catsize, Xs)
291290
project_Xs = map(ProjectTo, Xs)
@@ -303,13 +302,10 @@ function rrule(::typeof(vcat), Xs...)
303302
d > ndimsX ? 1 : (:)
304303
end
305304
end
306-
dX = if ndimsX > 0
307-
# InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...))
308-
dY[ind...]
309-
else
310-
sum(view(dY, ind...))
311-
end
312-
return project(dX)
305+
InplaceableThunk(
306+
dX -> dX .+= view(dY, ind...),
307+
@thunk project(@allowscalar dY[ind...])
308+
)
313309
end
314310
return (NoTangent(), dXs...)
315311
end
@@ -322,6 +318,7 @@ end
322318

323319
function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat})
324320
Y = reduce(vcat, As)
321+
Base.require_one_based_indexing(Y)
325322
ndimsY = Val(ndims(Y))
326323
heights = map(A -> size(A,1), As)
327324
function reduce_vcat_pullback(dY)
@@ -349,6 +346,7 @@ end
349346

350347
function rrule(::typeof(cat), Xs...; dims)
351348
Y = cat(Xs...; dims=dims)
349+
Base.require_one_based_indexing(Y)
352350
cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims)
353351
ndimsY = Val(ndims(Y))
354352
sizes = map(_catsize, Xs)
@@ -368,13 +366,10 @@ function rrule(::typeof(cat), Xs...; dims)
368366
for d in cdims
369367
prev[d] += get(sizeX, d, 1)
370368
end
371-
dX = if ndimsX > 0
372-
# InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...))
373-
dY[index...]
374-
else
375-
sum(view(dY, index...))
376-
end
377-
return project(dX)
369+
InplaceableThunk(
370+
dX -> dX .+= view(dY, index...),
371+
@thunk project(@allowscalar dY[index...])
372+
)
378373
end
379374
return (NoTangent(), dXs...)
380375
end

0 commit comments

Comments
 (0)