Skip to content

Commit c335d6d

Browse files
authored
Merge pull request #1286 from FluxML/bc/pairs-kwarg-indexing
Treat `pairs(NamedTuple)` as `NamedTuple` for indexing
2 parents bd5ce6e + 24a6111 commit c335d6d

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

src/lib/base.jl

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ end
119119

120120
# named tuple
121121
@adjoint function pairs(t::NamedTuple{N}) where N
122-
122+
123123
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
124124

125125
pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)
126-
126+
127127
function pairs_namedtuple_pullback::Dict)
128128
t0 = map(zero, t)
129129
for (idx, v) in Δ
@@ -145,6 +145,30 @@ else
145145
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
146146
end
147147

148+
# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple.
149+
# We can treat them much the same, just with some plumbing to handle the extra `itr` field.
150+
function _pullback(::AContext, ::typeof(getindex),
151+
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, k)
152+
# So we don't close over kwarg values in the pullback
153+
data = map(_ -> nothing, NamedTuple(ps))
154+
function kwargs_getindex_pullback(Δ)
155+
dps = (data = Base.setindex(data, Δ, k), itr = nothing)
156+
return (nothing, dps, nothing)
157+
end
158+
return ps[k], kwargs_getindex_pullback
159+
end
160+
161+
function _pullback(cx::AContext, ::typeof(literal_getindex),
162+
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
163+
val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
164+
function kwargs_literal_getindex_pullback(Δ)
165+
dps = (data = gf_back(Δ)[2], itr = nothing)
166+
return (nothing, dps, nothing)
167+
end
168+
return val, kwargs_literal_getindex_pullback
169+
end
170+
171+
# Misc.
148172
@adjoint function Base.getfield(p::Pair, i::Int)
149173
function pair_getfield_pullback(Δ)
150174
f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)

test/features.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,17 @@ end
552552
@test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),)
553553
@test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],)
554554
@test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],)
555+
556+
@testset "indexing kwargs" begin
557+
inner_lit_index(; kwargs...) = kwargs[:x]
558+
outer_lit_index(; kwargs...) = inner_lit_index(; x=kwargs[:x])
559+
560+
inner_dyn_index(k; kwargs...) = kwargs[k]
561+
outer_dyn_index(k; kwargs...) = inner_dyn_index(k; x=kwargs[k])
562+
563+
@test gradient(x -> outer_lit_index(; x), 0.0) == (1.0,)
564+
@test gradient((x, k) -> outer_dyn_index(k; x), 0.0, :x) == (1.0, nothing)
565+
end
555566
end
556567

557568
function type_test()
@@ -562,7 +573,7 @@ end
562573

563574
@testset "Pairs" begin
564575
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
565-
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
576+
@test (x->10*pairs((a=x, b=2))[2])'(100) === nothing
566577
foo(;kw...) = 1
567578
@test gradient(() -> foo(a=1,b=2.0)) === ()
568579

@@ -578,8 +589,8 @@ end
578589
@testset "kwarg splatting, pass in object" begin
579590
g(; kwargs...) = kwargs[:x] * kwargs[:z]
580591
h(somedata) = g(; somedata...)
581-
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
582-
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
592+
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
593+
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)
583594
end
584595

585596
@testset "Iterators" begin

0 commit comments

Comments
 (0)