|
| 1 | +abstract type _MBConfig end |
| 2 | + |
| 3 | +struct MBConvConfig <: _MBConfig |
| 4 | + kernel_size::Dims{2} |
| 5 | + inplanes::Integer |
| 6 | + outplanes::Integer |
| 7 | + expansion::Number |
| 8 | + stride::Integer |
| 9 | + nrepeats::Integer |
| 10 | +end |
| 11 | +function MBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, |
| 12 | + expansion::Number, stride::Integer, nrepeats::Integer, |
| 13 | + width_mult::Number = 1, depth_mult::Number = 1) |
| 14 | + inplanes = _round_channels(inplanes * width_mult, 8) |
| 15 | + outplanes = _round_channels(outplanes * width_mult, 8) |
| 16 | + nrepeats = ceil(Int, nrepeats * depth_mult) |
| 17 | + return MBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, |
| 18 | + stride, nrepeats) |
| 19 | +end |
| 20 | + |
| 21 | +function efficientnetblock(m::MBConvConfig, norm_layer) |
| 22 | + layers = [] |
| 23 | + explanes = _round_channels(m.inplanes * m.expansion, 8) |
| 24 | + push!(layers, |
| 25 | + mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; norm_layer, |
| 26 | + stride = m.stride, reduction = 4)) |
| 27 | + explanes = _round_channels(m.outplanes * m.expansion, 8) |
| 28 | + append!(layers, |
| 29 | + [mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; norm_layer, |
| 30 | + stride = 1, reduction = 4) for _ in 1:(m.nrepeats - 1)]) |
| 31 | + return Chain(layers...) |
| 32 | +end |
| 33 | + |
| 34 | +struct FusedMBConvConfig <: _MBConfig |
| 35 | + kernel_size::Dims{2} |
| 36 | + inplanes::Integer |
| 37 | + outplanes::Integer |
| 38 | + expansion::Number |
| 39 | + stride::Integer |
| 40 | + nrepeats::Integer |
| 41 | +end |
| 42 | +function FusedMBConvConfig(kernel_size::Integer, inplanes::Integer, outplanes::Integer, |
| 43 | + expansion::Number, stride::Integer, nrepeats::Integer) |
| 44 | + return FusedMBConvConfig((kernel_size, kernel_size), inplanes, outplanes, expansion, |
| 45 | + stride, nrepeats) |
| 46 | +end |
| 47 | + |
| 48 | +function efficientnetblock(m::FusedMBConvConfig, norm_layer) |
| 49 | + layers = [] |
| 50 | + explanes = _round_channels(m.inplanes * m.expansion, 8) |
| 51 | + push!(layers, |
| 52 | + fused_mbconv(m.kernel_size, m.inplanes, explanes, m.outplanes, swish; |
| 53 | + norm_layer, stride = m.stride)) |
| 54 | + explanes = _round_channels(m.outplanes * m.expansion, 8) |
| 55 | + append!(layers, |
| 56 | + [fused_mbconv(m.kernel_size, m.outplanes, explanes, m.outplanes, swish; |
| 57 | + norm_layer, stride = 1) for _ in 1:(m.nrepeats - 1)]) |
| 58 | + return Chain(layers...) |
| 59 | +end |
| 60 | + |
| 61 | +function efficientnet(block_configs::AbstractVector{<:_MBConfig}; |
| 62 | + headplanes::Union{Nothing, Integer} = nothing, |
| 63 | + norm_layer = BatchNorm, dropout_rate = nothing, |
| 64 | + inchannels::Integer = 3, nclasses::Integer = 1000) |
| 65 | + layers = [] |
| 66 | + # stem of the model |
| 67 | + append!(layers, |
| 68 | + conv_norm((3, 3), inchannels, block_configs[1].inplanes, swish; norm_layer, |
| 69 | + stride = 2, pad = SamePad())) |
| 70 | + # building inverted residual blocks |
| 71 | + append!(layers, [efficientnetblock(cfg, norm_layer) for cfg in block_configs]) |
| 72 | + # building last layers |
| 73 | + outplanes = block_configs[end].outplanes |
| 74 | + headplanes = isnothing(headplanes) ? outplanes * 4 : headplanes |
| 75 | + append!(layers, |
| 76 | + conv_norm((1, 1), outplanes, headplanes, swish; pad = SamePad())) |
| 77 | + return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) |
| 78 | +end |
0 commit comments