Skip to content

Commit 8981283

Browse files
committed
fixup
1 parent 6370374 commit 8981283

File tree

10 files changed

+150
-126
lines changed

10 files changed

+150
-126
lines changed

docs/src/models/basics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ m(5) # => 26
216216
Flux provides a set of helpers for custom layers, which you can enable by calling
217217

218218
```julia
219-
Flux.@functor Affine
219+
Flux.@layer Affine
220220
```
221221

222222
This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using MacroTools: @forward
77

88
@reexport using NNlib
99
using MLUtils
10+
const stack = MLUtils.stack # now exported by Base
1011
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1112

1213
using Zygote, ChainRulesCore

src/layers/basic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ struct SkipConnection{T,F}
338338
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
339339
end
340340

341-
@layer SkipConnection # should this be expand?
341+
@layer :expand SkipConnection
342342

343343
function (skip::SkipConnection)(input)
344344
skip.connection(skip.layers(input), input)

src/layers/macro.jl

Lines changed: 66 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,50 @@
33
@layer Dense
44
@layer :expand Chain
55
@layer BatchNorm trainable=(β,γ)
6-
@layer Struct functor=(α,β) trainable=(β,)
6+
@layer Struct children=(α,β) trainable=(β,)
77
88
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
99
When you define a new layer, this tells Flux to explore inside it
1010
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11+
Like `@functor`, this assumes your struct has the default constructor, to enable re-building.
1112
1213
Some "keywords" allow control of the recursion:
1314
* If some fields look like parameters but should not be trained,
14-
then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15-
* We can likewise add restructions to `Functors.functor`, but not yet written.
16-
* In fact you can provide an arbitrary keyword with this syntax, and it will
17-
overload this function alla `trainable`... that might be a terrible idea.
15+
then `trainable` lets you specify fields to include, and ignore the rest.
16+
* You can likewise add restructions to Functors's `children` (although this is seldom a good idea).
1817
1918
It also handles overloads of `show` for pretty printing.
2019
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
2120
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22-
* To disable all `show` overloads, maybe we want a `:ignore` option too.
21+
* To disable all `show` overloads, there is an `:ignore` option too.
2322
2423
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
24+
25+
Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26+
27+
# Example
28+
```jldoctest
29+
julia> struct Trio; a; b; c end
30+
31+
julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32+
Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
33+
34+
julia> Flux.destructure(tri) # parameters not visible to Flux
35+
(Bool[], Restructure(Trio, ..., 0))
36+
37+
julia> Flux.@layer :expand Trio
38+
39+
julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
40+
([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
41+
42+
julia> tri
43+
Trio(
44+
Dense(2 => 1), # 3 parameters
45+
Dense(1 => 1; bias=false), # 1 parameters
46+
Dropout(0.4),
47+
) # Total: 3 arrays, 4 parameters, 224 bytes.
48+
```
49+
2550
"""
2651
macro layer(exs...)
2752
out = quote end
@@ -40,10 +65,10 @@ macro layer(exs...)
4065
end
4166

4267
# This function exists only for depwarns when you use @functor directly
43-
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing)) # scope is weird ?? can't use $ on func name?
68+
push!(out.args, :(Flux._check_new_macro(::$(esc(type))) = nothing))
4469

45-
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :functor, rest)
46-
if isnothing(i)
70+
i = findfirst(ex -> Meta.isexpr(ex, :(=)) && ex.args[1] == :children, rest)
71+
if isnothing(i) # then default like @functor Layer
4772
push!(out.args, _macro_functor(esc(type)))
4873
else
4974
push!(out.args, _macro_functor(esc(type), rest[i].args[2]))
@@ -52,54 +77,70 @@ macro layer(exs...)
5277
j == i && continue
5378
ex = rest[j]
5479
Meta.isexpr(ex, :(=)) || error("expected keyword = fields")
55-
if ex.args[1] == :trainable
56-
push!(out.args, _macro_trainable(type, trainable, ex.args[2])) # pass the function "trainable" not the symbol
80+
81+
name = if ex.args[1] == :trainable
82+
:(Optimisers.trainable)
5783
else
58-
error()
59-
# @warn "defining a method for $(ex.args[1]) in your scope" # ??
60-
# push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
84+
@warn "trying to define a method for `$(ex.args[1])` in your scope... this is experimental" maxlog=1
85+
esc(ex.args[1])
6186
end
87+
push!(out.args, _macro_trainable(esc(type), name, ex.args[2]))
6288
end
6389

6490
out
6591
end
6692

67-
# Temporary depwarn function:
93+
# Temporary depwarn function, called within `params`, is also called by `show`.
6894

6995
function _check_new_macro(x::T) where T
7096
Functors.isleaf(x) && return
71-
@warn "you used @functor for this type, but should now use @layer" T maxlog=1 _id=hash(T)
97+
@warn "This type should now use Flux.@layer instead of @functor" T maxlog=1 _id=hash(T)
7298
end
7399
_check_new_macro(::Tuple) = nothing # defined by Functors.jl, not by users
74100
_check_new_macro(::NamedTuple) = nothing
75-
_check_new_macro(::Transpose) = nothing
76-
_check_new_macro(::Adjoint) = nothing
101+
_check_new_macro(::AbstractArray) = nothing
77102
_check_new_macro(::Ref) = nothing
78103

79104
# @layer's code for Functors & Adapt
80105
# Unlike @functor, _default_functor doesn't need to eval anything
81106

82107
function _macro_functor(type)
83108
quote
84-
Functors.functor(::Type{T}, x) where {T<:$type} = _default_functor(T, x)
85-
Adapt.adapt_structure(to, layer::$type) = fmap(adapt(to), layer)
109+
Functors.functor(::Type{T}, x) where {T<:$type} = $_default_functor(T, x)
110+
Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer)
86111
end
87112
end
88113

89114
function _macro_functor(type, fields)
90-
error("the equivalent of @functor Layer (:x,) isn't written yet, sorry")
115+
Meta.isexpr(fields, :tuple) || error("expected a tuple of field names")
116+
symbols = Tuple(map(_noquotenode, fields.args))
117+
quote
118+
Functors.functor(::Type{T}, x) where {T<:$type} = $_custom_functor(T, x, Val($symbols))
119+
Adapt.adapt_structure(to, layer::$type) = $fmap($adapt(to), layer)
120+
end
91121
end
122+
_macro_functor(type, field::Union{Symbol,QuoteNode}) = _macro_functor(type, :(($field,))) # lets you forget a comma
92123

93124
function _default_functor(::Type{T}, x) where {T}
94125
if @generated
95126
F = fieldnames(T)
96127
args = map(sy -> :(getfield(x, $(QuoteNode(sy)))), F)
97-
C = Base.typename(T).name # constructor
128+
C = Base.typename(T).wrapper # constructor
98129
recon = VERSION > v"1.9-" ? :(Splat($C)) : :(Base.splat($C))
99130
:((NamedTuple{$F}(($(args...),)), $recon))
100131
else
101132
# Getting this parameterless type takes about 2μs, every time:
102-
namedtuple(x), Base.splat(Base.typename(T).wrapper)
133+
spl = VERSION > v"1.9-" ? Splat : Base.splat
134+
namedtuple(x), spl(Base.typename(T).wrapper)
135+
end
136+
end
137+
138+
function _custom_functor(::Type{T}, x, ::Val{which}) where {T,which}
139+
if false
140+
# TODO write the @generated version
141+
else
142+
remake(nt) = Base.typename(T).wrapper(map(f -> f in which ? getfield(nt, f) : getfield(x, f), fieldnames(T))...)
143+
NamedTuple{which}(map(s -> getfield(x, s), which)), remake
103144
end
104145
end
105146

@@ -117,61 +158,12 @@ function _macro_trainable(type, fun, fields)
117158
quoted = map(QuoteNode, symbols)
118159
gets = [:(getfield(x, $f)) for f in quoted]
119160
quote
120-
# $fun(x::$type) = NamedTuple{$names}(($(gets...),))
121-
Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
161+
$fun(x::$type) = NamedTuple{$symbols}(($(gets...),))
162+
# Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122163
end
123164
end
124165
_macro_trainable(type, fun, field::Union{Symbol,QuoteNode}) = _macro_trainable(type, fun, :(($field,))) # lets you forget a comma
125166

126167
_noquotenode(s::Symbol) = s
127168
_noquotenode(q::QuoteNode) = q.value # lets you write trainable=(:x,:y) instead of (x,y)
128169
_noquotenode(ex) = error("expected a symbol, got $ex")
129-
130-
131-
132-
133-
134-
135-
# @big_show Chain
136-
# @big_show Parallel
137-
# @big_show SkipConnection
138-
# @big_show Recur
139-
# @big_show Maxout
140-
141-
142-
143-
144-
"""
145-
@big_show MyContainer
146-
147-
This macro lets you opt-in to Flux's fancy printing.
148-
149-
When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150-
and the printing routine will recursively unfold its children.
151-
This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152-
153-
Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154-
need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155-
156-
# Example
157-
```jldoctest
158-
julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159-
160-
julia> Flux.@functor Trio
161-
162-
julia> Flux.@big_show Trio
163-
164-
julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165-
Trio(
166-
Dense(10 => 5, tanh), # 55 parameters
167-
Dense(5 => 2), # 12 parameters
168-
NNlib.softmax,
169-
) # Total: 4 arrays, 67 parameters, 492 bytes.
170-
```
171-
172-
Note that there is no automatic method for 2-arg `show`, and thus
173-
something like `(tri, tri)` will print all the type parameters.
174-
175-
However, `Chain(tri, tri)` will always use Flux's recursive printing,
176-
even without using this macro: `Chain` is the entry point.
177-
"""

src/layers/normalise.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ end
178178
testmode!(m::AlphaDropout, mode=true) =
179179
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
180180

181+
Base.show(io::IO, d::AlphaDropout) = print(io, "AlphaDropout(", d.p, ")")
182+
181183
"""
182184
LayerNorm(size..., λ=identity; affine=true, ϵ=1fe-5)
183185

