Skip to content

Commit f76fadb

Browse files
committed
Minor refactor of cnn_stages
And no, final cleanup was not final. 1. Fix `width_mult` calculations 2. Hopefully all the parameters line up now for all widths of the MobileNets 3. `MNASNet` wasn't `@functor`ed 4. Random docstring link fixes
1 parent 2d310a9 commit f76fadb

File tree

12 files changed

+50
-57
lines changed

12 files changed

+50
-57
lines changed

src/convnets/builders/core.jl

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,19 @@
1-
# TODO potentially refactor other CNNs to use this
2-
function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, connection)
1+
function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer},
2+
connection = nothing)
33
# Construct each stage
44
stages = []
55
for (stage_idx, nblocks) in enumerate(block_repeats)
66
# Construct the blocks for each stage
77
blocks = map(1:nblocks) do block_idx
88
branches = get_layers(stage_idx, block_idx)
9+
if isnothing(connection)
10+
@assert length(branches)==1 "get_layers should return a single branch for
11+
each block if no connection is specified"
12+
end
913
return length(branches) == 1 ? only(branches) :
1014
Parallel(connection, branches...)
1115
end
1216
push!(stages, Chain(blocks...))
1317
end
1418
return stages
1519
end
16-
17-
function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer})
18-
# Construct each stage
19-
stages = []
20-
for (stage_idx, nblocks) in enumerate(block_repeats)
21-
# Construct the blocks for each stage
22-
blocks = map(1:nblocks) do block_idx
23-
branches = get_layers(stage_idx, block_idx)
24-
@assert length(branches)==1 "get_layers should return a single branch for each
25-
block if no connection is specified"
26-
return only(branches)
27-
end
28-
push!(stages, Chain(blocks...))
29-
end
30-
return stages
31-
end

src/convnets/builders/mbconv.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
22
width_mult::Real; norm_layer = BatchNorm, kwargs...)
33
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
44
outplanes = floor(Int, outplanes * width_mult)
5-
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
5+
if stage_idx != 1
6+
inplanes = floor(Int, block_configs[stage_idx - 1][3] * width_mult)
7+
end
68
function get_layers(block_idx::Integer)
79
inplanes = block_idx == 1 ? inplanes : outplanes
810
stride = block_idx == 1 ? stride : 1
@@ -23,8 +25,9 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
2325
if !isnothing(reduction)
2426
reduction = !se_from_explanes ? reduction * expansion : reduction
2527
end
26-
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
27-
inplanes = _round_channels(inplanes * width_mult, divisor)
28+
if stage_idx != 1
29+
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
30+
end
2831
outplanes = _round_channels(outplanes * width_mult, divisor)
2932
function get_layers(block_idx::Integer)
3033
inplanes = block_idx == 1 ? inplanes : outplanes

src/convnets/efficientnets/core.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@ function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer,
55
inchannels::Integer = 3, nclasses::Integer = 1000)
66
layers = []
77
# stem of the model
8+
inplanes = _round_channels(inplanes * scalings[1])
89
append!(layers,
9-
conv_norm((3, 3), inchannels, _round_channels(inplanes * scalings[1], 8),
10-
swish; norm_layer, stride = 2, pad = SamePad()))
10+
conv_norm((3, 3), inchannels, inplanes, swish; norm_layer, stride = 2,
11+
pad = SamePad()))
1112
# building inverted residual blocks
1213
get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings,
1314
norm_layer)
1415
append!(layers, cnn_stages(get_layers, block_repeats, +))
1516
# building last layers
1617
append!(layers,
17-
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1], 8),
18+
conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1]),
1819
headplanes, swish; pad = SamePad()))
1920
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
2021
end

src/convnets/mobilenets/mnasnet.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
2525
norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = _MNASNET_BN_MOMENTUM,
2626
kwargs...)
2727
# building first layer
28-
inplanes = _round_channels(inplanes * width_mult, 8)
28+
inplanes = _round_channels(inplanes * width_mult)
2929
layers = []
3030
append!(layers,
3131
conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1,
@@ -35,8 +35,8 @@ function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3535
norm_layer)
3636
append!(layers, cnn_stages(get_layers, block_repeats, +))
3737
# building last layers
38-
outplanes = _round_channels(block_configs[end][3] * width_mult, 8)
39-
headplanes = _round_channels(max_width * max(1, width_mult), 8)
38+
outplanes = _round_channels(block_configs[end][3] * width_mult)
39+
headplanes = _round_channels(max_width * max(1, width_mult))
4040
append!(layers,
4141
conv_norm((1, 1), outplanes, headplanes, relu; norm_layer))
4242
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))
@@ -119,3 +119,8 @@ function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false,
119119
end
120120
return MNASNet(layers)
121121
end
122+
123+
(m::MNASNet)(x) = m.layers(x)
124+
125+
backbone(m::MNASNet) = m.layers[1]
126+
classifier(m::MNASNet) = m.layers[2]

