Skip to content

Commit 97fdcd1

Browse files
mcabbottisentropic
authored andcommitted
Add a macro to opt-in to fancy printing, and to everything else (FluxML#1932)
1 parent 960f573 commit 97fdcd1

File tree

15 files changed

+350
-71
lines changed

15 files changed

+350
-71
lines changed

NEWS.md

+8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a complete list of PRs merged before each release.
44

5+
## v0.14.13
6+
* New macro `Flux.@layer` which should be used in place of `@functor`.
7+
This also adds `show` methods for pretty printing.
8+
9+
## v0.14.12
10+
* New `SignDecay` optimiser, like `` WeightNorm` but for L1 norm.
11+
512
## v0.14.0 (July 2023)
613
* Flux now requires julia v1.9 or later.
714
* CUDA.jl is not a hard dependency anymore. Support is now provided through the extension mechanism, by loading `using Flux, CUDA`.
@@ -51,6 +58,7 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
5158

5259
## v0.13.6
5360
* Use the package [OneHotArrays.jl](https://github.com/FluxML/OneHotArrays.jl) instead of having the same code here.
61+
* Added [`@autosize` macro](https://github.com/FluxML/Flux.jl/pull/2078)
5462

5563
## v0.13.4
5664
* Added [`PairwiseFusion` layer](https://github.com/FluxML/Flux.jl/pull/1983)

docs/src/models/advanced.md

+16-9
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ function (m::CustomModel)(x)
1818
return m.chain(x) + x
1919
end
2020

21-
# Call @functor to allow for training. Described below in more detail.
22-
Flux.@functor CustomModel
21+
# Call @layer to allow for training. Described below in more detail.
22+
Flux.@layer CustomModel
2323
```
2424

2525
You can then use the model like:
@@ -39,15 +39,15 @@ Taking reference from our example `Affine` layer from the [basics](@ref man-basi
3939
By default all the fields in the `Affine` type are collected as its parameters, however, in some cases it may be desired to hold other metadata in our "layers" that may not be needed for training, and are hence supposed to be ignored while the parameters are collected. With Flux, the way to mark some fields of our layer as trainable is through overloading the `trainable` function:
4040

4141
```julia-repl
42-
julia> Flux.@functor Affine
42+
julia> @layer Affine
4343
4444
julia> a = Affine(Float32[1 2; 3 4; 5 6], Float32[7, 8, 9])
4545
Affine(Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0])
4646
4747
julia> Flux.params(a) # default behavior
4848
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0], Float32[7.0, 8.0, 9.0]])
4949
50-
julia> Flux.trainable(a::Affine) = (; a.W) # returns a NamedTuple using the field's name
50+
julia> Flux.trainable(a::Affine) = (; W = a.W) # returns a NamedTuple using the field's name
5151
5252
julia> Flux.params(a)
5353
Params([Float32[1.0 2.0; 3.0 4.0; 5.0 6.0]])
@@ -67,7 +67,14 @@ julia> Flux.params(Affine(true, [10, 11, 12.0]))
6767
Params([])
6868
```
6969

70-
It is also possible to further restrict what fields are seen by writing `@functor Affine (W,)`. However, this is not recommended. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument, and the ignored fields will not be seen by functions like `gpu` (which is usually undesired).
70+
The exact same method of `trainable` can also be defined using the macro, for convenience:
71+
72+
```julia
73+
Flux.@layer Affine trainable=(W,)
74+
```
75+
76+
There is a second, more severe, kind of restriction possible. This is not recommended, but is included here for completeness. Calling `Functors.@functor Affine (W,)` means that all no exploration of the model will ever visit the other fields: They will not be moved to the GPU by [`gpu`](@ref), and their precision will not be changed by `f32`. This requires the `struct` to have a corresponding constructor that accepts only `W` as an argument.
77+
7178

7279
## Custom multiple input or output layer
7380

@@ -95,9 +102,9 @@ Join(combine, paths...) = Join(combine, paths)
95102
```
96103
Notice that we parameterized the type of the `paths` field. This is necessary for fast Julia code; in general, `T` might be a `Tuple` or `Vector`, but we don't need to pay attention to what it specifically is. The same goes for the `combine` field.
97104

98-
The next step is to use [`Functors.@functor`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
105+
The next step is to use [`Functors.@layer`](@ref) to make our struct behave like a Flux layer. This is important so that calling `params` on a `Join` returns the underlying weight arrays on each path.
99106
```julia
100-
Flux.@functor Join
107+
Flux.@layer Join
101108
```
102109

103110
Finally, we define the forward pass. For `Join`, this means applying each `path` in `paths` to each input array, then using `combine` to merge the results.
@@ -154,7 +161,7 @@ model(xs)
154161

155162
Our custom `Split` layer will accept a single input, then pass the input through a separate path to produce multiple outputs.
156163

157-
We start by following the same steps as the `Join` layer: define a struct, use [`Functors.@functor`](@ref), and define the forward pass.
164+
We start by following the same steps as the `Join` layer: define a struct, use [`@layer`](@ref), and define the forward pass.
158165
```julia
159166
using Flux
160167
using CUDA
@@ -166,7 +173,7 @@ end
166173

167174
Split(paths...) = Split(paths)
168175

169-
Flux.@functor Split
176+
Flux.@layer Split
170177

171178
(m::Split)(x::AbstractArray) = map(f -> f(x), m.paths)
172179
```

docs/src/models/basics.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ m(5) # => 26
257257

258258
There is still one problem with this `Affine` layer, that Flux does not know to look inside it. This means that [`Flux.train!`](@ref) won't see its parameters, nor will [`gpu`](@ref) be able to move them to your GPU. These features are enabled by the [`@functor`](@ref Functors.@functor) macro:
259259

260-
```
261-
Flux.@functor Affine
260+
```julia
261+
Flux.@layer Affine
262262
```
263263

264264
Finally, most Flux layers make bias optional, and allow you to supply the function used for generating random weights. We can easily add these refinements to the `Affine` layer as follows, using the helper function [`create_bias`](@ref Flux.create_bias):
@@ -272,3 +272,8 @@ end
272272
273273
Affine(3 => 1, bias=false, init=ones) |> gpu
274274
```
275+
276+
```@docs
277+
Flux.@layer
278+
Flux.create_bias
279+
```

src/Flux.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using MacroTools: @forward
99

1010
@reexport using NNlib
1111
using MLUtils
12+
const stack = MLUtils.stack # now exported by Base
1213
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1314
using Optimisers: freeze!, thaw!, adjust!
1415
using Random: default_rng
@@ -69,14 +70,16 @@ include("functor.jl")
6970
# Pirate error to catch a common mistake.
7071
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")
7172

73+
include("layers/show.jl")
74+
include("layers/macro.jl")
75+
7276
include("layers/stateless.jl")
7377
include("layers/basic.jl")
7478
include("layers/conv.jl")
7579
include("layers/recurrent.jl")
7680
include("layers/normalise.jl")
7781
include("layers/upsample.jl")
7882
include("layers/attention.jl")
79-
include("layers/show.jl")
8083

8184
include("loading.jl")
8285

src/functor.jl

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ function params!(p::Params, x, seen = IdSet())
8181
elseif x in seen
8282
nothing
8383
else
84+
_check_new_macro(x) # complains if you used @functor not @layer
8485
push!(seen, x)
8586
for child in trainable(x)
8687
params!(p, child, seen)

src/layers/attention.jl

+41-3
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ struct MultiHeadAttention{P1, D, P2}
7474
out_proj::P2
7575
end
7676

77-
@functor MultiHeadAttention
77+
@layer MultiHeadAttention
7878

7979
function MultiHeadAttention(dims;
8080
nheads::Int = 8,
@@ -83,8 +83,8 @@ function MultiHeadAttention(dims;
8383
dropout_prob = 0.0)
8484

8585
dims = normalize_mha_dims(dims)
86-
@assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads"
87-
@assert dims.v % nheads == 0 "v_dim should be divisible by nheads"
86+
dims.qk % nheads == 0 || throw(ArgumentError("qk_dim = $(dims.qk) should be divisible by nheads = $(nheads)"))
87+
dims.v % nheads == 0 || throw(ArgumentError( "v_dim = $(dims.v) should be divisible by nheads = $(nheads)"))
8888
q_proj = Dense(dims.q_in => dims.qk; bias, init)
8989
k_proj = Dense(dims.k_in => dims.qk; bias, init)
9090
v_proj = Dense(dims.v_in => dims.v; bias, init)
@@ -131,3 +131,41 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3,
131131
# [α] = [kv_len, q_len, nheads, batch_size]
132132
return x, α
133133
end
134+
135+
function Base.show(io::IO, mha::MultiHeadAttention)
136+
qk, q_in = size(mha.q_proj.weight)
137+
qk, k_in = size(mha.k_proj.weight)
138+
v, v_in = size(mha.v_proj.weight)
139+
out, v = size(mha.out_proj.weight)
140+
# @show q_in, k_in, v_in, qk, v, out
141+
print(io, "MultiHeadAttention(")
142+
if q_in == k_in == v_in == qk == v == out
143+
print(io, q_in)
144+
elseif q_in == k_in == v_in && qk == v
145+
print(io, q_in, " => ", qk, " => ", out)
146+
elseif q_in == k_in == v_in
147+
print(io, q_in, " => (", qk, ", ", v,") => ", out)
148+
else
149+
print(io, "(", q_in, ", ", k_in, ", ", v_in, ") => (", qk, ", ", v,") => ", out)
150+
end
151+
print(io, "; nheads=", mha.nheads)
152+
if mha.q_proj.bias !== false
153+
print(io, ", bias=true")
154+
end
155+
if mha.attn_drop.p != 0
156+
print(io, ", dropout_prob=", mha.attn_drop.p) # can't we rename this?
157+
end
158+
print(io, ")")
159+
end
160+
161+
162+
#=
163+
164+
# Test cases for printing:
165+
166+
MultiHeadAttention((3, 4, 5) => (6, 7) => 8; nheads=1)
167+
MultiHeadAttention(3 => (6, 7) => 8; nheads=1)
168+
MultiHeadAttention(3 => 6 => 8; nheads=1)
169+
MultiHeadAttention(8; bias=true)
170+
171+
=#

src/layers/basic.jl

+9-9
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646
@forward Chain.layers Base.getindex, Base.length, Base.first, Base.last,
4747
Base.iterate, Base.lastindex, Base.keys, Base.firstindex
4848

49-
@functor Chain
49+
@layer :expand Chain # the + opts-in to container-style pretty-printing
5050

5151
(c::Chain)(x) = _applychain(c.layers, x)
5252

@@ -165,7 +165,7 @@ function Dense((in, out)::Pair{<:Integer, <:Integer}, σ = identity;
165165
Dense(init(out, in), bias, σ)
166166
end
167167

168-
@functor Dense
168+
@layer Dense
169169

170170
function (a::Dense)(x::AbstractVecOrMat)
171171
_size_check(a, x, 1 => size(a.weight, 2))
@@ -251,7 +251,7 @@ end
251251
Scale(s1::Integer, s23::Integer...; bias = true, init = ones32, _act = identity) = Scale(init(s1, s23...), bias, _act)
252252
Scale(size_act...; bias = true, init = ones32) = Scale(size_act[1:end-1]...; bias, init, _act = size_act[end])
253253

254-
@functor Scale
254+
@layer Scale
255255

256256
function (a::Scale)(x::AbstractArray)
257257
σ = NNlib.fast_act(a.σ, x) # replaces tanh => tanh_fast, etc
@@ -306,7 +306,7 @@ end
306306
Maxout(layers...) = Maxout(layers)
307307
Maxout(f::Function, n_alts::Integer) = Maxout((f() for _ in 1:n_alts)...)
308308

309-
@functor Maxout
309+
@layer :expand Maxout
310310

311311
function (mo::Maxout)(input::AbstractArray)
312312
# Perhaps surprisingly, pairwise max broadcast is often faster,
@@ -353,7 +353,7 @@ struct SkipConnection{T,F}
353353
connection::F #user can pass arbitrary connections here, such as (a,b) -> a + b
354354
end
355355

356-
@functor SkipConnection
356+
@layer :expand SkipConnection
357357

358358
function (skip::SkipConnection)(input)
359359
skip.connection(skip.layers(input), input)
@@ -423,7 +423,7 @@ struct Bilinear{F,A,B}
423423
end
424424
end
425425

426-
@functor Bilinear
426+
@layer Bilinear
427427

428428
function Bilinear(((in1, in2), out)::Pair{<:Tuple, <:Integer}, σ = identity;
429429
bias = true, init = glorot_uniform)
@@ -522,7 +522,7 @@ function Parallel(connection; kw...)
522522
Parallel(connection, layers)
523523
end
524524

525-
@functor Parallel
525+
@layer :expand Parallel
526526

527527
(m::Parallel)(x) = m.connection(map(f -> f(x), Tuple(m.layers))...)
528528
(m::Parallel)(xs::Tuple) = m(xs...)
@@ -643,7 +643,7 @@ end
643643
end
644644
applypairwisefusion(layers::NamedTuple, connection, x) = applypairwisefusion(Tuple(layers), connection, x)
645645

646-
@functor PairwiseFusion
646+
@layer :expand PairwiseFusion
647647

648648
Base.getindex(m::PairwiseFusion, i) = m.layers[i]
649649
Base.getindex(m::PairwiseFusion, i::AbstractVector) = PairwiseFusion(m.connection, m.layers[i])
@@ -701,7 +701,7 @@ struct Embedding{W<:AbstractMatrix}
701701
weight::W
702702
end
703703

704-
@functor Embedding
704+
@layer Embedding
705705

706706
Embedding((in, out)::Pair{<:Integer, <:Integer}; init = randn32) = Embedding(init(out, in))
707707

src/layers/conv.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ function convfilter(filter::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer};
187187
init(filter..., cin÷groups, cout)
188188
end
189189

190-
@functor Conv
190+
@layer Conv
191191

192192
conv_dims(c::Conv, x::AbstractArray) =
193193
DenseConvDims(x, c.weight; stride = c.stride, padding = c.pad, dilation = c.dilation, groups = c.groups)
@@ -309,7 +309,7 @@ function ConvTranspose(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ =
309309
ConvTranspose(weight, bias, σ; stride, pad, dilation, groups)
310310
end
311311

312-
@functor ConvTranspose
312+
@layer ConvTranspose
313313

314314
function conv_transpose_dims(c::ConvTranspose, x::AbstractArray)
315315
# Calculate size of "input", from ∇conv_data()'s perspective...
@@ -460,7 +460,7 @@ function CrossCor(k::NTuple{N,Integer}, ch::Pair{<:Integer,<:Integer}, σ = iden
460460
return CrossCor(weight, bias, σ; stride, pad, dilation)
461461
end
462462

463-
@functor CrossCor
463+
@layer CrossCor
464464

465465
function crosscor(x, w, ddims::DenseConvDims)
466466
ddims = DenseConvDims(ddims, F=true)

0 commit comments

Comments
 (0)