Skip to content

Commit f001221

Browse files
authored
Merge pull request #198 from theabhirath/effnetv2
Implementation of EfficientNetv2 and MNASNet
2 parents 842fa99 + 992f6a6 commit f001221

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1526
-1025
lines changed

.github/workflows/CI.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@ jobs:
2727
- x64
2828
suite:
2929
- '["AlexNet", "VGG"]'
30-
- '["GoogLeNet", "SqueezeNet", "MobileNet"]'
31-
- '["EfficientNet"]'
30+
- '["GoogLeNet", "SqueezeNet", "MobileNets"]'
31+
- '"EfficientNet"'
3232
- 'r"/*/ResNet*"'
33-
- '[r"ResNeXt", r"SEResNet"]'
33+
- 'r"/*/SEResNet*"'
3434
- '[r"Res2Net", r"Res2NeXt"]'
3535
- '"Inception"'
3636
- '"DenseNet"'

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ Flux = "0.13"
2323
Functors = "0.2, 0.3"
2424
CUDA = "3"
2525
ChainRulesCore = "1"
26-
PartialFunctions = "1"
2726
MLUtils = "0.2.10"
2827
NNlib = "0.8"
2928
NNlibCUDA = "0.2"
29+
PartialFunctions = "1"
3030
julia = "1.6"
3131

3232
[publish]

src/Metalhead.jl

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ include("layers/Layers.jl")
1919
using .Layers
2020

2121
# CNN models
22+
## Builders
23+
include("convnets/builders/core.jl")
24+
include("convnets/builders/mbconv.jl")
25+
include("convnets/builders/resblocks.jl")
26+
## AlexNet and VGG
2227
include("convnets/alexnet.jl")
2328
include("convnets/vgg.jl")
2429
## ResNets
@@ -28,19 +33,23 @@ include("convnets/resnets/resnext.jl")
2833
include("convnets/resnets/seresnet.jl")
2934
include("convnets/resnets/res2net.jl")
3035
## Inceptions
31-
include("convnets/inception/googlenet.jl")
32-
include("convnets/inception/inceptionv3.jl")
33-
include("convnets/inception/inceptionv4.jl")
34-
include("convnets/inception/inceptionresnetv2.jl")
35-
include("convnets/inception/xception.jl")
36+
include("convnets/inceptions/googlenet.jl")
37+
include("convnets/inceptions/inceptionv3.jl")
38+
include("convnets/inceptions/inceptionv4.jl")
39+
include("convnets/inceptions/inceptionresnetv2.jl")
40+
include("convnets/inceptions/xception.jl")
41+
## EfficientNets
42+
include("convnets/efficientnets/core.jl")
43+
include("convnets/efficientnets/efficientnet.jl")
44+
include("convnets/efficientnets/efficientnetv2.jl")
3645
## MobileNets
37-
include("convnets/mobilenet/mobilenetv1.jl")
38-
include("convnets/mobilenet/mobilenetv2.jl")
39-
include("convnets/mobilenet/mobilenetv3.jl")
46+
include("convnets/mobilenets/mobilenetv1.jl")
47+
include("convnets/mobilenets/mobilenetv2.jl")
48+
include("convnets/mobilenets/mobilenetv3.jl")
49+
include("convnets/mobilenets/mnasnet.jl")
4050
## Others
4151
include("convnets/densenet.jl")
4252
include("convnets/squeezenet.jl")
43-
include("convnets/efficientnet.jl")
4453
include("convnets/convnext.jl")
4554
include("convnets/convmixer.jl")
4655

@@ -61,13 +70,16 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
6170
WideResNet, ResNeXt, SEResNet, SEResNeXt, Res2Net, Res2NeXt,
6271
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
6372
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
64-
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
73+
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet,
74+
EfficientNet, EfficientNetv2,
6575
MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt
6676

6777
# use Flux._big_show to pretty print large models
68-
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
69-
:Res2Net, :Res2NeXt, :GoogLeNet, :Inceptionv3, :Inceptionv4,
70-
:Xception, :SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
78+
for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt,
79+
:SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet,
80+
:Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
81+
:MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet,
82+
:EfficientNet, :EfficientNetv2,
7183
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
7284
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
7385
end

src/convnets/builders/core.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer},
2+
connection = nothing)
3+
# Construct each stage
4+
stages = []
5+
for (stage_idx, nblocks) in enumerate(block_repeats)
6+
# Construct the blocks for each stage
7+
blocks = map(1:nblocks) do block_idx
8+
branches = get_layers(stage_idx, block_idx)
9+
if isnothing(connection)
10+
@assert length(branches)==1 "get_layers should return a single branch for
11+
each block if no connection is specified"
12+
end
13+
return length(branches) == 1 ? only(branches) :
14+
Parallel(connection, branches...)
15+
end
16+
push!(stages, Chain(blocks...))
17+
end
18+
return stages
19+
end