src/layers/recurrent.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ function (m::Recur)(x)
136136
end
137137

138138
@layer :expand Recur trainable=(cell,)
139-
# trainable(a::Recur) = (; cell = a.cell)
140139

141140
Base.show(io::IO, m::Recur) = print(io, "Recur(", m.cell, ")")
142141

src/layers/show.jl

Lines changed: 34 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,30 @@
11
@nospecialize # just for this file, for startup time
22

3-
# This is called by @layer, on layers which should be treated like Chain
3+
# This is called by @layer :expand, on layers which should be treated like Chain, and returns an expression:
44
function _macro_big_show(ex)
5-
quote
6-
# Entry point:
7-
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
8-
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
9-
_big_show(io, x)
10-
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
11-
_layer_show(io, x)
12-
else
13-
show(io, x)
14-
end
15-
end
16-
17-
# Don't show Chain(Tuple(...)), always splat that:
18-
_show_children(x::$ex) = _flat_children(x)
5+
quote
6+
# Entry point:
7+
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
8+
if get(io, :typeinfo, nothing) === nothing # e.g. top level in REPL
9+
_big_show(io, x)
10+
elseif !get(io, :compact, false) # e.g. printed inside a Vector, but not a Matrix
11+
_layer_show(io, x)
12+
else
13+
show(io, x)
14+
end
1915
end
16+
17+
# Don't show Chain(Tuple(...)), always splat that. And ignore Recur's non-trainable state:
18+
Flux._show_children(x::$ex) = _flat_children(trainable(x))
19+
end
2020
end
2121

