3
3
@layer Dense
4
4
@layer :expand Chain
5
5
@layer BatchNorm trainable=(β,γ)
6
- @layer Struct functor =(α,β) trainable=(β,)
6
+ @layer Struct children =(α,β) trainable=(β,)
7
7
8
8
This macro replaces most uses of `@functor` in Flux 0.14. Its basic purpose is the same:
9
9
When you define a new layer, this tells Flux to explore inside it
10
10
to see the parameters it trains, and also to move them to the GPU, change precision, etc.
11
+ Like `@functor`, this assumes your struct has the default constructor, to enable re-building.
11
12
12
13
Some "keywords" allow control of the recursion:
13
14
* If some fields look like parameters but should not be trained,
14
- then `Optimisers.trainable` lets you specify fields to include, and ignore the rest.
15
- * We can likewise add restructions to `Functors.functor`, but not yet written.
16
- * In fact you can provide an arbitrary keyword with this syntax, and it will
17
- overload this function alla `trainable`... that might be a terrible idea.
15
+ then `trainable` lets you specify fields to include, and ignore the rest.
16
+ * You can likewise add restructions to Functors's `children` (although this is seldom a good idea).
18
17
19
18
It also handles overloads of `show` for pretty printing.
20
19
* By default, it adds methods to 3-arg `Base.show` to treat your layer much like `Dense` or `Conv`.
21
20
* If your layer is a container, more like `Chain` or `Parallel`, then `:expand` makes `show` unfold its contents.
22
- * To disable all `show` overloads, maybe we want a `:ignore` option too.
21
+ * To disable all `show` overloads, there is an `:ignore` option too.
23
22
24
23
(You probably still want to define 2-arg `show(io::IO, x::Layer)`, the macro does not touch this.)
24
+
25
+ Note that re-running the macro with different options does not overwrite all methods, you will need to restart.
26
+
27
+ # Example
28
+ ```jldoctest
29
+ julia> struct Trio; a; b; c end
30
+
31
+ julia> tri = Trio(Dense([1.1 2.2],), Dense([3.3;;], false), Dropout(0.4))
32
+ Trio(Dense(1 => 1, tanh), Dense(1 => 1; bias=false), Dropout(0.4))
33
+
34
+ julia> Flux.destructure(tri) # parameters not visible to Flux
35
+ (Bool[], Restructure(Trio, ..., 0))
36
+
37
+ julia> Flux.@layer :expand Trio
38
+
39
+ julia> Flux.destructure(tri) # now gpu, train!, etc will see inside too
40
+ ([1.1, 2.2, 0.0, 3.3], Restructure(Trio, ..., 4))
41
+
42
+ julia> tri
43
+ Trio(
44
+ Dense(2 => 1), # 3 parameters
45
+ Dense(1 => 1; bias=false), # 1 parameters
46
+ Dropout(0.4),
47
+ ) # Total: 3 arrays, 4 parameters, 224 bytes.
48
+ ```
49
+
25
50
"""
26
51
macro layer (exs... )
27
52
out = quote end
@@ -40,10 +65,10 @@ macro layer(exs...)
40
65
end
41
66
42
67
# This function exists only for depwarns when you use @functor directly
43
- push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing )) # scope is weird ?? can't use $ on func name?
68
+ push! (out. args, :(Flux. _check_new_macro (:: $ (esc (type))) = nothing ))
44
69
45
- i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :functor , rest)
46
- if isnothing (i)
70
+ i = findfirst (ex -> Meta. isexpr (ex, :(= )) && ex. args[1 ] == :children , rest)
71
+ if isnothing (i) # then default like @functor Layer
47
72
push! (out. args, _macro_functor (esc (type)))
48
73
else
49
74
push! (out. args, _macro_functor (esc (type), rest[i]. args[2 ]))
@@ -52,54 +77,70 @@ macro layer(exs...)
52
77
j == i && continue
53
78
ex = rest[j]
54
79
Meta. isexpr (ex, :(= )) || error (" expected keyword = fields" )
55
- if ex. args[1 ] == :trainable
56
- push! (out. args, _macro_trainable (type, trainable, ex. args[2 ])) # pass the function "trainable" not the symbol
80
+
81
+ name = if ex. args[1 ] == :trainable
82
+ :(Optimisers. trainable)
57
83
else
58
- error ()
59
- # @warn "defining a method for $(ex.args[1]) in your scope" # ??
60
- # push!(out.args, _macro_trainable(type, esc(ex.args[1]), ex.args[2]))
84
+ @warn " trying to define a method for `$(ex. args[1 ]) ` in your scope... this is experimental" maxlog= 1
85
+ esc (ex. args[1 ])
61
86
end
87
+ push! (out. args, _macro_trainable (esc (type), name, ex. args[2 ]))
62
88
end
63
89
64
90
out
65
91
end
66
92
67
- # Temporary depwarn function:
93
+ # Temporary depwarn function, called within `params`, is also called by `show`.
68
94
69
95
function _check_new_macro (x:: T ) where T
70
96
Functors. isleaf (x) && return
71
- @warn " you used @functor for this type, but should now use @layer" T maxlog= 1 _id= hash (T)
97
+ @warn " This type should now use Flux. @layer instead of @functor " T maxlog= 1 _id= hash (T)
72
98
end
73
99
_check_new_macro (:: Tuple ) = nothing # defined by Functors.jl, not by users
74
100
_check_new_macro (:: NamedTuple ) = nothing
75
- _check_new_macro (:: Transpose ) = nothing
76
- _check_new_macro (:: Adjoint ) = nothing
101
+ _check_new_macro (:: AbstractArray ) = nothing
77
102
_check_new_macro (:: Ref ) = nothing
78
103
79
104
# @layer's code for Functors & Adapt
80
105
# Unlike @functor, _default_functor doesn't need to eval anything
81
106
82
107
function _macro_functor (type)
83
108
quote
84
- Functors. functor (:: Type{T} , x) where {T<: $type } = _default_functor (T, x)
85
- Adapt. adapt_structure (to, layer:: $type ) = fmap (adapt (to), layer)
109
+ Functors. functor (:: Type{T} , x) where {T<: $type } = $ _default_functor (T, x)
110
+ Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
86
111
end
87
112
end
88
113
89
114
function _macro_functor (type, fields)
90
- error (" the equivalent of @functor Layer (:x,) isn't written yet, sorry" )
115
+ Meta. isexpr (fields, :tuple ) || error (" expected a tuple of field names" )
116
+ symbols = Tuple (map (_noquotenode, fields. args))
117
+ quote
118
+ Functors. functor (:: Type{T} , x) where {T<: $type } = $ _custom_functor (T, x, Val ($ symbols))
119
+ Adapt. adapt_structure (to, layer:: $type ) = $ fmap ($ adapt (to), layer)
120
+ end
91
121
end
122
+ _macro_functor (type, field:: Union{Symbol,QuoteNode} ) = _macro_functor (type, :(($ field,))) # lets you forget a comma
92
123
93
124
function _default_functor (:: Type{T} , x) where {T}
94
125
if @generated
95
126
F = fieldnames (T)
96
127
args = map (sy -> :(getfield (x, $ (QuoteNode (sy)))), F)
97
- C = Base. typename (T). name # constructor
128
+ C = Base. typename (T). wrapper # constructor
98
129
recon = VERSION > v " 1.9-" ? :(Splat ($ C)) : :(Base. splat ($ C))
99
130
:((NamedTuple {$F} (($ (args... ),)), $ recon))
100
131
else
101
132
# Getting this parameterless type takes about 2μs, every time:
102
- namedtuple (x), Base. splat (Base. typename (T). wrapper)
133
+ spl = VERSION > v " 1.9-" ? Splat : Base. splat
134
+ namedtuple (x), spl (Base. typename (T). wrapper)
135
+ end
136
+ end
137
+
138
+ function _custom_functor (:: Type{T} , x, :: Val{which} ) where {T,which}
139
+ if false
140
+ # TODO write the @generated version
141
+ else
142
+ remake (nt) = Base. typename (T). wrapper (map (f -> f in which ? getfield (nt, f) : getfield (x, f), fieldnames (T))... )
143
+ NamedTuple {which} (map (s -> getfield (x, s), which)), remake
103
144
end
104
145
end
105
146
@@ -117,61 +158,12 @@ function _macro_trainable(type, fun, fields)
117
158
quoted = map (QuoteNode, symbols)
118
159
gets = [:(getfield (x, $ f)) for f in quoted]
119
160
quote
120
- # $fun(x::$type) = NamedTuple{$names }(($(gets...),))
121
- Flux. trainable (x:: $type ) = NamedTuple {$symbols} (($ (gets... ),)) # ?? scope is weird
161
+ $ fun (x:: $type ) = NamedTuple {$symbols } (($ (gets... ),))
162
+ # Flux.trainable(x::$type) = NamedTuple{$symbols}(($(gets...),)) # ?? scope is weird
122
163
end
123
164
end
124
165
_macro_trainable (type, fun, field:: Union{Symbol,QuoteNode} ) = _macro_trainable (type, fun, :(($ field,))) # lets you forget a comma
125
166
126
167
_noquotenode (s:: Symbol ) = s
127
168
_noquotenode (q:: QuoteNode ) = q. value # lets you write trainable=(:x,:y) instead of (x,y)
128
169
_noquotenode (ex) = error (" expected a symbol, got $ex " )
129
-
130
-
131
-
132
-
133
-
134
-
135
- # @big_show Chain
136
- # @big_show Parallel
137
- # @big_show SkipConnection
138
- # @big_show Recur
139
- # @big_show Maxout
140
-
141
-
142
-
143
-
144
- """
145
- @big_show MyContainer
146
-
147
- This macro lets you opt-in to Flux's fancy printing.
148
-
149
- When `model::MyContainer` is returned at the REPL it will be treated like `Chain`,
150
- and the printing routine will recursively unfold its children.
151
- This is triggered by adding a method to 3-arg `Base.show(io::IO, ::MIME"text/plain", l::MyContainer)`.
152
-
153
- Custom layers which do not contain other layers (more like `Dense` than like `Chain`)
154
- need not call this, and should simply define 2-arg `Base.show(io::IO, l::MyLayer)`.
155
-
156
- # Example
157
- ```jldoctest
158
- julia> struct Trio{A,B,C}; a::A; b::B; c::C end
159
-
160
- julia> Flux.@functor Trio
161
-
162
- julia> Flux.@big_show Trio
163
-
164
- julia> tri = Trio(Dense(10=>5,tanh), Dense(5=>2), softmax)
165
- Trio(
166
- Dense(10 => 5, tanh), # 55 parameters
167
- Dense(5 => 2), # 12 parameters
168
- NNlib.softmax,
169
- ) # Total: 4 arrays, 67 parameters, 492 bytes.
170
- ```
171
-
172
- Note that there is no automatic method for 2-arg `show`, and thus
173
- something like `(tri, tri)` will print all the type parameters.
174
-
175
- However, `Chain(tri, tri)` will always use Flux's recursive printing,
176
- even without using this macro: `Chain` is the entry point.
177
- """
0 commit comments