Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
BSON = "0.3.2"
Flux = "0.13"
Functors = "0.2, 0.3"
CUDA = "3"
ChainRulesCore = "1"
Flux = "0.13"
Functors = "0.2, 0.3"
MLUtils = "0.2.10"
NNlib = "0.8"
NNlibCUDA = "0.2"
Expand Down
2 changes: 1 addition & 1 deletion docs/dev-guide/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ To add a new model architecture to Metalhead.jl, you can [open a PR](https://git

- reuse layers from Flux as much as possible (e.g. use `Parallel` before defining a `Bottleneck` struct)
- adhere as closely as possible to a reference such as a published paper (i.e. the structure of your model should follow intuitively from the paper)
- use generic functional builders (e.g. [`resnet`](#) is the core function that builds "ResNet-like" models)
- use generic functional builders (e.g. [`Metalhead.resnet`](@ref) is the core function that builds "ResNet-like" models)
- use multiple dispatch to add convenience constructors that wrap your functional builder

When in doubt, just open a PR! We are more than happy to help review your code to help it align with the rest of the library. After adding a model, you might consider adding some pre-trained weights (see below).
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using Flux, Metalhead
```

Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the pre-trained ResNet-18 model.
Using a model from Metalhead is as simple as selecting a model from the table of [available models](@ref). For example, below we use the pre-trained ResNet-18 model.
{cell=quickstart}
```julia
using Flux, Metalhead
Expand Down
7 changes: 4 additions & 3 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,27 @@ using .Layers

# CNN models
## Builders
include("convnets/builders/core.jl")
include("convnets/builders/irmodel.jl")
include("convnets/builders/mbconv.jl")
include("convnets/builders/resblocks.jl")
include("convnets/builders/resnet.jl")
include("convnets/builders/stages.jl")
## AlexNet and VGG
include("convnets/alexnet.jl")
include("convnets/vgg.jl")
## ResNets
include("convnets/resnets/core.jl")
include("convnets/resnets/res2net.jl")
include("convnets/resnets/resnet.jl")
include("convnets/resnets/resnext.jl")
include("convnets/resnets/seresnet.jl")
include("convnets/resnets/res2net.jl")
## Inceptions
include("convnets/inceptions/googlenet.jl")
include("convnets/inceptions/inceptionv3.jl")
include("convnets/inceptions/inceptionv4.jl")
include("convnets/inceptions/inceptionresnetv2.jl")
include("convnets/inceptions/xception.jl")
## EfficientNets
include("convnets/efficientnets/core.jl")
include("convnets/efficientnets/efficientnet.jl")
include("convnets/efficientnets/efficientnetv2.jl")
## MobileNets
Expand Down
11 changes: 6 additions & 5 deletions src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)

Create an AlexNet model
([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)).

# Arguments

- `dropout_prob`: dropout probability for the classifier
- `inchannels`: The number of input channels.
- `nclasses`: the number of output classes
"""
function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
function alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000)
backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2),
MaxPool((3, 3); stride = 2),
Conv((5, 5), 64 => 192, relu; pad = 2),
Expand All @@ -19,9 +20,9 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000)
Conv((3, 3), 256 => 256, relu; pad = 1),
MaxPool((3, 3); stride = 2))
classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten,
Dropout(0.5),
Dropout(dropout_prob),
Dense(256 * 6 * 6, 4096, relu),
Dropout(0.5),
Dropout(dropout_prob),
Dense(4096, 4096, relu),
Dense(4096, nclasses))
return Chain(backbone, classifier)
Expand All @@ -44,7 +45,7 @@ Create a `AlexNet`.

`AlexNet` does not currently support pretrained weights.

See also [`alexnet`](#).
See also [`Metalhead.alexnet`](@ref).
"""
struct AlexNet
layers::Any
Expand Down
19 changes: 0 additions & 19 deletions src/convnets/builders/core.jl

This file was deleted.

41 changes: 41 additions & 0 deletions src/convnets/builders/irmodel.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
function build_irmodel(scalings::NTuple{2, Real}, block_configs::AbstractVector{<:Tuple};
inplanes::Integer = 32, connection = +, activation = relu,
norm_layer = BatchNorm, divisor::Integer = 8, tail_conv::Bool = true,
expanded_classifier::Bool = false, stochastic_depth_prob = nothing,
headplanes::Integer, dropout_prob = nothing, inchannels::Integer = 3,
nclasses::Integer = 1000, kwargs...)
width_mult, _ = scalings
# building first layer
inplanes = _round_channels(inplanes * width_mult, divisor)
layers = []
append!(layers,
conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1,
norm_layer))
# building inverted residual blocks
get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings;
stochastic_depth_prob, norm_layer,
divisor, kwargs...)
append!(layers, cnn_stages(get_layers, block_repeats, connection))
# building last layers
outplanes = _round_channels(block_configs[end][3] * width_mult, divisor)
if tail_conv
# special case, supported fully only for MobileNetv3
if expanded_classifier
midplanes = _round_channels(outplanes * block_configs[end][4], divisor)
append!(layers,
conv_norm((1, 1), outplanes, midplanes, activation; norm_layer))
classifier = create_classifier(midplanes, headplanes, nclasses,
(hardswish, identity); dropout_prob)
else
append!(layers,
conv_norm((1, 1), outplanes, headplanes, activation; norm_layer))
classifier = create_classifier(headplanes, nclasses; dropout_prob)
end
else
classifier = create_classifier(outplanes, nclasses; dropout_prob)
end
return Chain(Chain(layers...), classifier)
end
function build_irmodel(width_mult::Real, block_configs::AbstractVector{<:Tuple}; kwargs...)
return build_irmodel((width_mult, 1), block_configs; kwargs...)
end
127 changes: 61 additions & 66 deletions src/convnets/builders/mbconv.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,29 @@
function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
"""
irblockbuilder(::typeof(irblockfn), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
divisor::Integer = 8, kwargs...)

Constructs a collection of inverted residual blocks for a given stage. Note that
this function is not intended to be called directly, but rather by the [`mbconv_stage_builder`](@ref)
function. This function must only be extended if the user wishes to extend a custom inverted
residual block type.

# Arguments

- `irblockfn`: the inverted residual block function to use in the block builder. Metalhead
defines methods for [`dwsep_conv_norm`](@ref), [`mbconv`](@ref) and [`fused_mbconv`](@ref)
as inverted residual blocks.
"""
function irblockbuilder(::typeof(dwsep_conv_norm), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
divisor::Integer = 8, kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have unused keywords? As far as I can tell, right now it is silently ignored, but not including extraneous keywords would throw a MethodError (as it should).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are being used though, to pass stuff like the se_round_fn in for MobileNetv3

Copy link
Member

@darsnack darsnack Sep 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically mean stuff like stochastic_depth_prob

Copy link
Member Author

@theabhirath theabhirath Sep 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I specifically mean stuff like stochastic_depth_prob

This particular one needs to be there because it will cause a MethodError for dwsep_conv_norm if passed through (and the default model builder passes it through). See also #200 (comment)

width_mult, depth_mult = scalings
block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
outplanes = _round_channels(outplanes * width_mult)
outplanes = _round_channels(outplanes * width_mult, divisor)
if stage_idx != 1
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult)
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
end
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
Expand All @@ -12,15 +32,17 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
stride, pad = SamePad(), norm_layer, kwargs...)...)
return (block,)
end
return get_layers, nrepeats
return get_layers, ceil(Int, nrepeats * depth_mult)
end

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
scalings::NTuple{2, Real}; norm_layer = BatchNorm,
divisor::Integer = 8, se_from_explanes::Bool = false,
kwargs...)
function irblockbuilder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
divisor::Integer = 8, se_from_explanes::Bool = false, kwargs...)
width_mult, depth_mult = scalings
block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult)
for idx in eachindex(block_configs)]
block_fn, k, outplanes, expansion, stride, _, reduction, activation = block_configs[stage_idx]
# calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
if !isnothing(reduction)
reduction = !se_from_explanes ? reduction * expansion : reduction
Expand All @@ -29,79 +51,52 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor)
end
outplanes = _round_channels(outplanes * width_mult, divisor)
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, divisor)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer,
stride, reduction, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
use_skip = stride == 1 && inplanes == outplanes
if use_skip
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
drop_path = StochasticDepth(sdschedule[schedule_idx])
return (drop_path, block)
else
return (block,)
end
end
return get_layers, ceil(Int, nrepeats * depth_mult)
end

