diff --git a/Project.toml b/Project.toml index 581934ac9..7ea574d09 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.44.5" +version = "1.44.6" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index d376a64f0..3519a097b 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -128,13 +128,14 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + deltas = unzip_broadcast(backs, dys) do back, dy # (could be map, sizes match) map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) df = ProjectTo(f)(sum(first(deltas))) return (NoTangent(), NoTangent(), df, dargs...) end + back_generic(dys::AbstractThunk) = back_generic(unthunk(dys)) back_generic(z::AbstractZero) = (TRI_NO..., map(Returns(z), args)...) return ys3, back_generic end @@ -318,7 +319,7 @@ rrule(::typeof(broadcasted), ::typeof(complex), x::Number) = rrule(complex, x) | function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw) dx = unthunk(dx_raw) - N = ndims(dx) + N = _ndims(dx) if length(x) == length(dx) ProjectTo(x)(dx) # handles trivial reshapes, offsets, structured matrices, row vectors else @@ -328,6 +329,9 @@ function unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx_raw) end unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx +_ndims(x) = ndims(x) +_ndims(::Tuple) = 1 + function unbroadcast(x::T, dx_raw) where {T<:Tuple{Vararg{Any,N}}} where {N} dx = unthunk(dx_raw) val = if N == length(dx) diff --git a/src/rulesets/Base/nondiff.jl b/src/rulesets/Base/nondiff.jl index 4d208a95b..22aeb1748 100644 --- a/src/rulesets/Base/nondiff.jl +++ b/src/rulesets/Base/nondiff.jl @@ -477,6 +477,11 @@ end @non_differentiable Broadcast.result_style(::Any) @non_differentiable Broadcast.result_style(::Any, ::Any) +@non_differentiable Base.CoreLogging.current_logger_for_env(::Any...) +@non_differentiable Base.CoreLogging._invoked_shouldlog(::Any...) +@non_differentiable Base.CoreLogging.Base.fixup_stdlib_path(::Any) +@non_differentiable Base.CoreLogging.handle_message(::Any...) + @non_differentiable Libc.free(::Any) @non_differentiable Libc.getpid() @non_differentiable Libc.strptime(::AbstractString) diff --git a/test/rulesets/Base/broadcast.jl b/test/rulesets/Base/broadcast.jl index 219b45a71..331616cce 100644 --- a/test/rulesets/Base/broadcast.jl +++ b/test/rulesets/Base/broadcast.jl @@ -176,5 +176,6 @@ BT1 = Broadcast.BroadcastStyle(Tuple) @testset "bugs" begin @test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type + @test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple) end end \ No newline at end of file