Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@ using .Layers

# CNN models
## Builders
include("convnets/builders/core.jl")
include("convnets/builders/irmodel.jl")
include("convnets/builders/mbconv.jl")
include("convnets/builders/resblocks.jl")
include("convnets/builders/resnet.jl")
include("convnets/builders/stages.jl")
## AlexNet and VGG
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/res2net.jl")
include("convnets/resnets/resnet.jl")
include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
include("convnets/resnets/res2net.jl")
## Inceptions
include("convnets/inceptions/googlenet.jl")
include("convnets/inceptions/inceptionv3.jl")
include("convnets/inceptions/inceptionv4.jl")
include("convnets/inceptions/inceptionresnetv2.jl")
include("convnets/inceptions/xception.jl")
## EfficientNets
include("convnets/efficientnets/core.jl")
include("convnets/efficientnets/efficientnet.jl")
include("convnets/efficientnets/efficientnetv2.jl")
## MobileNets
Expand Down
9 changes: 5 additions & 4 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)

Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).

# Arguments

- `dropout_rate`: dropout rate for the classifier
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
function alexnet(; dropout_rate = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
MaxPool((3, 3); stride = 2),
Conv((5, 5), 64 => 192, relu; pad = 2),
Expand All @@ -19,9 +20,9 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
Conv((3, 3), 256 => 256, relu; pad = 1),
MaxPool((3, 3); stride = 2))
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
Dropout(0.5),
Dropout(dropout_rate),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dropout(dropout_rate),
Dense(4096, 4096, relu),
Dense(4096, nclasses))
return Chain(backbone, classifier)
Expand Down
41 changes: 41 additions & 0 deletions src/convnets/builders/irmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
function irmodelbuilder(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
inplanes::Integer = 32, connection = +, activation = relu,
norm_layer = BatchNorm, divisor::Integer = 8,
tail_conv::Bool = true, expanded_classifier::Bool = false,
headplanes::Integer, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...)
width_mult, _ = scalings
# building first layer
inplanes = _round_channels(inplanes * width_mult, divisor)
layers = []
append!(layers,
conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1,
norm_layer))
# building inverted residual blocks
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
norm_layer, divisor, kwargs...)
append!(layers, cnn_stages(get_layers, block_repeats, connection))
# building last layers
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
if tail_conv
# special case, supported fully only for MobileNetv3
if expanded_classifier
midplanes = _round_channels(outplanes * block_configs[end][4], divisor)
append!(layers,
conv_norm((1, 1), outplanes, midplanes, activation; norm_layer))
classifier = create_classifier(midplanes, headplanes, nclasses,
(hardswish, identity); dropout_rate)
else
append!(layers,
conv_norm((1, 1), outplanes, headplanes, activation; norm_layer))
classifier = create_classifier(headplanes, nclasses; dropout_rate)
end
else
classifier = create_classifier(outplanes, nclasses; dropout_rate)
end
return Chain(Chain(layers...), classifier)
end

function irmodelbuilder(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
return irmodelbuilder((width_mult, 1), block_configs; kwargs...)
end
88 changes: 28 additions & 60 deletions src/convnets/builders/mbconv.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
# TODO - potentially make these builders more flexible to specify stuff like
# activation functions and reductions that don't change over the stages

function dwsepconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
outplanes = _round_channels(outplanes * width_mult)
outplanes = _round_channels(outplanes * width_mult, divisor)
if stage_idx != 1
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
end
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
Expand All @@ -12,13 +17,14 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
stride, pad = SamePad(), norm_layer, kwargs...)...)
return (block,)
end
return get_layers, nrepeats
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(dwsep_conv_norm)) = dwsepconv_builder

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
divisor::Integer = 8, se_from_explanes::Bool = false,
kwargs...)
function mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8,
se_from_explanes::Bool = false, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
Expand All @@ -39,69 +45,31 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
end
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(mbconv)) = mbconv_builder

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
norm_layer, kwargs...)
end

function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
norm_layer = BatchNorm, kwargs...)
function fused_mbconv_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
stage_idx::Integer, scalings::NTuple{2, Real};
norm_layer = BatchNorm, divisor::Integer = 8, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
outplanes = _round_channels(outplanes * width_mult, divisor)
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, 8)
explanes = _round_channels(inplanes * expansion, divisor)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
norm_layer, stride, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
end
return get_layers, nrepeats
end

# TODO - these builders need to be more flexible to potentially specify stuff like
# activation functions and reductions that don't change
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
end

function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
if isnothing(scalings)
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
elseif isnothing(width_mult)
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
kwargs...)
else
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
end
end

