Skip to content

Commit 6504f25

Browse files
committed
Refactor EfficientNets
1 parent d3e4add commit 6504f25

22 files changed

+300
-337
lines changed

src/Metalhead.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,19 @@ include("convnets/resnets/resnext.jl")
2828
include("convnets/resnets/seresnet.jl")
2929
include("convnets/resnets/res2net.jl")
3030
## Inceptions
31-
include("convnets/inception/googlenet.jl")
32-
include("convnets/inception/inceptionv3.jl")
33-
include("convnets/inception/inceptionv4.jl")
34-
include("convnets/inception/inceptionresnetv2.jl")
35-
include("convnets/inception/xception.jl")
31+
include("convnets/inceptions/googlenet.jl")
32+
include("convnets/inceptions/inceptionv3.jl")
33+
include("convnets/inceptions/inceptionv4.jl")
34+
include("convnets/inceptions/inceptionresnetv2.jl")
35+
include("convnets/inceptions/xception.jl")
3636
## EfficientNets
37-
include("convnets/efficientnet/efficientnet.jl")
38-
include("convnets/efficientnet/efficientnetv2.jl")
37+
include("convnets/efficientnets/core.jl")
38+
include("convnets/efficientnets/efficientnet.jl")
39+
include("convnets/efficientnets/efficientnetv2.jl")
3940
## MobileNets
40-
include("convnets/mobilenet/mobilenetv1.jl")
41-
include("convnets/mobilenet/mobilenetv2.jl")
42-
include("convnets/mobilenet/mobilenetv3.jl")
41+
include("convnets/mobilenets/mobilenetv1.jl")
42+
include("convnets/mobilenets/mobilenetv2.jl")
43+
include("convnets/mobilenets/mobilenetv3.jl")
4344
## Others
4445
include("convnets/densenet.jl")
4546
include("convnets/squeezenet.jl")

src/convnets/convmixer.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
2+
convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
33
patch_size::Dims{2} = (7, 7), activation = gelu,
44
inchannels::Integer = 3, nclasses::Integer = 1000)
55
@@ -16,7 +16,7 @@ Creates a ConvMixer model.
1616
- `inchannels`: number of input channels
1717
- `nclasses`: number of classes in the output
1818
"""
19-
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
19+
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
2020
patch_size::Dims{2} = (7, 7), activation = gelu,
2121
inchannels::Integer = 3, nclasses::Integer = 1000)
2222
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,

src/convnets/densenet.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ Create a Densenet bottleneck layer
1212
"""
1313
function dense_bottleneck(inplanes::Integer, outplanes::Integer; expansion::Integer = 4)
1414
inner_channels = expansion * outplanes
15-
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false,
15+
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels;
1616
revnorm = true)...,
1717
conv_norm((3, 3), inner_channels, outplanes; pad = 1,
18-
bias = false, revnorm = true)...),
18+
revnorm = true)...),
1919
cat_channels)
2020
end
2121

@@ -31,7 +31,7 @@ Create a DenseNet transition sequence
3131
- `outplanes`: number of output feature maps
3232
"""
3333
function transition(inplanes::Integer, outplanes::Integer)
34-
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)...,
34+
return Chain(conv_norm((1, 1), inplanes, outplanes; revnorm = true)...,
3535
MeanPool((2, 2)))
3636
end
3737

@@ -72,7 +72,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, inchannels::
7272
nclasses::Integer = 1000)
7373
layers = []
7474
append!(layers,
75-
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3), bias = false))
75+
conv_norm((7, 7), inchannels, inplanes; stride = 2, pad = (3, 3)))
7676
push!(layers, MaxPool((3, 3); stride = 2, pad = (1, 1)))
7777
outplanes = 0
7878
for (i, rates) in enumerate(growth_rates)

src/convnets/efficientnet/efficientnet.jl

Lines changed: 0 additions & 113 deletions
This file was deleted.

src/convnets/efficientnet/efficientnetv2.jl

Lines changed: 0 additions & 124 deletions
This file was deleted.

src/convnets/efficientnets/core.jl

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
abstract type _MBConfig end
2+
3+
struct MBConvConfig <: _MBConfig
4+
kernel_size::Dims{2}
5+
inplanes::Integer
6+
outplanes::Integer
7+
expansion::Number
8+
stride::Integer
9+
nrepeats::Integer
10+
end
11+
function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
12+
expansion::Number, stride::Integer, nrepeats::Integer,
13+
width_mult::Number = 1, depth_mult::Number = 1)
14+
inplanes = _round_channels(inplanes * width_mult, 8)
15+
outplanes = _round_channels(outplanes * width_mult, 8)
16+
nrepeats = ceil(Int, nrepeats * depth_mult)
17+
return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
18+
stride, nrepeats)
19+
end
20+
21+
function efficientnetblock(m::MBConvConfig, norm_layer)
22+
layers = []
23+
explanes = _round_channels(m.inplanes * m.expansion, 8)
24+
push!(layers,
25+
mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer,
26+
stride = m.stride, reduction = 4))
27+
explanes = _round_channels(m.outplanes * m.expansion, 8)
28+
append!(layers,
29+
[mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer,
30+
stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)])
31+
return Chain(layers...)
32+
end
33+
34+
struct FusedMBConvConfig <: _MBConfig
35+
kernel_size::Dims{2}
36+
inplanes::Integer
37+
outplanes::Integer
38+
expansion::Number
39+
stride::Integer
40+
nrepeats::Integer
41+
end
42+
function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer,
43+
expansion::Number, stride::Integer, nrepeats::Integer)
44+
return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion,
45+
stride, nrepeats)
46+
end
47+
48+
function efficientnetblock(m::FusedMBConvConfig, norm_layer)
49+
layers = []
50+
explanes = _round_channels(m.inplanes * m.expansion, 8)
51+
push!(layers,
52+
fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish;
53+
norm_layer, stride = m.stride))
54+
explanes = _round_channels(m.outplanes * m.expansion, 8)
55+
append!(layers,
56+
[fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish;
57+
norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)])
58+
return Chain(layers...)
59+
end
60+
61+
function efficientnet(block_configs::AbstractVector{<:_MBConfig};
62+
headplanes::Union{Nothing, Integer} = nothing,
63+
norm_layer = BatchNorm, dropout_rate = nothing,
64+
inchannels::Integer = 3, nclasses::Integer = 1000)
65+
layers = []
66+
# stem of the model
67+
append!(layers,
68+
conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer,
69+
stride = 2, pad = SamePad()))
70+
# building inverted residual blocks
71+
append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs])
72+
# building last layers
73+
outplanes = block_configs[end].outplanes
74+
headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes
75+
append!(layers,
76+
conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad()))
77+
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
78+
end

0 commit comments

Comments
 (0)