src/convnets/builders/mbconv.jl

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
2+
width_mult::Real; norm_layer = BatchNorm, kwargs...)
3+
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
4+
outplanes = _round_channels(outplanes * width_mult)
5+
if stage_idx != 1
6+
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
7+
end
8+
function get_layers(block_idx::Integer)
9+
inplanes = block_idx == 1 ? inplanes : outplanes
10+
stride = block_idx == 1 ? stride : 1
11+
block = Chain(block_fn((k, k), inplanes, outplanes, activation;
12+
stride, pad = SamePad(), norm_layer, kwargs...)...)
13+
return (block,)
14+
end
15+
return get_layers, nrepeats
16+
end
17+
18+
function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
19+
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
20+
divisor::Integer = 8, se_from_explanes::Bool = false,
21+
kwargs...)
22+
width_mult, depth_mult = scalings
23+
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
24+
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
25+
if !isnothing(reduction)
26+
reduction = !se_from_explanes ? reduction * expansion : reduction
27+
end
28+
if stage_idx != 1
29+
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
30+
end
31+
outplanes = _round_channels(outplanes * width_mult, divisor)
32+
function get_layers(block_idx::Integer)
33+
inplanes = block_idx == 1 ? inplanes : outplanes
34+
explanes = _round_channels(inplanes * expansion, divisor)
35+
stride = block_idx == 1 ? stride : 1
36+
block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer,
37+
stride, reduction, kwargs...)
38+
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
39+
end
40+
return get_layers, ceil(Int, nrepeats * depth_mult)
41+
end
42+
43+
function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
44+
width_mult::Real; norm_layer = BatchNorm, kwargs...)
45+
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
46+
norm_layer, kwargs...)
47+
end
48+
49+
function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
50+
norm_layer = BatchNorm, kwargs...)
51+
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
52+
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
53+
function get_layers(block_idx::Integer)
54+
inplanes = block_idx == 1 ? inplanes : outplanes
55+
explanes = _round_channels(inplanes * expansion, 8)
56+
stride = block_idx == 1 ? stride : 1
57+
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
58+
norm_layer, stride, kwargs...)
59+
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
60+
end
61+
return get_layers, nrepeats
62+
end
63+
64+
# TODO - these builders need to be more flexible to potentially specify stuff like
65+
# activation functions and reductions that don't change
66+
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
67+
inplanes::Integer, stage_idx::Integer;
68+
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
69+
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
70+
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
71+
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
72+
kwargs...)
73+
end
74+
75+
function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
76+
inplanes::Integer, stage_idx::Integer;
77+
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
78+
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
79+
if isnothing(scalings)
80+
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
81+
kwargs...)
82+
elseif isnothing(width_mult)
83+
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
84+
kwargs...)
85+
else
86+
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
87+
end
88+
end
89+
90+
function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
91+
inplanes::Integer, stage_idx::Integer;
92+
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
93+
width_mult::Union{Nothing, Number} = nothing, norm_layer)
94+
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
95+
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
96+
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
97+
end
98+
99+
function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
100+
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
101+
width_mult::Union{Nothing, Number} = nothing,
102+
norm_layer = BatchNorm, kwargs...)
103+
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
104+
width_mult, norm_layer, kwargs...)
105+
for idx in eachindex(block_configs)]
106+
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
107+
end