2222
function _big_show(io::IO, obj, indent::Int=0, name=nothing)
2323
pre, post = obj isa Chain{<:AbstractVector} ? ("([", "])") : ("(", ")")
2424
children = _show_children(obj)
2525
if all(_show_leaflike, children)
26+
# This check may not be useful anymore: it tries to infer when to stop the recursion by looking for grandkids,
27+
# but once all layers use @layer, they stop the recursion by defining a method for _big_show.
2628
_layer_show(io, obj, indent, name)
2729
else
2830
println(io, " "^indent, isnothing(name) ? "" : "$name = ", nameof(typeof(obj)), pre)
@@ -56,48 +58,32 @@ _show_leaflike(::Tuple{Vararg{<:AbstractArray}}) = true # e.g. parameters of LS
5658
# _show_leaflike(::Scale) = true # appears inside LayerNorm
5759
_show_leaflike(::AbstractArray{<:Number}) = true # e.g. transposed arrays
5860

59-
_show_children(x) = trainable(x) # except for layers which hide their Tuple:
60-
# _show_children(c::Chain) = c.layers
61-
# _show_children(m::Maxout) = m.layers
62-
# _show_children(p::Parallel) = (p.connection, p.layers...)
63-
# _show_children(f::PairwiseFusion) = (f.connection, f.layers...)
64-
61+
_show_children(x) = trainable(x)
62+
# This used to have methods for Chain, Maxout, Parallel, PairwiseFusion. Now @layer instead
63+
# writes a method to use this function. It flattens the Tuple within Chain etc.
64+
# (The remaining special cases are for printing of layer names when a NamedTuple, above.)
6565
function _flat_children(x)
6666
alpha = map(f -> getfield(x, f), fieldnames(typeof(x)))
6767
beta = map(y -> y isa Union{Tuple, NamedTuple} ? y : (y,), alpha)
6868
gamma = ((beta...)...,)
6969
end
7070

71-
# This is called by @layer, on layers which should be treated like Dense
71+
# This is called by @layer, on layers which should be treated like Dense, and returns an expression:
7272
function _macro_layer_show(ex)
73-
quote
74-
# Entry point:
75-
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
76-
if !get(io, :compact, false)
77-
_layer_show(io, x)
78-
else
79-
show(io, x)
80-
end
81-
end
82-
83-
# Exit from _big_show recursion, do we need this and _show_leaflike?
84-
_big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
85-
# Since this isn't a container, do not recurse into its children, if any:
86-
_show_leaflike(::$ex) = true
73+
quote
74+
# Entry point:
75+
function Base.show(io::IO, m::MIME"text/plain", x::$ex)
76+
if !get(io, :compact, false)
77+
_layer_show(io, x)
78+
else
79+
show(io, x)
80+
end
8781
end
82+
83+
# Exit from _big_show recursion:
84+
Flux._big_show(io::IO, obj::$ex, indent::Int=0, name=nothing) = _layer_show(io, obj, indent, name)
85+
end
8886
end
89-
# for T in [
90-
# :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding,
91-
# :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm,
92-
# ]
93-
# @eval function Base.show(io::IO, m::MIME"text/plain", x::$T)
94-
# if !get(io, :compact, false)
95-
# _layer_show(io, x)
96-
# else
97-
# show(io, x)
98-
# end
99-
# end
100-
# end
10187

10288
function _layer_show(io::IO, layer, indent::Int=0, name=nothing)
10389
_str = isnothing(name) ? "" : "$name = "

0 commit comments

Comments
 (0)