diff --git a/Project.toml b/Project.toml index b66176754..a420d3ed9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.0" +version = "1.44.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 30e492d2e..9f63eeb11 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -5,7 +5,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad using ChainRulesCore using Compat using Distributed -using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle +using GPUArraysCore: AbstractGPUArray, AbstractGPUArrayStyle, @allowscalar using IrrationalConstants: logtwo, logten using LinearAlgebra using LinearAlgebra.BLAS diff --git a/src/rulesets/Base/array.jl b/src/rulesets/Base/array.jl index 49a4e1ac2..cc1d3d36c 100644 --- a/src/rulesets/Base/array.jl +++ b/src/rulesets/Base/array.jl @@ -216,6 +216,7 @@ _catsize(x::AbstractArray) = size(x) function rrule(::typeof(hcat), Xs...) Y = hcat(Xs...) # note that Y always has 1-based indexing, even if X isa OffsetArray + Base.require_one_based_indexing(Y) ndimsY = Val(ndims(Y)) # this avoids closing over Y, Val() is essential for type-stability sizes = map(_catsize, Xs) # this avoids closing over Xs project_Xs = map(ProjectTo, Xs) @@ -233,15 +234,10 @@ function rrule(::typeof(hcat), Xs...) d > ndimsX ? 1 : (:) end end - dX = if ndimsX > 0 - # Here InplaceableThunk breaks @inferred, removed for now - # InplaceableThunk(dX -> dX .+= view(dY, ind...), @thunk(dY[ind...])) - dY[ind...] - else - # This is a hack to perhaps avoid GPU scalar indexing - sum(view(dY, ind...)) - end - return project(dX) + InplaceableThunk( + dX -> dX .+= view(dY, ind...), + @thunk project(@allowscalar dY[ind...]) + ) end return (NoTangent(), dXs...) end @@ -253,6 +249,8 @@ function frule((_, _, Ȧs), ::typeof(reduce), ::typeof(hcat), As::AbstractVecto end function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVecOrMat}) + Y = reduce(hcat, As) + Base.require_one_based_indexing(Y) widths = map(A -> size(A,2), As) function reduce_hcat_pullback_2(dY) hi = Ref(0) @@ -263,7 +261,7 @@ function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVe end return (NoTangent(), NoTangent(), dAs) end - return reduce(hcat, As), reduce_hcat_pullback_2 + return Y, reduce_hcat_pullback_2 end function rrule(::typeof(reduce), ::typeof(hcat), As::AbstractVector{<:AbstractVector}) @@ -286,6 +284,7 @@ end function rrule(::typeof(vcat), Xs...) Y = vcat(Xs...) + Base.require_one_based_indexing(Y) ndimsY = Val(ndims(Y)) sizes = map(_catsize, Xs) project_Xs = map(ProjectTo, Xs) @@ -303,13 +302,10 @@ function rrule(::typeof(vcat), Xs...) d > ndimsX ? 1 : (:) end end - dX = if ndimsX > 0 - # InplaceableThunk(@thunk(dY[ind...]), dX -> dX .+= view(dY, ind...)) - dY[ind...] - else - sum(view(dY, ind...)) - end - return project(dX) + InplaceableThunk( + dX -> dX .+= view(dY, ind...), + @thunk project(@allowscalar dY[ind...]) + ) end return (NoTangent(), dXs...) end @@ -322,6 +318,7 @@ end function rrule(::typeof(reduce), ::typeof(vcat), As::AbstractVector{<:AbstractVecOrMat}) Y = reduce(vcat, As) + Base.require_one_based_indexing(Y) ndimsY = Val(ndims(Y)) heights = map(A -> size(A,1), As) function reduce_vcat_pullback(dY) @@ -349,6 +346,7 @@ end function rrule(::typeof(cat), Xs...; dims) Y = cat(Xs...; dims=dims) + Base.require_one_based_indexing(Y) cdims = dims isa Val ? Int(_val(dims)) : dims isa Integer ? Int(dims) : Tuple(dims) ndimsY = Val(ndims(Y)) sizes = map(_catsize, Xs) @@ -368,13 +366,10 @@ function rrule(::typeof(cat), Xs...; dims) for d in cdims prev[d] += get(sizeX, d, 1) end - dX = if ndimsX > 0 - # InplaceableThunk(@thunk(dY[index...]), dX -> dX .+= view(dY, index...)) - dY[index...] - else - sum(view(dY, index...)) - end - return project(dX) + InplaceableThunk( + dX -> dX .+= view(dY, index...), + @thunk project(@allowscalar dY[index...]) + ) end return (NoTangent(), dXs...) end