src/convnets/builders/resblocks.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
function basicblock_builder(block_repeats::AbstractVector{<:Integer};
2+
inplanes::Integer = 64, reduction_factor::Integer = 1,
3+
expansion::Integer = 1, norm_layer = BatchNorm,
4+
revnorm::Bool = false, activation = relu,
5+
attn_fn = planes -> identity,
6+
drop_block_rate = nothing, drop_path_rate = nothing,
7+
stride_fn = resnet_stride, planes_fn = resnet_planes,
8+
downsample_tuple = (downsample_conv, downsample_identity))
9+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
10+
# Also get `planes_vec` needed for block `inplanes` and `planes` calculations
11+
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
12+
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
13+
planes_vec = collect(planes_fn(block_repeats))
14+
# closure over `idxs`
15+
function get_layers(stage_idx::Integer, block_idx::Integer)
16+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
17+
# This is also needed for block `inplanes` and `planes` calculations
18+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
19+
planes = planes_vec[schedule_idx]
20+
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
21+
# `resnet_stride` is a callback that the user can tweak to change the stride of the
22+
# blocks. It defaults to the standard behaviour as in the paper
23+
stride = stride_fn(stage_idx, block_idx)
24+
downsample_fn = stride != 1 || inplanes != planes * expansion ?
25+
downsample_tuple[1] : downsample_tuple[2]
26+
drop_path = DropPath(pathschedule[schedule_idx])
27+
drop_block = DropBlock(blockschedule[schedule_idx])
28+
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
29+
norm_layer, revnorm, attn_fn, drop_path, drop_block)
30+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
31+
revnorm)
32+
return block, downsample
33+
end
34+
return get_layers
35+
end
36+
37+
function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
38+
inplanes::Integer = 64, cardinality::Integer = 1,
39+
base_width::Integer = 64, reduction_factor::Integer = 1,
40+
expansion::Integer = 4, norm_layer = BatchNorm,
41+
revnorm::Bool = false, activation = relu,
42+
attn_fn = planes -> identity,
43+
drop_block_rate = nothing, drop_path_rate = nothing,
44+
stride_fn = resnet_stride, planes_fn = resnet_planes,
45+
downsample_tuple = (downsample_conv, downsample_identity))
46+
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
47+
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
48+
planes_vec = collect(planes_fn(block_repeats))
49+
# closure over `idxs`
50+
function get_layers(stage_idx::Integer, block_idx::Integer)
51+
# DropBlock, DropPath both take in rates based on a linear scaling schedule
52+
# This is also needed for block `inplanes` and `planes` calculations
53+
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
54+
planes = planes_vec[schedule_idx]
55+
inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion
56+
# `resnet_stride` is a callback that the user can tweak to change the stride of the
57+
# blocks. It defaults to the standard behaviour as in the paper
58+
stride = stride_fn(stage_idx, block_idx)
59+
downsample_fn = stride != 1 || inplanes != planes * expansion ?
60+
downsample_tuple[1] : downsample_tuple[2]
61+
drop_path = DropPath(pathschedule[schedule_idx])
62+
drop_block = DropBlock(blockschedule[schedule_idx])
63+
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
64+
reduction_factor, activation, norm_layer, revnorm,
65+
attn_fn, drop_path, drop_block)
66+
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
67+
revnorm)
68+
return block, downsample
69+
end
70+
return get_layers
71+
end

src/convnets/convmixer.jl

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
2+
convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
33
patch_size::Dims{2} = (7, 7), activation = gelu,
44
inchannels::Integer = 3, nclasses::Integer = 1000)
55
@@ -13,20 +13,25 @@ Creates a ConvMixer model.
1313
- `kernel_size`: kernel size of the convolutional layers
1414
- `patch_size`: size of the patches
1515
- `activation`: activation function used after the convolutional layers
16-
- `inchannels`: The number of channels in the input.
16+
- `inchannels`: number of input channels
1717
- `nclasses`: number of classes in the output
1818
"""
19-
function convmixer(planes::Integer, depth::Integer; kernel_size = (9, 9),
20-
patch_size::Dims{2} = (7, 7), activation = gelu,
19+
function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9),
20+
patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing,
2121
inchannels::Integer = 3, nclasses::Integer = 1000)
22-
stem = conv_norm(patch_size, inchannels, planes, activation; preact = true,
23-
stride = patch_size[1])
24-
blocks = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
22+
layers = []
23+
# stem of the model
24+
append!(layers,
25+
conv_norm(patch_size, inchannels, planes, activation; preact = true,
26+
stride = patch_size[1]))
27+
# stages of the model
28+
stages = [Chain(SkipConnection(Chain(conv_norm(kernel_size, planes, planes, activation;
2529
preact = true, groups = planes,
2630
pad = SamePad())), +),
2731
conv_norm((1, 1), planes, planes, activation; preact = true)...)
2832
for _ in 1:depth]
29-
return Chain(Chain(stem..., Chain(blocks...)), create_classifier(planes, nclasses))
33+
append!(layers, stages)
34+
return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate))
3035
end
3136

3237
const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20),
@@ -48,7 +53,7 @@ Creates a ConvMixer model.
4853
# Arguments
4954
5055
- `config`: the size of the model, either `:base`, `:small` or `:large`
51-
- `inchannels`: The number of channels in the input.
56+
- `inchannels`: number of input channels
5257
- `nclasses`: number of classes in the output
5358
"""
5459
struct ConvMixer

src/convnets/convnext.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ Creates a ConvNeXt model.
8484
# Arguments
8585
8686
- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
87-
- `inchannels`: The number of channels in the input.
87+
- `inchannels`: number of input channels
8888
- `nclasses`: number of output classes
8989
9090
See also [`Metalhead.convnext`](#).

0 commit comments

Comments
 (0)