-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
While replying to this https://discourse.julialang.org/t/ann-forwarddiffpullbacks-jl-forwarddiff-based-chainrulescore-pullbacks/78737/10 I got an error in the following example:
julia> using ForwardDiffPullbacks, Zygote
julia> Zygote.gradient([1,2,3], 4) do xs, y
f3 = x -> abs2(@show(x)/y)
sum(fwddiff(f3).(xs)) # this cannot track gradient w.r.t. y
end
x = 1
x = 2
x = 3
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(1,1)
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(2,1)
x = Dual{ForwardDiff.Tag{Tuple{var"#84#86"{Int64}, Val{1}}, Int64}}(3,1)
([0.125, 0.25, 0.375], nothing)
julia> Zygote.gradient([1,2,3], 4) do xs, y
f3 = x -> abs2(@show(x)/y)
sum(f3.(xs)) # reverts to slower generic broadcast, no Dual
end
x = 1
x = 2
x = 3
([0.125, 0.25, 0.375], -0.4375)
julia> f4(x,y) = abs2(@show(x)/y);
julia> Zygote.gradient((xs,y) -> sum(f4.(xs, y)), [1,2,3], 4)
x = Dual{Nothing}(1,1,0)
x = Dual{Nothing}(2,1,0)
x = Dual{Nothing}(3,1,0)
([0.125, 0.25, 0.375], -0.4375)
julia> Zygote.gradient((xs,y) -> sum(fwddiff(f4).(xs, y)), [1,2,3], 4)
x = 1
x = 2
x = 3
x = Dual{ForwardDiff.Tag{Tuple{typeof(f4), Val{1}}, Int64}}(1,1)
x = Dual{ForwardDiff.Tag{Tuple{typeof(f4), Val{1}}, Int64}}(2,1)
x = Dual{ForwardDiff.Tag{Tuple{typeof(f4), Val{1}}, Int64}}(3,1)
x = 1
x = 2
x = 3
ERROR: MethodError: no method matching (::ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}})(::Vector{Float64})
Closest candidates are:
(::ChainRulesCore.ProjectTo{T})(::AbstractFloat) where T<:AbstractFloat at ~/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:171
(::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Complex}) at ~/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:192
(::ChainRulesCore.ProjectTo{<:Number})(::ChainRulesCore.Tangent{<:Number}) at ~/.julia/packages/ChainRulesCore/RbX5a/src/projection.jl:193
...
Stacktrace:
[1] _project
@ ~/.julia/packages/Zygote/H6vD3/src/compiler/chainrules.jl:182 [inlined]
[2] map(f::typeof(Zygote._project), t::Tuple{Vector{Int64}, Int64}, s::Tuple{Vector{Float64}, Vector{Float64}})
@ Base ./tuple.jl:247
[3] gradient(::Function, ::Vector{Int64}, ::Vararg{Any})
@ Zygote ~/.julia/packages/Zygote/H6vD3/src/compiler/interface.jl:77The bug may well be in Zygote, or ChainRulesCore, and thus my fault... but I open this somewhere so as not to forget.
Metadata
Metadata
Assignees
Labels
No labels