src/convnets/mobilenets/mobilenetv1.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,18 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
2323
"""
2424
function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1,
2525
activation = relu, dropout_rate = nothing,
26-
inchannels::Integer = 3, nclasses::Integer = 1000)
26+
inplanes::Integer = 32, inchannels::Integer = 3,
27+
nclasses::Integer = 1000)
2728
layers = []
2829
# stem of the model
30+
inplanes = floor(Int, inplanes * width_mult)
2931
append!(layers,
30-
conv_norm((3, 3), inchannels, config[1][3], activation; stride = 2, pad = 1))
32+
conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1))
3133
# building inverted residual blocks
32-
get_layers, block_repeats = mbconv_stack_builder(config, config[1][3]; width_mult)
34+
get_layers, block_repeats = mbconv_stack_builder(config, inplanes; width_mult)
3335
append!(layers, cnn_stages(get_layers, block_repeats))
34-
return Chain(Chain(layers...),
35-
create_classifier(config[end][3], nclasses; dropout_rate))
36+
outplanes = floor(Int, config[end][3] * width_mult)
37+
return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate))
3638
end
3739

3840
# Layer configurations for MobileNetv1
@@ -45,14 +47,10 @@ end
4547
const MOBILENETV1_CONFIGS = [
4648
# f, k, c, s, n, a
4749
(dwsep_conv_bn, 3, 64, 1, 1, relu6),
48-
(dwsep_conv_bn, 3, 128, 2, 1, relu6),
49-
(dwsep_conv_bn, 3, 128, 1, 1, relu6),
50-
(dwsep_conv_bn, 3, 256, 2, 1, relu6),
51-
(dwsep_conv_bn, 3, 256, 1, 1, relu6),
52-
(dwsep_conv_bn, 3, 512, 2, 1, relu6),
53-
(dwsep_conv_bn, 3, 512, 1, 5, relu6),
54-
(dwsep_conv_bn, 3, 1024, 2, 1, relu6),
55-
(dwsep_conv_bn, 3, 1024, 1, 1, relu6),
50+
(dwsep_conv_bn, 3, 128, 2, 2, relu6),
51+
(dwsep_conv_bn, 3, 256, 2, 2, relu6),
52+
(dwsep_conv_bn, 3, 512, 2, 6, relu6),
53+
(dwsep_conv_bn, 3, 1024, 2, 2, relu6),
5654
]
5755

5856
"""

src/convnets/mobilenets/mobilenetv2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real =
3737
divisor)
3838
append!(layers, cnn_stages(get_layers, block_repeats, +))
3939
# building last layers
40-
outplanes = _round_channels(block_configs[end][3], divisor)
40+
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
4141
headplanes = _round_channels(max_width * max(1, width_mult), divisor)
4242
append!(layers, conv_norm((1, 1), outplanes, headplanes, relu6))
4343
return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate))

src/convnets/mobilenets/mobilenetv3.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
2828
max_width::Integer = 1024, dropout_rate = 0.2,
2929
inchannels::Integer = 3, nclasses::Integer = 1000)
3030
# building first layer
31-
inplanes = _round_channels(16 * width_mult, 8)
31+
inplanes = _round_channels(16 * width_mult)
3232
layers = []
3333
append!(layers,
3434
conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1))
@@ -38,12 +38,11 @@ function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1,
3838
se_round_fn = _round_channels)
3939
append!(layers, cnn_stages(get_layers, block_repeats, +))
4040
# building last layers
41-
explanes = _round_channels(configs[end][3] * width_mult, 8)
42-
midplanes = _round_channels(explanes * configs[end][4], 8)
43-
headplanes = _round_channels(max_width * width_mult, 8)
41+
explanes = _round_channels(configs[end][3] * width_mult)
42+
midplanes = _round_channels(explanes * configs[end][4])
4443
append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish))
4544
return Chain(Chain(layers...),
46-
create_classifier(midplanes, headplanes, nclasses,
45+
create_classifier(midplanes, max_width, nclasses,
4746
(hardswish, identity); dropout_rate))
4847
end
4948

src/convnets/resnets/resnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
ResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
33
44
Creates a ResNet model with the specified depth.
5-
((reference)[https://arxiv.org/abs/1512.03385])
5+
([reference](https://arxiv.org/abs/1512.03385))
66
77
# Arguments
88
@@ -39,7 +39,7 @@ classifier(m::ResNet) = m.layers[2]
3939
Creates a Wide ResNet model with the specified depth. The model is the same as ResNet
4040
except for the bottleneck number of channels which is twice larger in every block.
4141
The number of channels in outer 1x1 convolutions is the same.
42-
((reference)[https://arxiv.org/abs/1605.07146])
42+
([reference](https://arxiv.org/abs/1605.07146))
4343
4444
# Arguments
4545

src/convnets/resnets/resnext.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
44
55
Creates a ResNeXt model with the specified depth, cardinality, and base width.
6-
((reference)[https://arxiv.org/abs/1611.05431])
6+
([reference](https://arxiv.org/abs/1611.05431))
77
88
# Arguments
99

src/convnets/resnets/seresnet.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SEResNet(depth::Integer; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
33
44
Creates a SEResNet model with the specified depth.
5-
((reference)[https://arxiv.org/pdf/1709.01507.pdf])
5+
([reference](https://arxiv.org/pdf/1709.01507.pdf))
66
77
# Arguments
88
@@ -43,7 +43,7 @@ classifier(m::SEResNet) = m.layers[2]
4343
base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000)
4444
4545
Creates a SEResNeXt model with the specified depth, cardinality, and base width.
46-
((reference)[https://arxiv.org/pdf/1709.01507.pdf])
46+
([reference](https://arxiv.org/pdf/1709.01507.pdf))
4747
4848
# Arguments
4949

0 commit comments

Comments
 (0)