function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer)
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
return get_layers, ceil(Int, nrepeats * depth_mult)
end
_get_builder(::typeof(fused_mbconv)) = fused_mbconv_builder

function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing,
norm_layer = BatchNorm, kwargs...)
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
width_mult, norm_layer, kwargs...)
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
scalings::NTuple{2, Real}; kwargs...)
builders = _get_builder.(first.(block_configs))
bxs = [builders[idx](block_configs, inplanes, idx, scalings; kwargs...)
for idx in eachindex(block_configs)]
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
end
9 changes: 9 additions & 0 deletions src/convnets/builders/resnet.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
function resnetbuilder(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer},
connection, classifier_fn)
# Build stages of the ResNet
stage_blocks = cnn_stages(get_layers, block_repeats, connection)
backbone = Chain(stem, stage_blocks...)
# Add classifier to the backbone
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
return Chain(backbone, classifier_fn(nfeaturemaps))
end
File renamed without changes.
26 changes: 21 additions & 5 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
"`planes` should have exactly one value for each block"
downsample_layers = []
push!(downsample_layers,
Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4,
Chain(conv_norm((4, 4), inchannels, planes[1]; stride = 4,
norm_layer = ChannelLayerNorm)...))
for m in 1:(length(depths) - 1)
push!(downsample_layers,
Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2,
Chain(conv_norm((2, 2), planes[m], planes[m + 1]; stride = 2,
norm_layer = ChannelLayerNorm, revnorm = true)...))
end
stages = []
Expand All @@ -68,6 +68,12 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In
return Chain(Chain(backbone...), classifier)
end

function convnext(config::Symbol; drop_path_rate = 0.0, layerscale_init = 1.0f-6,
inchannels::Integer = 3, nclasses::Integer = 1000)
return convnext(CONVNEXT_CONFIGS[config]...; drop_path_rate, layerscale_init,
inchannels, nclasses)
end

# Configurations for ConvNeXt models
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
Expand All @@ -76,27 +82,37 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

"""
ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
nclasses::Integer = 1000)

Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments

- `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`.
- `pretrain`: set to `true` to load pre-trained weights for ImageNet
- `inchannels`: number of input channels
- `nclasses`: number of output classes

!!! warning

`ConvNeXt` does not currently support pretrained weights.

See also [`Metalhead.convnext`](#).
"""
struct ConvNeXt
layers::Any
end
@functor ConvNeXt

function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000)
function ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses)
layers = convnext(config; inchannels, nclasses)
if pretrain
layers = loadpretrain!(layers, "convnext_$config")
end
return ConvNeXt(layers)
end

Expand Down
21 changes: 0 additions & 21 deletions src/convnets/efficientnets/core.jl

This file was deleted.

16 changes: 12 additions & 4 deletions src/convnets/efficientnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6)))

function efficientnet(config::Symbol; norm_layer = BatchNorm,
dropout_rate = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
return irmodelbuilder(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32,
norm_layer, activation = swish,
headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4,
dropout_rate, inchannels, nclasses)
end

"""
EfficientNet(config::Symbol; pretrain::Bool = false)

Expand All @@ -50,10 +61,7 @@ end

function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3,
nclasses::Integer = 1000)
_checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS))
scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2]
layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings,
inchannels, nclasses)
layers = efficientnet(config; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("efficientnet-", config))
end
Expand Down
14 changes: 10 additions & 4 deletions src/convnets/efficientnets/efficientnetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish),
(mbconv, 3, 512, 6, 2, 32, 4, swish),
(mbconv, 3, 768, 6, 1, 8, 4, swish)])

function efficientnetv2(config::Symbol; norm_layer = BatchNorm, dropout_rate = nothing,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, keys(EFFNETV2_CONFIGS))
block_configs = EFFNETV2_CONFIGS[config]
return irmodelbuilder((1, 1), block_configs; activation = swish, norm_layer,
inplanes = block_configs[1][3], headplanes = 1280,
dropout_rate, inchannels, nclasses)
end

"""
EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1,
inchannels::Integer = 3, nclasses::Integer = 1000)
Expand All @@ -57,10 +66,7 @@ end

function EfficientNetv2(config::Symbol; pretrain::Bool = false,
inchannels::Integer = 3, nclasses::Integer = 1000)
_checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS))))
block_configs = EFFNETV2_CONFIGS[config]
layers = efficientnet(block_configs; inplanes = block_configs[1][3],
headplanes = 1280, inchannels, nclasses)
layers = efficientnetv2(config; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("efficientnetv2-", config))
end
Expand Down
Loading