function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
width_mult::Real; norm_layer = BatchNorm, kwargs...)
return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1);
norm_layer, kwargs...)
return get_layers, block_repeats[stage_idx]
end

function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer;
norm_layer = BatchNorm, kwargs...)
block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
function irblockbuilder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
stochastic_depth_prob = nothing, norm_layer = BatchNorm,
divisor::Integer = 8, kwargs...)
width_mult, depth_mult = scalings
block_repeats = [ceil(Int, block_configs[idx][end - 1] * depth_mult)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Mobilenet the config list is

const MOBILENETV2_CONFIGS = [
    (mbconv, 3, 16, 1, 1, 1, nothing, relu6),
    (mbconv, 3, 24, 6, 2, 2, nothing, relu6),
    (mbconv, 3, 32, 6, 2, 3, nothing, relu6),
    (mbconv, 3, 64, 6, 2, 4, nothing, relu6),
    (mbconv, 3, 96, 6, 1, 3, nothing, relu6),
    (mbconv, 3, 160, 6, 2, 3, nothing, relu6),
    (mbconv, 3, 320, 6, 1, 1, nothing, relu6),
]

so won't block_repeats just be a list of nothing ?
and as the name suggest, it tells how many time a particular block of layers is repeated?

for idx in eachindex(block_configs)]
block_fn, k, outplanes, expansion, stride, _, activation = block_configs[stage_idx]
inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3]
outplanes = _round_channels(outplanes * width_mult, divisor)
sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats))
function get_layers(block_idx::Integer)
inplanes = block_idx == 1 ? inplanes : outplanes
explanes = _round_channels(inplanes * expansion, 8)
explanes = _round_channels(inplanes * expansion, divisor)
stride = block_idx == 1 ? stride : 1
block = block_fn((k, k), inplanes, explanes, outplanes, activation;
norm_layer, stride, kwargs...)
return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
end
return get_layers, nrepeats
end

# TODO - these builders need to be more flexible to potentially specify stuff like
# activation functions and reductions that don't change
function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
@assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument"
return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
end

function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...)
if isnothing(scalings)
return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer,
kwargs...)
elseif isnothing(width_mult)
return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer,
kwargs...)
else
throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified"))
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
drop_path = StochasticDepth(sdschedule[schedule_idx])
return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
end
return get_layers, block_repeats[stage_idx]
end

function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple},
inplanes::Integer, stage_idx::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing, norm_layer)
@assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument."
@assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument"
return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer)
end

function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer;
scalings::Union{Nothing, NTuple{2, Real}} = nothing,
width_mult::Union{Nothing, Number} = nothing,
norm_layer = BatchNorm, kwargs...)
bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings,
width_mult, norm_layer, kwargs...)
for idx in eachindex(block_configs)]
function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer,
scalings::NTuple{2, Real}; kwargs...)
bxs = [irblockbuilder(block_configs[idx][1], block_configs, inplanes, idx, scalings;
kwargs...) for idx in eachindex(block_configs)]
return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs)
end
Loading