diff --git a/Project.toml b/Project.toml index 9664dc79d..2f7df3ed4 100644 --- a/Project.toml +++ b/Project.toml @@ -19,10 +19,10 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] BSON = "0.3.2" -Flux = "0.13" -Functors = "0.2, 0.3" CUDA = "3" ChainRulesCore = "1" +Flux = "0.13" +Functors = "0.2, 0.3" MLUtils = "0.2.10" NNlib = "0.8" NNlibCUDA = "0.2" diff --git a/docs/dev-guide/contributing.md b/docs/dev-guide/contributing.md index 75574b033..f126d7bb8 100644 --- a/docs/dev-guide/contributing.md +++ b/docs/dev-guide/contributing.md @@ -16,7 +16,7 @@ To add a new model architecture to Metalhead.jl, you can [open a PR](https://git - reuse layers from Flux as much as possible (e.g. use `Parallel` before defining a `Bottleneck` struct) - adhere as closely as possible to a reference such as a published paper (i.e. the structure of your model should follow intuitively from the paper) -- use generic functional builders (e.g. [`resnet`](#) is the core function that builds "ResNet-like" models) +- use generic functional builders (e.g. [`Metalhead.resnet`](@ref) is the core function that builds "ResNet-like" models) - use multiple dispatch to add convenience constructors that wrap your functional builder When in doubt, just open a PR! We are more than happy to help review your code to help it align with the rest of the library. After adding a model, you might consider adding some pre-trained weights (see below). diff --git a/docs/tutorials/quickstart.md b/docs/tutorials/quickstart.md index 43149933a..cfb99eb72 100644 --- a/docs/tutorials/quickstart.md +++ b/docs/tutorials/quickstart.md @@ -5,7 +5,7 @@ using Flux, Metalhead ``` -Using a model from Metalhead is as simple as selecting a model from the table of [available models](#). For example, below we use the pre-trained ResNet-18 model. +Using a model from Metalhead is as simple as selecting a model from the table of [available models](@ref). For example, below we use the pre-trained ResNet-18 model. {cell=quickstart} ```julia using Flux, Metalhead diff --git a/src/Metalhead.jl b/src/Metalhead.jl index f9be49db1..b67bf1faf 100644 --- a/src/Metalhead.jl +++ b/src/Metalhead.jl @@ -20,18 +20,20 @@ using .Layers # CNN models ## Builders -include("convnets/builders/core.jl") +include("convnets/builders/invresmodel.jl") include("convnets/builders/mbconv.jl") include("convnets/builders/resblocks.jl") +include("convnets/builders/resnet.jl") +include("convnets/builders/stages.jl") ## AlexNet and VGG include("convnets/alexnet.jl") include("convnets/vgg.jl") ## ResNets include("convnets/resnets/core.jl") +include("convnets/resnets/res2net.jl") include("convnets/resnets/resnet.jl") include("convnets/resnets/resnext.jl") include("convnets/resnets/seresnet.jl") -include("convnets/resnets/res2net.jl") ## Inceptions include("convnets/inceptions/googlenet.jl") include("convnets/inceptions/inceptionv3.jl") @@ -39,7 +41,6 @@ include("convnets/inceptions/inceptionv4.jl") include("convnets/inceptions/inceptionresnetv2.jl") include("convnets/inceptions/xception.jl") ## EfficientNets -include("convnets/efficientnets/core.jl") include("convnets/efficientnets/efficientnet.jl") include("convnets/efficientnets/efficientnetv2.jl") ## MobileNets @@ -71,16 +72,16 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19, DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201, GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception, SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, MNASNet, - EfficientNet, EfficientNetv2, - MLPMixer, ResMLP, gMLP, ViT, ConvMixer, ConvNeXt + EfficientNet, EfficientNetv2, ConvMixer, ConvNeXt, + MLPMixer, ResMLP, gMLP, ViT # use Flux._big_show to pretty print large models for T in (:AlexNet, :VGG, :SqueezeNet, :ResNet, :WideResNet, :ResNeXt, :SEResNet, :SEResNeXt, :Res2Net, :Res2NeXt, :GoogLeNet, :DenseNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception, :MobileNetv1, :MobileNetv2, :MobileNetv3, :MNASNet, - :EfficientNet, :EfficientNetv2, - :MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt) + :EfficientNet, :EfficientNetv2, :ConvMixer, :ConvNeXt, + :MLPMixer, :ResMLP, :gMLP, :ViT) @eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model) end diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 3c713839e..51cfbf029 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -1,15 +1,16 @@ """ - alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) + alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) Create an AlexNet model ([reference](https://papers.nips.cc/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf)). # Arguments + - `dropout_prob`: dropout probability for the classifier - `inchannels`: The number of input channels. - `nclasses`: the number of output classes """ -function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) +function alexnet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(Conv((11, 11), inchannels => 64, relu; stride = 4, pad = 2), MaxPool((3, 3); stride = 2), Conv((5, 5), 64 => 192, relu; pad = 2), @@ -19,9 +20,9 @@ function alexnet(; inchannels::Integer = 3, nclasses::Integer = 1000) Conv((3, 3), 256 => 256, relu; pad = 1), MaxPool((3, 3); stride = 2)) classifier = Chain(AdaptiveMeanPool((6, 6)), MLUtils.flatten, - Dropout(0.5), + Dropout(dropout_prob), Dense(256 * 6 * 6, 4096, relu), - Dropout(0.5), + Dropout(dropout_prob), Dense(4096, 4096, relu), Dense(4096, nclasses)) return Chain(backbone, classifier) @@ -44,7 +45,7 @@ Create a `AlexNet`. `AlexNet` does not currently support pretrained weights. -See also [`alexnet`](#). +See also [`Metalhead.alexnet`](@ref). """ struct AlexNet layers::Any diff --git a/src/convnets/builders/core.jl b/src/convnets/builders/core.jl deleted file mode 100644 index f97f92ff9..000000000 --- a/src/convnets/builders/core.jl +++ /dev/null @@ -1,19 +0,0 @@ -function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, - connection = nothing) - # Construct each stage - stages = [] - for (stage_idx, nblocks) in enumerate(block_repeats) - # Construct the blocks for each stage - blocks = map(1:nblocks) do block_idx - branches = get_layers(stage_idx, block_idx) - if isnothing(connection) - @assert length(branches)==1 "get_layers should return a single branch for - each block if no connection is specified" - end - return length(branches) == 1 ? only(branches) : - Parallel(connection, branches...) - end - push!(stages, Chain(blocks...)) - end - return stages -end diff --git a/src/convnets/builders/invresmodel.jl b/src/convnets/builders/invresmodel.jl new file mode 100644 index 000000000..6faeca992 --- /dev/null +++ b/src/convnets/builders/invresmodel.jl @@ -0,0 +1,44 @@ +function build_invresmodel(scalings::NTuple{2, Real}, + block_configs::AbstractVector{<:Tuple}; + inplanes::Integer = 32, connection = +, activation = relu, + norm_layer = BatchNorm, divisor::Integer = 8, + tail_conv::Bool = true, expanded_classifier::Bool = false, + stochastic_depth_prob = nothing, headplanes::Integer, + dropout_prob = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000, kwargs...) + width_mult, _ = scalings + # building first layer + inplanes = _round_channels(inplanes * width_mult, divisor) + layers = [] + append!(layers, + conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1, + norm_layer)) + # building inverted residual blocks + get_layers, block_repeats = mbconv_stage_builder(block_configs, inplanes, scalings; + stochastic_depth_prob, norm_layer, + divisor, kwargs...) + append!(layers, cnn_stages(get_layers, block_repeats, connection)) + # building last layers + outplanes = _round_channels(block_configs[end][3] * width_mult, divisor) + if tail_conv + # special case, supported fully only for MobileNetv3 + if expanded_classifier + midplanes = _round_channels(outplanes * block_configs[end][4], divisor) + append!(layers, + conv_norm((1, 1), outplanes, midplanes, activation; norm_layer)) + classifier = create_classifier(midplanes, headplanes, nclasses, + (hardswish, identity); dropout_prob) + else + append!(layers, + conv_norm((1, 1), outplanes, headplanes, activation; norm_layer)) + classifier = create_classifier(headplanes, nclasses; dropout_prob) + end + else + classifier = create_classifier(outplanes, nclasses; dropout_prob) + end + return Chain(Chain(layers...), classifier) +end +function build_invresmodel(width_mult::Real, block_configs::AbstractVector{<:Tuple}; + kwargs...) + return build_invresmodel((width_mult, 1), block_configs; kwargs...) +end diff --git a/src/convnets/builders/mbconv.jl b/src/convnets/builders/mbconv.jl index 31a936add..7b68be300 100644 --- a/src/convnets/builders/mbconv.jl +++ b/src/convnets/builders/mbconv.jl @@ -1,9 +1,24 @@ -function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, - width_mult::Real; norm_layer = BatchNorm, kwargs...) +""" + invresbuilder(::typeof(irblockfn), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; + stochastic_depth_prob = nothing, norm_layer = BatchNorm, + divisor::Integer = 8, kwargs...) + +Creates a block builder for `irblockfn` within a given stage. +Note that this function is not intended to be called directly, but instead passed to +[`mbconv_stage_builder`](@ref) which will return a builder over all stages. +Users wanting to provide a custom inverted residual block type can extend this +function by defining `invresbuilder(::typeof(my_block), ...)`. +""" +function invresbuilder(::typeof(dwsep_conv_norm), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; + stochastic_depth_prob = nothing, norm_layer = BatchNorm, + divisor::Integer = 8, kwargs...) + width_mult, depth_mult = scalings block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] - outplanes = _round_channels(outplanes * width_mult) + outplanes = _round_channels(outplanes * width_mult, divisor) if stage_idx != 1 - inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult) + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor) end function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes @@ -12,15 +27,17 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, stride, pad = SamePad(), norm_layer, kwargs...)...) return (block,) end - return get_layers, nrepeats + return get_layers, ceil(Int, nrepeats * depth_mult) end -function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, - scalings::NTuple{2, Real}; norm_layer = BatchNorm, - divisor::Integer = 8, se_from_explanes::Bool = false, - kwargs...) +function invresbuilder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; + stochastic_depth_prob = nothing, norm_layer = BatchNorm, + divisor::Integer = 8, se_from_explanes::Bool = false, kwargs...) width_mult, depth_mult = scalings - block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] + block_repeats = [ceil(Int, block_configs[idx][end - 2] * depth_mult) + for idx in eachindex(block_configs)] + block_fn, k, outplanes, expansion, stride, _, reduction, activation = block_configs[stage_idx] # calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes if !isnothing(reduction) reduction = !se_from_explanes ? reduction * expansion : reduction @@ -29,79 +46,52 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor) end outplanes = _round_channels(outplanes * width_mult, divisor) + sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes explanes = _round_channels(inplanes * expansion, divisor) stride = block_idx == 1 ? stride : 1 block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, stride, reduction, kwargs...) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + use_skip = stride == 1 && inplanes == outplanes + if use_skip + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + drop_path = StochasticDepth(sdschedule[schedule_idx]) + return (drop_path, block) + else + return (block,) + end end - return get_layers, ceil(Int, nrepeats * depth_mult) -end - -function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, - width_mult::Real; norm_layer = BatchNorm, kwargs...) - return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1); - norm_layer, kwargs...) + return get_layers, block_repeats[stage_idx] end -function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer; - norm_layer = BatchNorm, kwargs...) - block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] +function invresbuilder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple}, + inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real}; + stochastic_depth_prob = nothing, norm_layer = BatchNorm, + divisor::Integer = 8, kwargs...) + width_mult, depth_mult = scalings + block_repeats = [ceil(Int, block_configs[idx][end - 1] * depth_mult) + for idx in eachindex(block_configs)] + block_fn, k, outplanes, expansion, stride, _, activation = block_configs[stage_idx] inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] + outplanes = _round_channels(outplanes * width_mult, divisor) + sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) function get_layers(block_idx::Integer) inplanes = block_idx == 1 ? inplanes : outplanes - explanes = _round_channels(inplanes * expansion, 8) + explanes = _round_channels(inplanes * expansion, divisor) stride = block_idx == 1 ? stride : 1 block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, stride, kwargs...) - return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + drop_path = StochasticDepth(sdschedule[schedule_idx]) + return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,) end - return get_layers, nrepeats -end - -# TODO - these builders need to be more flexible to potentially specify stuff like -# activation functions and reductions that don't change -function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::Union{Nothing, NTuple{2, Real}} = nothing, - width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) - @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" - return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, - kwargs...) -end - -function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::Union{Nothing, NTuple{2, Real}} = nothing, - width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) - if isnothing(scalings) - return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, - kwargs...) - elseif isnothing(width_mult) - return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer, - kwargs...) - else - throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) - end -end - -function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple}, - inplanes::Integer, stage_idx::Integer; - scalings::Union{Nothing, NTuple{2, Real}} = nothing, - width_mult::Union{Nothing, Number} = nothing, norm_layer) - @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." - @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" - return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer) + return get_layers, block_repeats[stage_idx] end -function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; - scalings::Union{Nothing, NTuple{2, Real}} = nothing, - width_mult::Union{Nothing, Number} = nothing, - norm_layer = BatchNorm, kwargs...) - bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings, - width_mult, norm_layer, kwargs...) - for idx in eachindex(block_configs)] +function mbconv_stage_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer, + scalings::NTuple{2, Real}; kwargs...) + bxs = [invresbuilder(block_configs[idx][1], block_configs, inplanes, idx, scalings; + kwargs...) for idx in eachindex(block_configs)] return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) end diff --git a/src/convnets/builders/resblocks.jl b/src/convnets/builders/resblocks.jl index 8343bf811..bdf36c0f3 100644 --- a/src/convnets/builders/resblocks.jl +++ b/src/convnets/builders/resblocks.jl @@ -1,70 +1,191 @@ +""" + basicblock_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 1, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, + dropblock_prob = nothing, stochastic_depth_prob = nothing, + stride_fn = resnet_stride, planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + +Builder for creating a basic block for a ResNet model. +([reference](https://arxiv.org/abs/1512.03385)) + +# Arguments + + - `block_repeats`: number of repeats of a block in each stage + + - `inplanes`: number of input channels + - `reduction_factor`: reduction factor for the number of channels in each stage + - `expansion`: expansion factor for the number of channels for the block + - `norm_layer`: normalization layer to use + - `revnorm`: set to `true` to place normalization layer before the convolution + - `activation`: activation function to use + - `attn_fn`: attention function to use + - `dropblock_prob`: dropblock probability. Set to `nothing` to disable `DropBlock` + - `stochastic_depth_prob`: stochastic depth probability. Set to `nothing` to disable `StochasticDepth` + - `stride_fn`: callback for computing the stride of the block + - `planes_fn`: callback for computing the number of channels in each block + - `downsample_tuple`: two-element tuple of downsample functions to use. The first one + is used when the number of channels changes in the block, the second one is used + when the number of channels stays the same. +""" function basicblock_builder(block_repeats::AbstractVector{<:Integer}; inplanes::Integer = 64, reduction_factor::Integer = 1, expansion::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, + attn_fn = planes -> identity, dropblock_prob = nothing, + stochastic_depth_prob = nothing, stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) - # DropBlock, DropPath both take in rates based on a linear scaling schedule + # DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule # Also get `planes_vec` needed for block `inplanes` and `planes` calculations - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) + dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats)) planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule + # DropBlock, StochasticDepth both take in probabilities based on a linear scaling schedule # This is also needed for block `inplanes` and `planes` calculations schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx planes = planes_vec[schedule_idx] inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) + drop_path = StochasticDepth(sdschedule[schedule_idx]) + drop_block = DropBlock(dbschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) - return block, downsample + return (downsample, block) end return get_layers end +""" + bottleneck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 64, reduction_factor::Integer = 1, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, dropblock_prob = nothing, + stochastic_depth_prob = nothing, stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + +Builder for creating a bottleneck block for a ResNet/ResNeXt model. +([reference](https://arxiv.org/abs/1611.05431)) + +# Arguments + + - `block_repeats`: number of repeats of a block in each stage + - `inplanes`: number of input channels + - `cardinality`: number of groups for the convolutional layer + - `base_width`: base width for the convolutional layer + - `reduction_factor`: reduction factor for the number of channels in each stage + - `expansion`: expansion factor for the number of channels for the block + - `norm_layer`: normalization layer to use + - `revnorm`: set to `true` to place normalization layer before the convolution + - `activation`: activation function to use + - `attn_fn`: attention function to use + - `dropblock_prob`: dropblock probability. Set to `nothing` to disable `DropBlock` + - `stochastic_depth_prob`: stochastic depth probability. Set to `nothing` to disable `StochasticDepth` + - `stride_fn`: callback for computing the stride of the block + - `planes_fn`: callback for computing the number of channels in each block + - `downsample_tuple`: two-element tuple of downsample functions to use. The first one + is used when the number of channels changes in the block, the second one is used + when the number of channels stays the same. +""" function bottleneck_builder(block_repeats::AbstractVector{<:Integer}; inplanes::Integer = 64, cardinality::Integer = 1, base_width::Integer = 64, reduction_factor::Integer = 1, expansion::Integer = 4, norm_layer = BatchNorm, revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - drop_block_rate = nothing, drop_path_rate = nothing, - stride_fn = resnet_stride, planes_fn = resnet_planes, + attn_fn = planes -> identity, dropblock_prob = nothing, + stochastic_depth_prob = nothing, stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = (downsample_conv, downsample_identity)) - pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats)) - blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats)) + sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(block_repeats)) + dbschedule = linear_scheduler(dropblock_prob; depth = sum(block_repeats)) planes_vec = collect(planes_fn(block_repeats)) # closure over `idxs` function get_layers(stage_idx::Integer, block_idx::Integer) - # DropBlock, DropPath both take in rates based on a linear scaling schedule + # DropBlock, StochasticDepth both take in rates based on a linear scaling schedule # This is also needed for block `inplanes` and `planes` calculations schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx planes = planes_vec[schedule_idx] inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper stride = stride_fn(stage_idx, block_idx) downsample_fn = stride != 1 || inplanes != planes * expansion ? downsample_tuple[1] : downsample_tuple[2] - drop_path = DropPath(pathschedule[schedule_idx]) - drop_block = DropBlock(blockschedule[schedule_idx]) + drop_path = StochasticDepth(sdschedule[schedule_idx]) + drop_block = DropBlock(dbschedule[schedule_idx]) block = bottleneck(inplanes, planes; stride, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) + return (downsample, block) + end + return get_layers +end + +""" + bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 26, scale::Integer = 4, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + +Builder for creating a bottle2neck block for a Res2Net model. +([reference](https://arxiv.org/abs/1904.01169)) + +# Arguments + + - `block_repeats`: number of repeats of a block in each stage + - `inplanes`: number of input channels + - `cardinality`: number of groups for the convolutional layer + - `base_width`: base width for the convolutional layer + - `scale`: scale for the number of channels in each block + - `expansion`: expansion factor for the number of channels for the block + - `norm_layer`: normalization layer to use + - `revnorm`: set to `true` to place normalization layer before the convolution + - `activation`: activation function to use + - `attn_fn`: attention function to use + - `stride_fn`: callback for computing the stride of the block + - `planes_fn`: callback for computing the number of channels in each block + - `downsample_tuple`: two-element tuple of downsample functions to use. The first one + is used when the number of channels changes in the block, the second one is used + when the number of channels stays the same. +""" +function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; + inplanes::Integer = 64, cardinality::Integer = 1, + base_width::Integer = 26, scale::Integer = 4, + expansion::Integer = 4, norm_layer = BatchNorm, + revnorm::Bool = false, activation = relu, + attn_fn = planes -> identity, stride_fn = resnet_stride, + planes_fn = resnet_planes, + downsample_tuple = (downsample_conv, downsample_identity)) + planes_vec = collect(planes_fn(block_repeats)) + # closure over `idxs` + function get_layers(stage_idx::Integer, block_idx::Integer) + # This is needed for block `inplanes` and `planes` calculations + schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx + planes = planes_vec[schedule_idx] + inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion + stride = stride_fn(stage_idx, block_idx) + downsample_fn = (stride != 1 || inplanes != planes * expansion) ? + downsample_tuple[1] : downsample_tuple[2] + is_first = (stride > 1 || downsample_fn != downsample_tuple[2]) ? true : false + block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale, + activation, is_first, norm_layer, revnorm, attn_fn) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) return block, downsample end return get_layers diff --git a/src/convnets/builders/resnet.jl b/src/convnets/builders/resnet.jl new file mode 100644 index 000000000..580baaa34 --- /dev/null +++ b/src/convnets/builders/resnet.jl @@ -0,0 +1,44 @@ +""" + build_resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, + connection, classifier_fn) + +Creates a generic ResNet-like model. + +!!! info + + This is a very generic, flexible but low level function that can be used to create any of the ResNet + variants. For a more user friendly function, see [`Metalhead.resnet`](@ref). + +# Arguments + + - `img_dims`: The dimensions of the input image. This is used to determine the number of feature + maps to be passed to the classifier. This should be a tuple of the form `(height, width, channels)`. + - `stem`: The stem of the ResNet model. The stem should be created outside of this function and + passed in as an argument. This is done to allow for more flexibility in creating the stem. + [`resnet_stem`](@ref) is a helper function that Metalhead provides which is recommended for + creating the stem. + - `get_layers` is a function that takes in two inputs - the `stage_idx`, or the index of + the stage, and the `block_idx`, or the index of the block within the stage. It returns a + tuple of layers. If the tuple returned by `get_layers` has more than one element, then + `connection` is used to splat this tuple into `Parallel` - if not, then the only element of + the tuple is directly inserted into the network. `get_layers` is a very specific function + and should not be created on its own. Instead, use one of the builders provided by Metalhead + to create it. + - `block_repeats`: This is a `Vector` of integers that specifies the number of repeats of each + block in each stage. + - `connection`: This is a function that determines the residual connection in the model. For + `resnets`, either of [`Metalhead.addact`](@ref) or [`Metalhead.actadd`](@ref) is recommended. + - `classifier_fn`: This is a function that takes in the number of feature maps and returns a + classifier. This is usually built as a closure using a function like [`Metalhead.create_classifier`](@ref). + For example, if the number of output classes is `nclasses`, then the function can be defined as + `channels -> create_classifier(channels, nclasses)`. +""" +function build_resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, + connection, classifier_fn) + # Build stages of the ResNet + stage_blocks = cnn_stages(get_layers, block_repeats, connection) + backbone = Chain(stem, stage_blocks...) + # Add classifier to the backbone + nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] + return Chain(backbone, classifier_fn(nfeaturemaps)) +end diff --git a/src/convnets/builders/stages.jl b/src/convnets/builders/stages.jl new file mode 100644 index 000000000..13706fd02 --- /dev/null +++ b/src/convnets/builders/stages.jl @@ -0,0 +1,39 @@ +""" + cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, [connection = nothing]) + +Creates a convolutional neural network backbone by calling the function `get_layers` +repeatedly. + +# Arguments + + - `get_layers` is a function that takes in two inputs - the `stage_idx`, or the index of + the stage, and the `block_idx`, or the index of the block within the stage. It returns a + tuple of layers. If the tuple returned by `get_layers` has more than one element, then + `connection` is used - if not, then the only element of the tuple is directly inserted + into the network. + - `block_repeats` is a `Vector` of integers, where each element specifies the number of + times the `get_layers` function should be called for that stage. + - `connection` defaults to `nothing` and is an optional argument that specifies the + connection type between the layers. It is passed to `Parallel` and is useful for residual + network structures. For example, for ResNet, the connection is simply `+`. If `connection` + is `nothing`, then every call to `get_layers` _must_ return a tuple of length 1. +""" +function cnn_stages(get_layers, block_repeats::AbstractVector{<:Integer}, + connection = nothing) + # Construct each stage + stages = [] + for (stage_idx, nblocks) in enumerate(block_repeats) + # Construct the blocks for each stage + blocks = map(1:nblocks) do block_idx + branches = get_layers(stage_idx, block_idx) + if isnothing(connection) + @assert length(branches)==1 "get_layers should return a single branch for + each block if no connection is specified" + end + return length(branches) == 1 ? only(branches) : + Parallel(connection, branches...) + end + push!(stages, Chain(blocks...)) + end + return stages +end diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index bc1a71a5f..83d8d4638 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -17,7 +17,7 @@ Creates a ConvMixer model. - `nclasses`: number of classes in the output """ function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9), - patch_size::Dims{2} = (7, 7), activation = gelu, dropout_rate = nothing, + patch_size::Dims{2} = (7, 7), activation = gelu, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] # stem of the model @@ -31,7 +31,7 @@ function convmixer(planes::Integer, depth::Integer; kernel_size::Dims{2} = (9, 9 conv_norm((1, 1), planes, planes, activation; preact = true)...) for _ in 1:depth] append!(layers, stages) - return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_rate)) + return Chain(Chain(layers...), create_classifier(planes, nclasses; dropout_prob)) end const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), @@ -45,7 +45,8 @@ const CONVMIXER_CONFIGS = Dict(:base => ((1536, 20), patch_size = (7, 7)))) """ - ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + ConvMixer(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates a ConvMixer model. ([reference](https://arxiv.org/abs/2201.09792)) @@ -53,18 +54,25 @@ Creates a ConvMixer model. # Arguments - `config`: the size of the model, either `:base`, `:small` or `:large` + - `pretrain`: whether to load the pre-trained weights for ImageNet - `inchannels`: number of input channels - `nclasses`: number of classes in the output + +See also [`Metalhead.convmixer`](@ref). """ struct ConvMixer layers::Any end @functor ConvMixer -function ConvMixer(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) +function ConvMixer(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(config, keys(CONVMIXER_CONFIGS)) layers = convmixer(CONVMIXER_CONFIGS[config][1]...; CONVMIXER_CONFIGS[config][2]..., inchannels, nclasses) + if pretrain + loadpretrain!(layers, "convmixer$config") + end return ConvMixer(layers) end diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 15271cfed..30b314556 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -1,5 +1,5 @@ """ - convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6) + convnextblock(planes::Integer, stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6) Creates a single block of ConvNeXt. ([reference](https://arxiv.org/abs/2201.03545)) @@ -7,23 +7,23 @@ Creates a single block of ConvNeXt. # Arguments - `planes`: number of input channels. - - `drop_path_rate`: Stochastic depth rate. - - `layerscale_init`: Initial value for [`LayerScale`](#) + - `stochastic_depth_prob`: Stochastic depth probability. + - `layerscale_init`: Initial value for [`LayerScale`](@ref) """ -function convnextblock(planes::Integer, drop_path_rate = 0.0, layerscale_init = 1.0f-6) - layers = SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), - swapdims((3, 1, 2, 4)), - LayerNorm(planes; ϵ = 1.0f-6), - mlp_block(planes, 4 * planes), - LayerScale(planes, layerscale_init), - swapdims((2, 3, 1, 4)), - DropPath(drop_path_rate)), +) - return layers +function convnextblock(planes::Integer, stochastic_depth_prob = 0.0, + layerscale_init = 1.0f-6) + return SkipConnection(Chain(DepthwiseConv((7, 7), planes => planes; pad = 3), + swapdims((3, 1, 2, 4)), + LayerNorm(planes; ϵ = 1.0f-6), + mlp_block(planes, 4 * planes), + LayerScale(planes, layerscale_init), + swapdims((2, 3, 1, 4)), + StochasticDepth(stochastic_depth_prob)), +) end """ convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; - drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, + stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, nclasses::Integer = 1000) Creates the layers for a ConvNeXt model. @@ -33,32 +33,33 @@ Creates the layers for a ConvNeXt model. - `depths`: list with configuration for depth of each block - `planes`: list with configuration for number of output channels in each block - - `drop_path_rate`: Stochastic depth rate. - - `layerscale_init`: Initial value for [`LayerScale`](#) + - `stochastic_depth_prob`: Stochastic depth probability. + - `layerscale_init`: Initial value for [`LayerScale`](@ref) ([reference](https://arxiv.org/abs/2103.17239)) - `inchannels`: number of input channels. - `nclasses`: number of output classes """ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:Integer}; - drop_path_rate = 0.0, layerscale_init = 1.0f-6, inchannels::Integer = 3, + stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, + inchannels::Integer = 3, nclasses::Integer = 1000) @assert length(depths) == length(planes) "`planes` should have exactly one value for each block" downsample_layers = [] push!(downsample_layers, - Chain(conv_norm((4, 4), inchannels => planes[1]; stride = 4, + Chain(conv_norm((4, 4), inchannels, planes[1]; stride = 4, norm_layer = ChannelLayerNorm)...)) for m in 1:(length(depths) - 1) push!(downsample_layers, - Chain(conv_norm((2, 2), planes[m] => planes[m + 1]; stride = 2, + Chain(conv_norm((2, 2), planes[m], planes[m + 1]; stride = 2, norm_layer = ChannelLayerNorm, revnorm = true)...)) end stages = [] - dp_rates = linear_scheduler(drop_path_rate; depth = sum(depths)) + sdschedule = linear_scheduler(stochastic_depth_prob; depth = sum(depths)) cur = 0 for i in eachindex(depths) push!(stages, - [convnextblock(planes[i], dp_rates[cur + j], layerscale_init) + [convnextblock(planes[i], sdschedule[cur + j], layerscale_init) for j in 1:depths[i]]) cur += depths[i] end @@ -68,6 +69,12 @@ function convnext(depths::AbstractVector{<:Integer}, planes::AbstractVector{<:In return Chain(Chain(backbone...), classifier) end +function convnext(config::Symbol; stochastic_depth_prob = 0.0, layerscale_init = 1.0f-6, + inchannels::Integer = 3, nclasses::Integer = 1000) + return convnext(CONVNEXT_CONFIGS[config]...; stochastic_depth_prob, layerscale_init, + inchannels, nclasses) +end + # Configurations for ConvNeXt models const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), :small => ([3, 3, 27, 3], [96, 192, 384, 768]), @@ -76,7 +83,8 @@ const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) """ - ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) + ConvNeXt(config::Symbol; pretrain::Bool = true, inchannels::Integer = 3, + nclasses::Integer = 1000) Creates a ConvNeXt model. ([reference](https://arxiv.org/abs/2201.03545)) @@ -84,19 +92,28 @@ Creates a ConvNeXt model. # Arguments - `config`: The size of the model, one of `tiny`, `small`, `base`, `large` or `xlarge`. + - `pretrain`: set to `true` to load pre-trained weights for ImageNet - `inchannels`: number of input channels - `nclasses`: number of output classes -See also [`Metalhead.convnext`](#). +!!! warning + + `ConvNeXt` does not currently support pretrained weights. + +See also [`Metalhead.convnext`](@ref). """ struct ConvNeXt layers::Any end @functor ConvNeXt -function ConvNeXt(config::Symbol; inchannels::Integer = 3, nclasses::Integer = 1000) +function ConvNeXt(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) _checkconfig(config, keys(CONVNEXT_CONFIGS)) - layers = convnext(CONVNEXT_CONFIGS[config]...; inchannels, nclasses) + layers = convnext(config; inchannels, nclasses) + if pretrain + layers = loadpretrain!(layers, "convnext_$config") + end return ConvNeXt(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index a7c367c1c..4799e22a7 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -55,7 +55,7 @@ function dense_block(inplanes::Integer, growth_rates) end """ - densenet(inplanes, growth_rates; reduction = 0.5, dropout_rate = nothing, + densenet(inplanes, growth_rates; reduction = 0.5, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create a DenseNet model @@ -65,12 +65,12 @@ Create a DenseNet model - `inplanes`: the number of input feature maps to the first dense block - `growth_rates`: the growth rates of output feature maps within each - [`dense_block`](#) (a vector of vectors) + [`dense_block`](@ref) (a vector of vectors) - `reduction`: the factor by which the number of feature maps is scaled across each transition - - `dropout_rate`: the dropout rate for the classifier head. Set to `nothing` to disable dropout. + - `dropout_prob`: the dropout probability for the classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes """ -function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate = nothing, +function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) layers = [] append!(layers, @@ -85,7 +85,7 @@ function densenet(inplanes::Integer, growth_rates; reduction = 0.5, dropout_rate inplanes = floor(Int, outplanes * reduction) end push!(layers, BatchNorm(outplanes, relu)) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) + return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_prob)) end """ @@ -97,15 +97,15 @@ Create a DenseNet model # Arguments - `nblocks`: number of dense blocks between transitions - - `growth_rate`: the output feature map growth rate of dense blocks (i.e. `k` in the ref) + - `growth_rate`: the output feature map growth probability of dense blocks (i.e. `k` in the ref) - `reduction`: the factor by which the number of feature maps is scaled across each transition - `nclasses`: the number of output classes """ function densenet(nblocks::AbstractVector{<:Integer}; growth_rate::Integer = 32, - reduction = 0.5, dropout_rate = nothing, inchannels::Integer = 3, + reduction = 0.5, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks]; - reduction, dropout_rate, inchannels, nclasses) + reduction, dropout_prob, inchannels, nclasses) end const DENSENET_CONFIGS = Dict(121 => [6, 12, 24, 16], @@ -125,7 +125,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. `DenseNet` does not currently support pretrained weights. -See also [`Metalhead.densenet`](#). +See also [`Metalhead.densenet`](@ref). """ struct DenseNet layers::Any diff --git a/src/convnets/efficientnets/core.jl b/src/convnets/efficientnets/core.jl deleted file mode 100644 index d853e0c1a..000000000 --- a/src/convnets/efficientnets/core.jl +++ /dev/null @@ -1,21 +0,0 @@ -function efficientnet(block_configs::AbstractVector{<:Tuple}; inplanes::Integer, - scalings::NTuple{2, Real} = (1, 1), - headplanes::Integer = block_configs[end][3] * 4, - norm_layer = BatchNorm, dropout_rate = nothing, - inchannels::Integer = 3, nclasses::Integer = 1000) - layers = [] - # stem of the model - inplanes = _round_channels(inplanes * scalings[1]) - append!(layers, - conv_norm((3, 3), inchannels, inplanes, swish; norm_layer, stride = 2, - pad = SamePad())) - # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; scalings, - norm_layer) - append!(layers, cnn_stages(get_layers, block_repeats, +)) - # building last layers - append!(layers, - conv_norm((1, 1), _round_channels(block_configs[end][3] * scalings[1]), - headplanes, swish; pad = SamePad())) - return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) -end diff --git a/src/convnets/efficientnets/efficientnet.jl b/src/convnets/efficientnets/efficientnet.jl index 5eb81b21d..2657a3884 100644 --- a/src/convnets/efficientnets/efficientnet.jl +++ b/src/convnets/efficientnets/efficientnet.jl @@ -32,16 +32,51 @@ const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), :b8 => (672, (2.2, 3.6))) """ - EfficientNet(config::Symbol; pretrain::Bool = false) + efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2, + dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNet model. ([reference](https://arxiv.org/abs/1905.11946v5)). + +# Arguments + + - `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`. + - `norm_layer`: normalization layer to use. + - `stochastic_depth_prob`: probability of stochastic depth. Set to `nothing` to disable + stochastic depth. + - `dropout_prob`: probability of dropout in the classifier head. Set to `nothing` to disable + dropout. + - `inchannels`: number of input channels. + - `nclasses`: number of output classes. +""" +function efficientnet(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2, + dropout_prob = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] + return build_invresmodel(scalings, EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, + norm_layer, stochastic_depth_prob, activation = swish, + headplanes = EFFICIENTNET_BLOCK_CONFIGS[end][3] * 4, + dropout_prob, inchannels, nclasses) +end + +""" + EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)). -See also [`efficientnet`](#). # Arguments - - `config`: name of default configuration - (can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`) + - `config`: size of the model. Can be one of `[:b0, :b1, :b2, :b3, :b4, :b5, :b6, :b7, :b8]`. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet + - `inchannels`: number of input channels. + - `nclasses`: number of output classes. + +!!! warning + + EfficientNet does not currently support pretrained weights. + +See also [`Metalhead.efficientnet`](@ref). """ struct EfficientNet layers::Any @@ -50,10 +85,7 @@ end function EfficientNet(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, keys(EFFICIENTNET_GLOBAL_CONFIGS)) - scalings = EFFICIENTNET_GLOBAL_CONFIGS[config][2] - layers = efficientnet(EFFICIENTNET_BLOCK_CONFIGS; inplanes = 32, scalings, - inchannels, nclasses) + layers = efficientnet(config; inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnet-", config)) end diff --git a/src/convnets/efficientnets/efficientnetv2.jl b/src/convnets/efficientnets/efficientnetv2.jl index ff64eea23..dab159e68 100644 --- a/src/convnets/efficientnets/efficientnetv2.jl +++ b/src/convnets/efficientnets/efficientnetv2.jl @@ -36,8 +36,35 @@ const EFFNETV2_CONFIGS = Dict(:small => [(fused_mbconv, 3, 24, 1, 1, 2, swish), (mbconv, 3, 768, 6, 1, 8, 4, swish)]) """ - EfficientNetv2(config::Symbol; pretrain::Bool = false, width_mult::Real = 1, - inchannels::Integer = 3, nclasses::Integer = 1000) + efficientnetv2(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2, + dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an EfficientNetv2 model. ([reference](https://arxiv.org/abs/2104.00298)). + +# Arguments + + - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) + - `norm_layer`: normalization layer to use. + - `stochastic_depth_prob`: probability of stochastic depth. Set to `nothing` to disable + stochastic depth. + - `dropout_prob`: probability of dropout in the classifier head. Set to `nothing` to disable + dropout. + - `inchannels`: number of input channels. + - `nclasses`: number of output classes. +""" +function efficientnetv2(config::Symbol; norm_layer = BatchNorm, stochastic_depth_prob = 0.2, + dropout_prob = nothing, inchannels::Integer = 3, + nclasses::Integer = 1000) + _checkconfig(config, keys(EFFNETV2_CONFIGS)) + block_configs = EFFNETV2_CONFIGS[config] + return build_invresmodel((1, 1), block_configs; activation = swish, norm_layer, + inplanes = block_configs[1][3], headplanes = 1280, + stochastic_depth_prob, dropout_prob, inchannels, nclasses) +end + +""" + EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, + nclasses::Integer = 1000) Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). @@ -45,10 +72,14 @@ Create an EfficientNetv2 model ([reference](https://arxiv.org/abs/2104.00298)). - `config`: size of the network (one of `[:small, :medium, :large, :xlarge]`) - `pretrain`: whether to load the pre-trained weights for ImageNet - - `width_mult`: Controls the number of output feature maps in each block (with 1 - being the default in the paper) - `inchannels`: number of input channels - `nclasses`: number of output classes + +!!! warning + + `EfficientNetv2` does not currently support pretrained weights. + +See also [`efficientnet`](#). """ struct EfficientNetv2 layers::Any @@ -57,10 +88,7 @@ end function EfficientNetv2(config::Symbol; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, sort(collect(keys(EFFNETV2_CONFIGS)))) - block_configs = EFFNETV2_CONFIGS[config] - layers = efficientnet(block_configs; inplanes = block_configs[1][3], - headplanes = 1280, inchannels, nclasses) + layers = efficientnetv2(config; inchannels, nclasses) if pretrain loadpretrain!(layers, string("efficientnetv2-", config)) end diff --git a/src/convnets/inceptions/googlenet.jl b/src/convnets/inceptions/googlenet.jl index 54f814479..c343eec72 100644 --- a/src/convnets/inceptions/googlenet.jl +++ b/src/convnets/inceptions/googlenet.jl @@ -1,5 +1,5 @@ """ - _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_3x3, pool_proj) + inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_3x3, pool_proj) Create an inception module for use in GoogLeNet ([reference](https://arxiv.org/abs/1409.4842v1)). @@ -14,7 +14,7 @@ Create an inception module for use in GoogLeNet - `out_5x5`: the number of output feature maps for the 5x5 convolution (branch 3) - `pool_proj`: the number of output feature maps for the pooling projection (branch 4) """ -function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) +function inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, pool_proj) branch1 = Chain(Conv((1, 1), inplanes => out_1x1)) branch2 = Chain(Conv((1, 1), inplanes => red_3x3), Conv((3, 3), red_3x3 => out_3x3; pad = 1)) @@ -27,33 +27,35 @@ function _inceptionblock(inplanes, out_1x1, red_3x3, out_3x3, red_5x5, out_5x5, end """ - googlenet(; nclasses::Integer = 1000) + googlenet(; dropout_prob = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000) Create an Inception-v1 model (commonly referred to as GoogLeNet) ([reference](https://arxiv.org/abs/1409.4842v1)). # Arguments + - `dropout_prob`: the dropout probability in the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: the number of input channels - `nclasses`: the number of output classes """ -function googlenet(; dropout_rate = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000) +function googlenet(; dropout_prob = 0.4, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(Conv((7, 7), inchannels => 64; stride = 2, pad = 3), MaxPool((3, 3); stride = 2, pad = 1), Conv((1, 1), 64 => 64), Conv((3, 3), 64 => 192; pad = 1), MaxPool((3, 3); stride = 2, pad = 1), - _inceptionblock(192, 64, 96, 128, 16, 32, 32), - _inceptionblock(256, 128, 128, 192, 32, 96, 64), + inceptionblock(192, 64, 96, 128, 16, 32, 32), + inceptionblock(256, 128, 128, 192, 32, 96, 64), MaxPool((3, 3); stride = 2, pad = 1), - _inceptionblock(480, 192, 96, 208, 16, 48, 64), - _inceptionblock(512, 160, 112, 224, 24, 64, 64), - _inceptionblock(512, 128, 128, 256, 24, 64, 64), - _inceptionblock(512, 112, 144, 288, 32, 64, 64), - _inceptionblock(528, 256, 160, 320, 32, 128, 128), + inceptionblock(480, 192, 96, 208, 16, 48, 64), + inceptionblock(512, 160, 112, 224, 24, 64, 64), + inceptionblock(512, 128, 128, 256, 24, 64, 64), + inceptionblock(512, 112, 144, 288, 32, 64, 64), + inceptionblock(528, 256, 160, 320, 32, 128, 128), MaxPool((3, 3); stride = 2, pad = 1), - _inceptionblock(832, 256, 160, 320, 32, 128, 128), - _inceptionblock(832, 384, 192, 384, 48, 128, 128)) - return Chain(backbone, create_classifier(1024, nclasses; dropout_rate)) + inceptionblock(832, 256, 160, 320, 32, 128, 128), + inceptionblock(832, 384, 192, 384, 48, 128, 128)) + return Chain(backbone, create_classifier(1024, nclasses; dropout_prob)) end """ @@ -71,7 +73,7 @@ Create an Inception-v1 model (commonly referred to as `GoogLeNet`) `GoogLeNet` does not currently support pretrained weights. -See also [`googlenet`](#). +See also [`Metalhead.googlenet`](@ref). """ struct GoogLeNet layers::Any diff --git a/src/convnets/inceptions/inceptionresnetv2.jl b/src/convnets/inceptions/inceptionresnetv2.jl index bd88648e9..98d686062 100644 --- a/src/convnets/inceptions/inceptionresnetv2.jl +++ b/src/convnets/inceptions/inceptionresnetv2.jl @@ -64,18 +64,18 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) + inceptionresnetv2(; inchannels::Integer = 3, dropout_prob = nothing, nclasses::Integer = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) # Arguments + - `dropout_prob`: probability of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, +function inceptionresnetv2(; dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., @@ -92,7 +92,7 @@ function inceptionresnetv2(; dropout_rate = nothing, inchannels::Integer = 3, [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), basic_conv_bn((1, 1), 2080, 1536)...) - return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) + return Chain(backbone, create_classifier(1536, nclasses; dropout_prob)) end """ @@ -111,6 +111,8 @@ Creates an InceptionResNetv2 model. !!! warning `InceptionResNetv2` does not currently support pretrained weights. + +See also [`Metalhead.inceptionresnetv2`](@ref). """ struct InceptionResNetv2 layers::Any diff --git a/src/convnets/inceptions/inceptionv3.jl b/src/convnets/inceptions/inceptionv3.jl index 32fbbede5..41d7ae18e 100644 --- a/src/convnets/inceptions/inceptionv3.jl +++ b/src/convnets/inceptions/inceptionv3.jl @@ -125,15 +125,17 @@ function inceptionv3_e(inplanes) end """ - inceptionv3(; inchannels::Integer = 3, nclasses::Integer = 1000) + inceptionv3(; dropout_prob = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). # Arguments + - `dropout_prob`: the dropout probability in the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: number of input feature maps - `nclasses`: the number of output classes """ -function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, +function inceptionv3(; dropout_prob = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., @@ -153,14 +155,13 @@ function inceptionv3(; dropout_rate = 0.2, inchannels::Integer = 3, inceptionv3_d(768), inceptionv3_e(1280), inceptionv3_e(2048)) - return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) + return Chain(backbone, create_classifier(2048, nclasses; dropout_prob)) end """ Inceptionv3(; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) Create an Inception-v3 model ([reference](https://arxiv.org/abs/1512.00567v3)). -See also [`inceptionv3`](#). # Arguments @@ -171,6 +172,8 @@ See also [`inceptionv3`](#). !!! warning `Inceptionv3` does not currently support pretrained weights. + +See also [`Metalhead.inceptionv3`](@ref). """ struct Inceptionv3 layers::Any diff --git a/src/convnets/inceptions/inceptionv4.jl b/src/convnets/inceptions/inceptionv4.jl index 13d40da25..964afc362 100644 --- a/src/convnets/inceptions/inceptionv4.jl +++ b/src/convnets/inceptions/inceptionv4.jl @@ -85,18 +85,18 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels::Integer = 3, dropout_rate = nothing, nclasses::Integer = 1000) + inceptionv4(; dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) # Arguments + - `dropout_prob`: probability of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. - `nclasses`: the number of output classes. """ -function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, +function inceptionv4(; dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(basic_conv_bn((3, 3), inchannels, 32; stride = 2)..., basic_conv_bn((3, 3), 32, 32)..., @@ -107,7 +107,7 @@ function inceptionv4(; dropout_rate = nothing, inchannels::Integer = 3, [inceptionv4_b() for _ in 1:7]..., reduction_b(), # mixed_7a [inceptionv4_c() for _ in 1:3]...) - return Chain(backbone, create_classifier(1536, nclasses; dropout_rate)) + return Chain(backbone, create_classifier(1536, nclasses; dropout_prob)) end """ @@ -126,6 +126,8 @@ Creates an Inceptionv4 model. !!! warning `Inceptionv4` does not currently support pretrained weights. + +See also [`Metalhead.inceptionv4`](@ref). """ struct Inceptionv4 layers::Any diff --git a/src/convnets/inceptions/xception.jl b/src/convnets/inceptions/xception.jl index 171bddd19..9dfd73f86 100644 --- a/src/convnets/inceptions/xception.jl +++ b/src/convnets/inceptions/xception.jl @@ -34,7 +34,7 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end push!(layers, relu) append!(layers, - dwsep_conv_bn((3, 3), inc, outc; pad = 1, use_norm = (false, false))) + dwsep_conv_norm((3, 3), inc, outc; pad = 1, norm_layer = identity)) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] @@ -43,18 +43,18 @@ function xception_block(inchannels::Integer, outchannels::Integer, nrepeats::Int end """ - xception(; dropout_rate = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) + xception(; dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) # Arguments - - `dropout_rate`: rate of dropout in classifier head. Set to `nothing` to disable dropout. + - `dropout_prob`: probability of dropout in classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ -function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) +function xception(; dropout_prob = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(conv_norm((3, 3), inchannels, 32; stride = 2)..., conv_norm((3, 3), 32, 64)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -62,9 +62,9 @@ function xception(; dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integ xception_block(256, 728, 2; stride = 2), [xception_block(728, 728, 3) for _ in 1:8]..., xception_block(728, 1024, 2; stride = 2, grow_at_start = false), - dwsep_conv_bn((3, 3), 1024, 1536; pad = 1)..., - dwsep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - return Chain(backbone, create_classifier(2048, nclasses; dropout_rate)) + dwsep_conv_norm((3, 3), 1024, 1536; pad = 1)..., + dwsep_conv_norm((3, 3), 1536, 2048; pad = 1)...) + return Chain(backbone, create_classifier(2048, nclasses; dropout_prob)) end """ @@ -82,6 +82,8 @@ Creates an Xception model. !!! warning `Xception` does not currently support pretrained weights. + +See also [`Metalhead.xception`](@ref). """ struct Xception layers::Any diff --git a/src/convnets/mobilenets/mnasnet.jl b/src/convnets/mobilenets/mnasnet.jl index 2f6db2acf..98cd9d759 100644 --- a/src/convnets/mobilenets/mnasnet.jl +++ b/src/convnets/mobilenets/mnasnet.jl @@ -1,48 +1,5 @@ -# momentum used for BatchNorm as per Tensorflow implementation -const _MNASNET_BN_MOMENTUM = 0.0003f0 - -""" - mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width = 1280, dropout_rate = 0.2, inchannels::Integer = 3, - nclasses::Integer = 1000) - -Create an MNASNet model with the specified configuration. -([reference](https://arxiv.org/abs/1807.11626)). - -# Arguments - - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - `max_width`: The maximum number of feature maps in any layer of the network - - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. - - `inchannels`: The number of input channels. - - `nclasses`: The number of output classes -""" -function mnasnet(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, inplanes::Integer = 32, dropout_rate = 0.2, - inchannels::Integer = 3, nclasses::Integer = 1000) - # norm layer for MNASNet is different from other models - norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = _MNASNET_BN_MOMENTUM, - kwargs...) - # building first layer - inplanes = _round_channels(inplanes * width_mult) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, relu; stride = 2, pad = 1, - norm_layer)) - # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, - norm_layer) - append!(layers, cnn_stages(get_layers, block_repeats, +)) - # building last layers - outplanes = _round_channels(block_configs[end][3] * width_mult) - append!(layers, - conv_norm((1, 1), outplanes, max_width, relu; norm_layer)) - return Chain(Chain(layers...), create_classifier(max_width, nclasses; dropout_rate)) -end - # Layer configurations for MNasNet -# f: block function - we use `dwsep_conv_bn` for the first block and `mbconv` for the rest +# f: block function - we use `dwsep_conv_norm` for the first block and `mbconv` for the rest # k: kernel size # c: output channels # e: expansion factor - only used for `mbconv` @@ -53,7 +10,7 @@ end # Data is organised as (f, k, c, (e,) s, n, (r,) a) const MNASNET_CONFIGS = Dict(:B1 => (32, [ - (dwsep_conv_bn, 3, 16, 1, 1, relu), + (dwsep_conv_norm, 3, 16, 1, 1, relu), (mbconv, 3, 24, 3, 2, 3, nothing, relu), (mbconv, 5, 40, 3, 2, 3, nothing, relu), (mbconv, 5, 80, 6, 2, 3, nothing, relu), @@ -63,7 +20,7 @@ const MNASNET_CONFIGS = Dict(:B1 => (32, ]), :A1 => (32, [ - (dwsep_conv_bn, 3, 16, 1, 1, relu), + (dwsep_conv_norm, 3, 16, 1, 1, relu), (mbconv, 3, 24, 6, 2, 2, nothing, relu), (mbconv, 5, 40, 3, 2, 3, 4, relu), (mbconv, 3, 80, 6, 2, 4, nothing, relu), @@ -73,26 +30,57 @@ const MNASNET_CONFIGS = Dict(:B1 => (32, ]), :small => (8, [ - (dwsep_conv_bn, 3, 8, 1, 1, relu), + (dwsep_conv_norm, 3, 8, 1, 1, relu), (mbconv, 3, 16, 3, 2, 1, nothing, relu), (mbconv, 3, 16, 6, 2, 2, nothing, relu), (mbconv, 5, 32, 6, 2, 4, 4, relu), (mbconv, 3, 32, 6, 1, 3, 4, relu), (mbconv, 5, 88, 6, 2, 3, 4, relu), - (mbconv, 3, 144, 6, 1, 1, nothing, relu)])) + (mbconv, 3, 144, 6, 1, 1, nothing, relu), + ])) """ - MNASNet(width_mult = 1; inchannels::Integer = 3, pretrain::Bool = false, - nclasses::Integer = 1000) + mnasnet(config::Symbol; width_mult::Real = 1, max_width::Integer = 1280, + dropout_prob = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) + +Create an MNasNet model. ([reference](https://arxiv.org/abs/1807.11626)) + +# Arguments + + - `config`: configuration of the model. One of `B1`, `A1` or `small`. `B1` is without + squeeze-and-excite layers, `A1` is with squeeze-and-excite layers, and `small` is a smaller + version of `A1`. + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) + - `max_width`: Controls the maximum number of output feature maps in each block + - `dropout_prob`: Dropout probability for the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: Number of input channels. + - `nclasses`: Number of output classes. +""" +function mnasnet(config::Symbol; width_mult::Real = 1, max_width::Integer = 1280, + dropout_prob = 0.2, inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, keys(MNASNET_CONFIGS)) + # momentum used for BatchNorm is as per Tensorflow implementation + norm_layer = (args...; kwargs...) -> BatchNorm(args...; momentum = 0.0003f0, kwargs...) + inplanes, block_configs = MNASNET_CONFIGS[config] + return build_invresmodel(width_mult, block_configs; inplanes, norm_layer, + headplanes = max_width, dropout_prob, inchannels, nclasses) +end + +""" + MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, + inchannels::Integer = 3, nclasses::Integer = 1000) Creates a MNASNet model with the specified configuration. ([reference](https://arxiv.org/abs/1807.11626)) # Arguments + - `config`: configuration of the model. One of `B1`, `A1` or `small`. `B1` is without + squeeze-and-excite layers, `A1` is with squeeze-and-excite layers, and `small` is a smaller + version of `A1`. - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper; - this is usually a value between 0.1 and 1.4) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inchannels`: The number of input channels. - `nclasses`: The number of output classes @@ -101,7 +89,7 @@ Creates a MNASNet model with the specified configuration. `MNASNet` does not currently support pretrained weights. -See also [`mnasnet`](#). +See also [`Metalhead.mnasnet`](@ref). """ struct MNASNet layers::Any @@ -111,8 +99,7 @@ end function MNASNet(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MNASNET_CONFIGS)) - inplanes, block_configs = MNASNET_CONFIGS[config] - layers = mnasnet(block_configs; width_mult, inplanes, inchannels, nclasses) + layers = mnasnet(config; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, "mnasnet$(width_mult)") end diff --git a/src/convnets/mobilenets/mobilenetv1.jl b/src/convnets/mobilenets/mobilenetv1.jl index 24240d0c0..ca17dc4ac 100644 --- a/src/convnets/mobilenets/mobilenetv1.jl +++ b/src/convnets/mobilenets/mobilenetv1.jl @@ -1,58 +1,41 @@ +# Layer configurations for MobileNetv1 +# f: block function - we use `dwsep_conv_norm` for all blocks +# k: kernel size +# c: output channels +# s: stride +# n: number of repeats +# a: activation function +# Data is organised as (f, k, c, s, n, a) +const MOBILENETV1_CONFIGS = [ + (dwsep_conv_norm, 3, 64, 1, 1, relu6), + (dwsep_conv_norm, 3, 128, 2, 2, relu6), + (dwsep_conv_norm, 3, 256, 2, 2, relu6), + (dwsep_conv_norm, 3, 512, 2, 6, relu6), + (dwsep_conv_norm, 3, 1024, 2, 2, relu6), +] + """ - mobilenetv1(width_mult::Real, config::AbstractVector{<:Tuple}; - activation = relu, dropout_rate = nothing, + mobilenetv1(width_mult::Real = 1; inplanes::Integer = 32, dropout_prob = nothing, inchannels::Integer = 3, nclasses::Integer = 1000) -Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)). +Create a MobileNetv1 model. ([reference](https://arxiv.org/abs/1704.04861v1)). # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `dw`: Set true to use a depthwise separable convolution or false for regular convolution - + `o`: The number of output feature maps - + `s`: The stride of the convolutional kernel - + `r`: The number of time this configuration block is repeated - - `activate`: The activation function to use throughout the network - - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - - `inchannels`: The number of input channels. The default value is 3. - - `nclasses`: The number of output classes + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) + - `inplanes`: Number of input channels to the first convolution layer + - `dropout_prob`: Dropout probability for the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: Number of input channels. + - `nclasses`: Number of output classes. """ -function mobilenetv1(config::AbstractVector{<:Tuple}; width_mult::Real = 1, - activation = relu, dropout_rate = nothing, - inplanes::Integer = 32, inchannels::Integer = 3, - nclasses::Integer = 1000) - layers = [] - # stem of the model - inplanes = _round_channels(inplanes * width_mult) - append!(layers, - conv_norm((3, 3), inchannels, inplanes, activation; stride = 2, pad = 1)) - # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(config, inplanes; width_mult) - append!(layers, cnn_stages(get_layers, block_repeats)) - outplanes = _round_channels(config[end][3] * width_mult) - return Chain(Chain(layers...), create_classifier(outplanes, nclasses; dropout_rate)) +function mobilenetv1(width_mult::Real = 1; inplanes::Integer = 32, dropout_prob = nothing, + inchannels::Integer = 3, nclasses::Integer = 1000) + return build_invresmodel(width_mult, MOBILENETV1_CONFIGS; inplanes, inchannels, + activation = relu6, connection = nothing, tail_conv = false, + headplanes = 1024, dropout_prob, nclasses) end -# Layer configurations for MobileNetv1 -# f: block function - we use `dwsep_conv_bn` for all blocks -# k: kernel size -# c: output channels -# s: stride -# n: number of repeats -# a: activation function -const MOBILENETV1_CONFIGS = [ - # f, k, c, s, n, a - (dwsep_conv_bn, 3, 64, 1, 1, relu6), - (dwsep_conv_bn, 3, 128, 2, 2, relu6), - (dwsep_conv_bn, 3, 256, 2, 2, relu6), - (dwsep_conv_bn, 3, 512, 2, 6, relu6), - (dwsep_conv_bn, 3, 1024, 2, 2, relu6), -] - """ MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) @@ -63,8 +46,7 @@ Create a MobileNetv1 model with the baseline configuration # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper; - this is usually a value between 0.1 and 1.4) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inchannels`: The number of input channels. - `nclasses`: The number of output classes @@ -73,7 +55,7 @@ Create a MobileNetv1 model with the baseline configuration `MobileNetv1` does not currently support pretrained weights. -See also [`mobilenetv1`](#). +See also [`Metalhead.mobilenetv1`](@ref). """ struct MobileNetv1 layers::Any @@ -82,7 +64,7 @@ end function MobileNetv1(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv1(MOBILENETV1_CONFIGS; width_mult, inchannels, nclasses) + layers = mobilenetv1(width_mult; inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv1")) end diff --git a/src/convnets/mobilenets/mobilenetv2.jl b/src/convnets/mobilenets/mobilenetv2.jl index f3e26862c..6d0973130 100644 --- a/src/convnets/mobilenets/mobilenetv2.jl +++ b/src/convnets/mobilenets/mobilenetv2.jl @@ -1,48 +1,3 @@ -""" - mobilenetv2(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, divisor::Integer = 8, dropout_rate = 0.2, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create a MobileNetv2 model. -([reference](https://arxiv.org/abs/1801.04381)). - -# Arguments - - - `configs`: A "list of tuples" configuration for each layer that details: - - + `t`: The expansion factor that controls the number of feature maps in the bottleneck layer - + `c`: The number of output feature maps - + `n`: The number of times a block is repeated - + `s`: The stride of the convolutional kernel - - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper) - - `max_width`: The maximum number of feature maps in any layer of the network - - `divisor`: The divisor used to round the number of feature maps in each block - - `dropout_rate`: rate of dropout in the classifier head. Set to `nothing` to disable dropout. - - `inchannels`: The number of input channels. - - `nclasses`: The number of output classes -""" -function mobilenetv2(block_configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1280, divisor::Integer = 8, - inplanes::Integer = 32, dropout_rate = 0.2, - inchannels::Integer = 3, nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(inplanes * width_mult, divisor) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes; pad = 1, stride = 2)) - # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(block_configs, inplanes; width_mult, - divisor) - append!(layers, cnn_stages(get_layers, block_repeats, +)) - # building last layers - outplanes = _round_channels(block_configs[end][3] * width_mult, divisor) - headplanes = _round_channels(max_width * max(1, width_mult), divisor) - append!(layers, conv_norm((1, 1), outplanes, headplanes, relu6)) - return Chain(Chain(layers...), create_classifier(headplanes, nclasses; dropout_rate)) -end - # Layer configurations for MobileNetv2 # f: block function - we use `mbconv` for all blocks # k: kernel size @@ -52,17 +7,41 @@ end # n: number of repeats # r: reduction factor # a: activation function +# Data is organised as (f, k, c, e, s, n, r, a) const MOBILENETV2_CONFIGS = [ - # f, k, c, e, s, n, r, a - (mbconv, 3, 16, 1, 1, 1, nothing, swish), - (mbconv, 3, 24, 6, 2, 2, nothing, swish), - (mbconv, 3, 32, 6, 2, 3, nothing, swish), - (mbconv, 3, 64, 6, 2, 4, nothing, swish), - (mbconv, 3, 96, 6, 1, 3, nothing, swish), - (mbconv, 3, 160, 6, 2, 3, nothing, swish), - (mbconv, 3, 320, 6, 1, 1, nothing, swish), + (mbconv, 3, 16, 1, 1, 1, nothing, relu6), + (mbconv, 3, 24, 6, 2, 2, nothing, relu6), + (mbconv, 3, 32, 6, 2, 3, nothing, relu6), + (mbconv, 3, 64, 6, 2, 4, nothing, relu6), + (mbconv, 3, 96, 6, 1, 3, nothing, relu6), + (mbconv, 3, 160, 6, 2, 3, nothing, relu6), + (mbconv, 3, 320, 6, 1, 1, nothing, relu6), ] +""" + mobilenetv2(width_mult::Real = 1; max_width::Integer = 1280, + inplanes::Integer = 32, dropout_prob = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create a MobileNetv2 model. ([reference](https://arxiv.org/abs/1801.04381v1)). + +# Arguments + + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) + - `max_width`: The maximum width of the network. + - `inplanes`: Number of input channels to the first convolution layer + - `dropout_prob`: Dropout probability for the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: Number of input channels. + - `nclasses`: Number of output classes. +""" +function mobilenetv2(width_mult::Real = 1; max_width::Integer = 1280, + inplanes::Integer = 32, dropout_prob = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + return build_invresmodel(width_mult, MOBILENETV2_CONFIGS; activation = relu6, inplanes, + headplanes = max_width, dropout_prob, inchannels, nclasses) +end + """ MobileNetv2(width_mult = 1.0; inchannels::Integer = 3, pretrain::Bool = false, nclasses::Integer = 1000) @@ -73,8 +52,7 @@ Create a MobileNetv2 model with the specified configuration. # Arguments - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper; - this is usually a value between 0.1 and 1.4) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: Whether to load the pre-trained weights for ImageNet - `inchannels`: The number of input channels. - `nclasses`: The number of output classes @@ -83,7 +61,7 @@ Create a MobileNetv2 model with the specified configuration. `MobileNetv2` does not currently support pretrained weights. -See also [`mobilenetv2`](#). +See also [`Metalhead.mobilenetv2`](@ref). """ struct MobileNetv2 layers::Any @@ -92,7 +70,7 @@ end function MobileNetv2(width_mult::Real = 1; pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = mobilenetv2(MOBILENETV2_CONFIGS; width_mult, inchannels, nclasses) + layers = mobilenetv2(width_mult; inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv2")) end diff --git a/src/convnets/mobilenets/mobilenetv3.jl b/src/convnets/mobilenets/mobilenetv3.jl index 2614c7c2f..07e4501eb 100644 --- a/src/convnets/mobilenets/mobilenetv3.jl +++ b/src/convnets/mobilenets/mobilenetv3.jl @@ -1,51 +1,3 @@ -""" - mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, dropout_rate = 0.2, - inchannels::Integer = 3, nclasses::Integer = 1000) - -Create a MobileNetv3 model. -([reference](https://arxiv.org/abs/1905.02244)). - -# Arguments - - - `configs`: a "list of tuples" configuration for each layer that details: - - + `k::Integer` - The size of the convolutional kernel - + `c::Float` - The multiplier factor for deciding the number of feature maps in the hidden layer - + `t::Integer` - The number of output feature maps for a given block - + `r::Integer` - The reduction factor (`>= 1` or `nothing` to skip) for squeeze and excite layers - + `s::Integer` - The stride of the convolutional kernel - + `a` - The activation function used in the bottleneck (typically `hardswish` or `relu`) - - - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4.) - - `max_width`: The maximum number of feature maps in any layer of the network - - `dropout_rate`: The dropout rate to use in the classifier head. Set to `nothing` to disable. - - `inchannels`: The number of input channels. - - `nclasses`: the number of output classes -""" -function mobilenetv3(configs::AbstractVector{<:Tuple}; width_mult::Real = 1, - max_width::Integer = 1024, dropout_rate = 0.2, - inchannels::Integer = 3, nclasses::Integer = 1000) - # building first layer - inplanes = _round_channels(16 * width_mult) - layers = [] - append!(layers, - conv_norm((3, 3), inchannels, inplanes, hardswish; stride = 2, pad = 1)) - # building inverted residual blocks - get_layers, block_repeats = mbconv_stack_builder(configs, inplanes; width_mult, - se_from_explanes = true, - se_round_fn = _round_channels) - append!(layers, cnn_stages(get_layers, block_repeats, +)) - # building last layers - explanes = _round_channels(configs[end][3] * width_mult) - midplanes = _round_channels(explanes * configs[end][4]) - append!(layers, conv_norm((1, 1), explanes, midplanes, hardswish)) - return Chain(Chain(layers...), - create_classifier(midplanes, max_width, nclasses, - (hardswish, identity); dropout_rate)) -end - # Layer configurations for small and large models for MobileNetv3 # f: mbconv block function - we use `mbconv` for all blocks # k: kernel size @@ -56,27 +8,57 @@ end # r: squeeze and excite reduction factor # a: activation function # Data is organised as (f, k, c, e, s, n, r, a) -const MOBILENETV3_CONFIGS = Dict(:small => [ - (mbconv, 3, 16, 1, 2, 1, 4, relu), - (mbconv, 3, 24, 4.5, 2, 1, nothing, relu), - (mbconv, 3, 24, 3.67, 1, 1, nothing, relu), - (mbconv, 5, 40, 4, 2, 1, 4, hardswish), - (mbconv, 5, 40, 6, 1, 2, 4, hardswish), - (mbconv, 5, 48, 3, 1, 2, 4, hardswish), - (mbconv, 5, 96, 6, 1, 3, 4, hardswish), - ], - :large => [ - (mbconv, 3, 16, 1, 1, 1, nothing, relu), - (mbconv, 3, 24, 4, 2, 1, nothing, relu), - (mbconv, 3, 24, 3, 1, 1, nothing, relu), - (mbconv, 5, 40, 3, 2, 1, 4, relu), - (mbconv, 5, 40, 3, 1, 2, 4, relu), - (mbconv, 3, 80, 6, 2, 1, nothing, hardswish), - (mbconv, 3, 80, 2.5, 1, 1, nothing, hardswish), - (mbconv, 3, 80, 2.3, 1, 2, nothing, hardswish), - (mbconv, 3, 112, 6, 1, 2, 4, hardswish), - (mbconv, 5, 160, 6, 1, 3, 4, hardswish), - ]) +const MOBILENETV3_CONFIGS = Dict(:small => (1024, + [ + (mbconv, 3, 16, 1, 2, 1, 4, relu), + (mbconv, 3, 24, 4.5, 2, 1, nothing, relu), + (mbconv, 3, 24, 3.67, 1, 1, nothing, relu), + (mbconv, 5, 40, 4, 2, 1, 4, hardswish), + (mbconv, 5, 40, 6, 1, 2, 4, hardswish), + (mbconv, 5, 48, 3, 1, 2, 4, hardswish), + (mbconv, 5, 96, 6, 1, 3, 4, hardswish), + ]), + :large => (1280, + [ + (mbconv, 3, 16, 1, 1, 1, nothing, relu), + (mbconv, 3, 24, 4, 2, 1, nothing, relu), + (mbconv, 3, 24, 3, 1, 1, nothing, relu), + (mbconv, 5, 40, 3, 2, 3, 4, relu), + (mbconv, 3, 80, 6, 2, 1, nothing, + hardswish), + (mbconv, 3, 80, 2.5, 1, 1, nothing, + hardswish), + (mbconv, 3, 80, 2.3, 1, 2, nothing, + hardswish), + (mbconv, 3, 112, 6, 1, 2, 4, hardswish), + (mbconv, 5, 160, 6, 1, 3, 4, hardswish), + ])) + +""" + mobilenetv3(config::Symbol; width_mult::Real = 1, dropout_prob = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + +Create a MobileNetv3 model with the specified configuration. +([reference](https://arxiv.org/abs/1905.02244)). + +# Arguments + + - `config`: The configuration of the model. Can be either `small` or `large`. + - `width_mult`: Controls the number of output feature maps in each block + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) + - `dropout_prob`: Dropout probability for the classifier head. Set to `nothing` to disable dropout. + - `inchannels`: The number of input channels. + - `nclasses`: The number of output classes. +""" +function mobilenetv3(config::Symbol; width_mult::Real = 1, dropout_prob = 0.2, + inchannels::Integer = 3, nclasses::Integer = 1000) + _checkconfig(config, [:small, :large]) + max_width, block_configs = MOBILENETV3_CONFIGS[config] + return build_invresmodel(width_mult, block_configs; inplanes = 16, + headplanes = max_width, activation = relu, + se_from_explanes = true, se_round_fn = _round_channels, + expanded_classifier = true, dropout_prob, inchannels, nclasses) +end """ MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, @@ -90,8 +72,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. - `config`: :small or :large for the size of the model (see paper). - `width_mult`: Controls the number of output feature maps in each block - (with 1 being the default in the paper; - this is usually a value between 0.1 and 1.4) + (with 1 being the default in the paper; this is usually a value between 0.1 and 1.4) - `pretrain`: whether to load the pre-trained weights for ImageNet - `inchannels`: number of input channels - `nclasses`: the number of output classes @@ -100,7 +81,7 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. `MobileNetv3` does not currently support pretrained weights. -See also [`mobilenetv3`](#). +See also [`Metalhead.mobilenetv3`](@ref). """ struct MobileNetv3 layers::Any @@ -109,10 +90,7 @@ end function MobileNetv3(config::Symbol; width_mult::Real = 1, pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) - _checkconfig(config, [:small, :large]) - max_width = config == :large ? 1280 : 1024 - layers = mobilenetv3(MOBILENETV3_CONFIGS[config]; width_mult, max_width, inchannels, - nclasses) + layers = mobilenetv3(config; width_mult, inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv3", config)) end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index dc143b57e..c5b2f1c7d 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -1,3 +1,5 @@ +## ResNet blocks + """ basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, reduction_factor::Integer = 1, activation = relu, @@ -6,6 +8,8 @@ attn_fn = planes -> identity) Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). +This function creates the layers. For more configuration options and to see the function +used to build the block for the model, see [`Metalhead.basicblock_builder`](@ref). # Arguments @@ -19,7 +23,7 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385 - `revnorm`: set to `true` to place the normalisation layer before the convolution - `drop_block`: the drop block layer - `drop_path`: the drop path layer - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](@ref) for an example. """ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, reduction_factor::Integer = 1, activation = relu, @@ -27,9 +31,9 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1, drop_block = identity, drop_path = identity, attn_fn = planes -> identity) first_planes = planes ÷ reduction_factor - conv_bn1 = conv_norm((3, 3), inplanes => first_planes, identity; norm_layer, revnorm, + conv_bn1 = conv_norm((3, 3), inplanes, first_planes, identity; norm_layer, revnorm, stride, pad = 1) - conv_bn2 = conv_norm((3, 3), first_planes => planes, identity; norm_layer, revnorm, + conv_bn2 = conv_norm((3, 3), first_planes, planes, identity; norm_layer, revnorm, pad = 1) layers = [conv_bn1..., drop_block, activation, conv_bn2..., attn_fn(planes), drop_path] @@ -45,6 +49,8 @@ end attn_fn = planes -> identity) Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)). +This function creates the layers. For more configuration options and to see the function +used to build the block for the model, see [`Metalhead.bottleneck_builder`](@ref). # Arguments @@ -60,7 +66,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512. - `revnorm`: set to `true` to place the normalisation layer before the convolution - `drop_block`: the drop block layer - `drop_path`: the drop path layer - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. + - `attn_fn`: the attention function to use. See [`squeeze_excite`](@ref) for an example. """ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, cardinality::Integer = 1, base_width::Integer = 64, @@ -71,65 +77,153 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer, width = fld(planes * base_width, 64) * cardinality first_planes = width ÷ reduction_factor outplanes = planes * 4 - conv_bn1 = conv_norm((1, 1), inplanes => first_planes, activation; norm_layer, + conv_bn1 = conv_norm((1, 1), inplanes, first_planes, activation; norm_layer, revnorm) - conv_bn2 = conv_norm((3, 3), first_planes => width, identity; norm_layer, revnorm, + conv_bn2 = conv_norm((3, 3), first_planes, width, identity; norm_layer, revnorm, stride, pad = 1, groups = cardinality) - conv_bn3 = conv_norm((1, 1), width => outplanes, identity; norm_layer, revnorm) + conv_bn3 = conv_norm((1, 1), width, outplanes, identity; norm_layer, revnorm) layers = [conv_bn1..., conv_bn2..., drop_block, activation, conv_bn3..., attn_fn(outplanes), drop_path] return Chain(filter!(!=(identity), layers)...) end -# Downsample layer using convolutions. +""" + bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, + cardinality::Integer = 1, base_width::Integer = 26, + scale::Integer = 4, activation = relu, norm_layer = BatchNorm, + revnorm::Bool = false, attn_fn = planes -> identity) + +Creates a bottleneck block as described in the Res2Net paper. ([reference](https://arxiv.org/abs/1904.01169)) +This function creates the layers. For more configuration options and to see the function +used to build the block for the model, see [`Metalhead.bottle2neck_builder`](@ref). + +# Arguments + + - `inplanes`: number of input feature maps + - `planes`: number of feature maps for the block + - `stride`: the stride of the block + - `cardinality`: the number of groups in the 3x3 convolutions. + - `base_width`: the number of output feature maps for each convolutional group. + - `scale`: the number of feature groups in the block. See the [paper](https://arxiv.org/abs/1904.01169) + for more details. + - `activation`: the activation function to use. + - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the batch norm before the convolution + - `attn_fn`: the attention function to use. See [`squeeze_excite`](@ref) for an example. +""" +function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, + cardinality::Integer = 1, base_width::Integer = 26, + scale::Integer = 4, activation = relu, is_first::Bool = false, + norm_layer = BatchNorm, revnorm::Bool = false, + attn_fn = planes -> identity) + width = fld(planes * base_width, 64) * cardinality + outplanes = planes * 4 + pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity + conv_bns = [Chain(conv_norm((3, 3), width, width, activation; norm_layer, stride, + pad = 1, groups = cardinality)...) + for _ in 1:max(1, scale - 1)] + reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : + Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) + tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) + layers = [ + conv_norm((1, 1), inplanes, width * scale, activation; + norm_layer, revnorm)..., + chunk$(; size = width, dims = 3), tuplify, reslayer, + conv_norm((1, 1), width * scale, outplanes, activation; + norm_layer, revnorm)..., + attn_fn(outplanes), + ] + return Chain(filter!(!=(identity), layers)...) +end + +## Downsample layers + +""" + downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm, revnorm::Bool = false) + +Creates a 1x1 convolutional downsample layer as used in ResNet. + +# Arguments + + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `stride`: the stride of the convolution + - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the normalisation layer before the convolution +""" function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false) - return Chain(conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, revnorm, + return Chain(conv_norm((1, 1), inplanes, outplanes, identity; norm_layer, revnorm, pad = SamePad(), stride)...) end -# Downsample layer using max pooling +""" + downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, + norm_layer = BatchNorm, revnorm::Bool = false) + +Creates a pooling-based downsample layer as described in the +[Bag of Tricks](https://arxiv.org/abs/1812.01187v1) paper. This adds an average pooling layer +of size `(2, 2)` with `stride` followed by a 1x1 convolution. + +# Arguments + + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `stride`: the stride of the convolution + - `norm_layer`: the normalization layer to use. + - `revnorm`: set to `true` to place the normalisation layer before the convolution +""" function downsample_pool(inplanes::Integer, outplanes::Integer; stride::Integer = 1, norm_layer = BatchNorm, revnorm::Bool = false) pool = stride == 1 ? identity : MeanPool((2, 2); stride, pad = SamePad()) return Chain(pool, - conv_norm((1, 1), inplanes => outplanes, identity; norm_layer, + conv_norm((1, 1), inplanes, outplanes, identity; norm_layer, revnorm)...) end -# Downsample layer which is an identity projection. Uses max pooling -# when the output size is more than the input size. # TODO - figure out how to make this work when outplanes < inplanes +""" + downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) + +Creates an identity downsample layer. This returns `identity` if `inplanes == outplanes`. +If `outplanes > inplanes`, it maps the input to `outplanes` channels using a 1x1 max pooling +layer and zero padding. + +!!! warning + + This does not currently support the scenario where `inplanes > outplanes`. + +# Arguments + + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + +Note that kwargs are ignored and only included for compatibility with other downsample layers. +""" function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...) if outplanes > inplanes return Chain(MaxPool((1, 1); stride = 2), y -> cat_channels(y, - zeros(eltype(y), - size(y, 1), - size(y, 2), + zeros(eltype(y), size(y, 1), size(y, 2), outplanes - inplanes, size(y, 4)))) else return identity end end -# Shortcut configurations for the ResNet models +# Shortcut configurations for the ResNet variants const RESNET_SHORTCUTS = Dict(:A => (downsample_identity, downsample_identity), :B => (downsample_conv, downsample_identity), :C => (downsample_conv, downsample_conv), :D => (downsample_pool, downsample_identity)) -# Stride for each block in the ResNet model -function resnet_stride(stage_idx::Integer, block_idx::Integer) - return stage_idx == 1 || block_idx != 1 ? 1 : 2 -end - # returns `DropBlock`s for each stage of the ResNet as in timm. # TODO - add experimental options for DropBlock as part of the API (#188) -# function _drop_blocks(drop_block_rate::AbstractFloat) +# function _drop_blocks(dropblock_prob::AbstractFloat) # return [ # identity, identity, -# DropBlock(drop_block_rate, 5, 0.25), DropBlock(drop_block_rate, 3, 1.00), +# DropBlock(dropblock_prob, 5, 0.25), DropBlock(dropblock_prob, 3, 1.00), # ] # end @@ -137,7 +231,7 @@ end resnet_stem(; stem_type = :default, inchannels::Integer = 3, replace_stem_pool = false, norm_layer = BatchNorm, activation = relu) -Builds a stem to be used in a ResNet model. See the `stem` argument of [`resnet`](#) for details +Builds a stem to be used in a ResNet model. See the `stem` argument of [`resnet`](@ref) for details on how to use this function. # Arguments @@ -147,9 +241,9 @@ on how to use this function. + `:default`: Builds a stem based on the default ResNet stem, which consists of a single 7x7 convolution with stride 2 and a normalisation layer followed by a 3x3 max pooling layer with stride 2. - + `:deep`: This borrows ideas from other papers (InceptionResNet-v2, for example) in using - a deeper stem with 3 successive 3x3 convolutions having normalisation layers after each - one. This is followed by a 3x3 max pooling layer with stride 2. + + `:deep`: This borrows ideas from other papers ([InceptionResNetv2](https://arxiv.org/abs/1602.07261), + for example) in using a deeper stem with 3 successive 3x3 convolutions having normalisation + layers after each one. This is followed by a 3x3 max pooling layer with stride 2. + `:deep_tiered`: A variant of the `:deep` stem that has a larger width in the second convolution. This is an experimental variant from the `timm` library in Python that shows peformance improvements over the `:deep` stem in some cases. @@ -163,20 +257,27 @@ on how to use this function. function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, replace_pool::Bool = false, activation = relu, norm_layer = BatchNorm, revnorm::Bool = false) - _checkconfig(stem_type, [:default, :deep, :deep_tiered]) + # Check for valid stem types + deep_stem = if stem_type === :deep || stem_type === :deep_tiered + true + elseif stem_type === :default + false + else + throw(ArgumentError("Unsupported stem type $stem_type. Must be one of + [:default, :deep, :deep_tiered]")) + end # Main stem - deep_stem = stem_type == :deep || stem_type == :deep_tiered inplanes = deep_stem ? stem_width * 2 : 64 # Deep stem that uses three successive 3x3 convolutions instead of a single 7x7 convolution if deep_stem - if stem_type == :deep + if stem_type === :deep stem_channels = (stem_width, stem_width) - elseif stem_type == :deep_tiered + elseif stem_type === :deep_tiered stem_channels = (3 * (stem_width ÷ 4), stem_width) end - conv1 = Chain(conv_norm((3, 3), inchannels => stem_channels[1], activation; + conv1 = Chain(conv_norm((3, 3), inchannels, stem_channels[1], activation; norm_layer, revnorm, stride = 2, pad = 1)..., - conv_norm((3, 3), stem_channels[1] => stem_channels[2], activation; + conv_norm((3, 3), stem_channels[1], stem_channels[2], activation; norm_layer, pad = 1)..., Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else @@ -185,61 +286,137 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3, bn1 = norm_layer(inplanes, activation) # Stem pooling stempool = replace_pool ? - Chain(conv_norm((3, 3), inplanes => inplanes, activation; norm_layer, + Chain(conv_norm((3, 3), inplanes, inplanes, activation; norm_layer, revnorm, stride = 2, pad = 1)...) : MaxPool((3, 3); stride = 2, pad = 1) return Chain(conv1, bn1, stempool) end +# Callbacks for channel and stride calculations for each block in a ResNet + +""" + resnet_planes(block_repeats::AbstractVector{<:Integer}) + +Default callback for determining the number of channels in each block in a ResNet model. + +# Arguments + +`block_repeats`: A `Vector` of integers specifying the number of times each block is repeated +in each stage of the ResNet model. For example, `[3, 4, 6, 3]` is the configuration used in +ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, 6 blocks in the +third stage and 3 blocks in the fourth stage. +""" function resnet_planes(block_repeats::AbstractVector{<:Integer}) return Iterators.flatten((64 * 2^(stage_idx - 1) for _ in 1:stages) for (stage_idx, stages) in enumerate(block_repeats)) end -function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Integer}, - connection, classifier_fn) - # Build stages of the ResNet - stage_blocks = cnn_stages(get_layers, block_repeats, connection) - backbone = Chain(stem, stage_blocks...) - # Add classifier to the backbone - nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3] - return Chain(backbone, classifier_fn(nfeaturemaps)) +""" + resnet_stride(stage_idx::Integer, block_idx::Integer) + +Default callback for determining the stride of a block in a ResNet model. +Returns `2` for the first block in every stage except the first stage and `1` for all other +blocks. + +# Arguments + + - `stage_idx`: The index of the stage in the ResNet model. + - `block_idx`: The index of the block in the stage. +""" +function resnet_stride(stage_idx::Integer, block_idx::Integer) + return stage_idx == 1 || block_idx != 1 ? 1 : 2 end +""" + resnet(block_type, block_repeats::AbstractVector{<:Integer}, + downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); + cardinality::Integer = 1, base_width::Integer = 64, + inplanes::Integer = 64, reduction_factor::Integer = 1, + connection = addact, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, + attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), + use_conv::Bool = false, dropblock_prob = nothing, + stochastic_depth_prob = nothing, dropout_prob = nothing, + imsize::Dims{2} = (256, 256), inchannels::Integer = 3, + nclasses::Integer = 1000, kwargs...) + +Creates a generic ResNet-like model that is used to create the higher level models like ResNet, +Wide ResNet, ResNeXt and Res2Net. For an _even_ more generic model API, see [`Metalhead.build_resnet`](@ref). + +# Arguments + + - `block_type`: The type of block to be used in the model. This can be one of [`Metalhead.basicblock`](@ref), + [`Metalhead.bottleneck`](@ref) and [`Metalhead.bottle2neck`](@ref). `basicblock` is used in the + original ResNet paper for ResNet-18 and ResNet-34, and `bottleneck` is used in the original ResNet-50 + and ResNet-101 models, as well as for the Wide ResNet and ResNeXt models. `bottle2neck` is introduced in + the `Res2Net` paper. + - `block_repeats`: A `Vector` of integers specifying the number of times each block is repeated + in each stage of the ResNet model. For example, `[3, 4, 6, 3]` is the configuration used in + ResNet-50, which has 3 blocks in the first stage, 4 blocks in the second stage, 6 blocks in the + third stage and 3 blocks in the fourth stage. + - `downsample_opt`: A `NTuple` of two callbacks that are used to determine the downsampling + operation to be used in the model. The first callback is used to determine the convolutional + operation to be used in the downsampling operation and the second callback is used to determine + the identity operation to be used in the downsampling operation. + - `cardinality`: The number of groups to be used in the 3x3 convolutional layer in the bottleneck + block. This is usually modified from the default value of `1` in the ResNet models to `32` or `64` + in the `ResNeXt` models. + - `base_width`: The base width of the convolutional layer in the blocks of the model. + - `inplanes`: The number of input channels in the first convolutional layer. + - `reduction_factor`: The reduction factor used in the model. + - `connection`: This is a function that determines the residual connection in the model. For + `resnets`, either of [`Metalhead.addact`](@ref) or [`Metalhead.actadd`](@ref) is recommended. + - `norm_layer`: The normalisation layer to be used in the model. + - `revnorm`: set to `true` to place the normalisation layers before the convolutions + - `attn_fn`: A callback that is used to determine the attention function to be used in the model. + See [`Metalhead.Layers.squeeze_excite`](@ref) for an example. + - `pool_layer`: A fully-instantiated pooling layer passed in to be used by the classifier head. + For example, `AdaptiveMeanPool((1, 1))` is used in the ResNet family by default, but something + like `MeanPool((3, 3))` should also work provided the dimensions after applying the pooling + layer are compatible with the rest of the classifier head. + - `use_conv`: Set to true to use convolutions instead of identity operations in the model. + - `dropblock_prob`: `DropBlock` probability to be used in the model. Set to `nothing` to disable + DropBlock. See [`Metalhead.DropBlock`](@ref) for more details. + - `stochastic_depth_prob`: `StochasticDepth` probability to be used in the model. Set to `nothing` to disable + StochasticDepth. See [`Metalhead.StochasticDepth`](@ref) for more details. + - `dropout_prob`: `Dropout` probability to be used in the classifier head. Set to `nothing` to + disable Dropout. +""" function resnet(block_type, block_repeats::AbstractVector{<:Integer}, downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity); cardinality::Integer = 1, base_width::Integer = 64, - inplanes::Integer = 64, - reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256), - inchannels::Integer = 3, stem_fn = resnet_stem, connection = addact, - activation = relu, norm_layer = BatchNorm, revnorm::Bool = false, + inplanes::Integer = 64, reduction_factor::Integer = 1, + connection = addact, activation = relu, + norm_layer = BatchNorm, revnorm::Bool = false, attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool((1, 1)), - use_conv::Bool = false, drop_block_rate = nothing, drop_path_rate = nothing, - dropout_rate = nothing, nclasses::Integer = 1000, kwargs...) + use_conv::Bool = false, dropblock_prob = nothing, + stochastic_depth_prob = nothing, dropout_prob = nothing, + imsize::Dims{2} = (256, 256), inchannels::Integer = 3, + nclasses::Integer = 1000, kwargs...) # Build stem - stem = stem_fn(; inchannels) + stem = resnet_stem(; inchannels) # Block builder if block_type == basicblock @assert cardinality==1 "Cardinality must be 1 for `basicblock`" @assert base_width==64 "Base width must be 64 for `basicblock`" get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, activation, norm_layer, revnorm, attn_fn, - drop_block_rate, drop_path_rate, + dropblock_prob, stochastic_depth_prob, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, - attn_fn, drop_block_rate, drop_path_rate, + attn_fn, dropblock_prob, stochastic_depth_prob, stride_fn = resnet_stride, planes_fn = resnet_planes, downsample_tuple = downsample_opt, kwargs...) elseif block_type == bottle2neck - @assert isnothing(drop_block_rate) "DropBlock not supported for `bottle2neck`. - Set `drop_block_rate` to nothing." - @assert isnothing(drop_path_rate) "DropPath not supported for `bottle2neck`. - Set `drop_path_rate` to nothing." + @assert isnothing(dropblock_prob) "DropBlock not supported for `bottle2neck`. + Set `dropblock_prob` to nothing." + @assert isnothing(stochastic_depth_prob) "StochasticDepth not supported for `bottle2neck`. + Set `stochastic_depth_prob` to nothing." @assert reduction_factor==1 "Reduction factor not supported for `bottle2neck`. Set `reduction_factor` to 1." get_layers = bottle2neck_builder(block_repeats; inplanes, cardinality, base_width, @@ -251,10 +428,10 @@ function resnet(block_type, block_repeats::AbstractVector{<:Integer}, # TODO: write better message when we have link to dev docs for resnet throw(ArgumentError("Unknown block type $block_type")) end - classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate, + classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_prob, pool_layer, use_conv) - return resnet((imsize..., inchannels), stem, get_layers, block_repeats, - connection$activation, classifier_fn) + return build_resnet((imsize..., inchannels), stem, get_layers, block_repeats, + connection$activation, classifier_fn) end function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) return resnet(block_fn, block_repeats, RESNET_SHORTCUTS[downsample_opt]; kwargs...) diff --git a/src/convnets/resnets/res2net.jl b/src/convnets/resnets/res2net.jl index e308e1125..33b9fb961 100644 --- a/src/convnets/resnets/res2net.jl +++ b/src/convnets/resnets/res2net.jl @@ -1,81 +1,3 @@ -""" - bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, - cardinality::Integer = 1, base_width::Integer = 26, - scale::Integer = 4, activation = relu, norm_layer = BatchNorm, - revnorm::Bool = false, attn_fn = planes -> identity) - -Creates a bottleneck block as described in the Res2Net paper. -([reference](https://arxiv.org/abs/1904.01169)) - -# Arguments - - - `inplanes`: number of input feature maps - - `planes`: number of feature maps for the block - - `stride`: the stride of the block - - `cardinality`: the number of groups in the 3x3 convolutions. - - `base_width`: the number of output feature maps for each convolutional group. - - `scale`: the number of feature groups in the block. See the [paper](https://arxiv.org/abs/1904.01169) - for more details. - - `activation`: the activation function to use. - - `norm_layer`: the normalization layer to use. - - `revnorm`: set to `true` to place the batch norm before the convolution - - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example. -""" -function bottle2neck(inplanes::Integer, planes::Integer; stride::Integer = 1, - cardinality::Integer = 1, base_width::Integer = 26, - scale::Integer = 4, activation = relu, is_first::Bool = false, - norm_layer = BatchNorm, revnorm::Bool = false, - attn_fn = planes -> identity) - width = fld(planes * base_width, 64) * cardinality - outplanes = planes * 4 - pool = is_first && scale > 1 ? MeanPool((3, 3); stride, pad = 1) : identity - conv_bns = [Chain(conv_norm((3, 3), width => width, activation; norm_layer, stride, - pad = 1, groups = cardinality)...) - for _ in 1:max(1, scale - 1)] - reslayer = is_first ? Parallel(cat_channels, pool, conv_bns...) : - Parallel(cat_channels, identity, Chain(PairwiseFusion(+, conv_bns...))) - tuplify = is_first ? x -> tuple(x...) : x -> tuple(x[1], tuple(x[2:end]...)) - layers = [ - conv_norm((1, 1), inplanes => width * scale, activation; - norm_layer, revnorm)..., - chunk$(; size = width, dims = 3), tuplify, reslayer, - conv_norm((1, 1), width * scale => outplanes, activation; - norm_layer, revnorm)..., - attn_fn(outplanes), - ] - return Chain(filter(!=(identity), layers)...) -end - -function bottle2neck_builder(block_repeats::AbstractVector{<:Integer}; - inplanes::Integer = 64, cardinality::Integer = 1, - base_width::Integer = 26, scale::Integer = 4, - expansion::Integer = 4, norm_layer = BatchNorm, - revnorm::Bool = false, activation = relu, - attn_fn = planes -> identity, - stride_fn = resnet_stride, planes_fn = resnet_planes, - downsample_tuple = (downsample_conv, downsample_identity)) - planes_vec = collect(planes_fn(block_repeats)) - # closure over `idxs` - function get_layers(stage_idx::Integer, block_idx::Integer) - # This is needed for block `inplanes` and `planes` calculations - schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx - planes = planes_vec[schedule_idx] - inplanes = schedule_idx == 1 ? inplanes : planes_vec[schedule_idx - 1] * expansion - # `resnet_stride` is a callback that the user can tweak to change the stride of the - # blocks. It defaults to the standard behaviour as in the paper - stride = stride_fn(stage_idx, block_idx) - downsample_fn = (stride != 1 || inplanes != planes * expansion) ? - downsample_tuple[1] : downsample_tuple[2] - is_first = (stride > 1 || downsample_fn != downsample_tuple[2]) ? true : false - block = bottle2neck(inplanes, planes; stride, cardinality, base_width, scale, - activation, is_first, norm_layer, revnorm, attn_fn) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, - revnorm) - return block, downsample - end - return get_layers -end - """ Res2Net(depth::Integer; pretrain::Bool = false, scale::Integer = 4, base_width::Integer = 26, inchannels::Integer = 3, @@ -93,6 +15,12 @@ Creates a Res2Net model with the specified depth, scale, and base width. - `base_width`: the number of feature maps in each group. - `inchannels`: the number of input channels. - `nclasses`: the number of output classes + +!!! warning + + `Res2Net` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct Res2Net layers::Any @@ -134,6 +62,12 @@ Creates a Res2NeXt model with the specified depth, scale, base width and cardina - `cardinality`: the number of groups in the 3x3 convolutions. - `inchannels`: the number of input channels. - `nclasses`: the number of output classes + +!!! warning + + `Res2NeXt` does not currently support pretrained weights. + +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct Res2NeXt layers::Any diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index b65b71072..a8c12eda8 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -11,7 +11,7 @@ Creates a ResNet model with the specified depth. - `inchannels`: The number of input channels. - `nclasses`: the number of output classes -Advanced users who want more configuration options will be better served by using [`resnet`](#). +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct ResNet layers::Any @@ -48,7 +48,7 @@ The number of channels in outer 1x1 convolutions is the same. - `inchannels`: The number of input channels. - `nclasses`: The number of output classes -Advanced users who want more configuration options will be better served by using [`resnet`](#). +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct WideResNet layers::Any diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 2a8fbd561..6d65708d0 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -20,15 +20,13 @@ Creates a ResNeXt model with the specified depth, cardinality, and base width. - `inchannels`: the number of input channels. - `nclasses`: the number of output classes -Advanced users who want more configuration options will be better served by using [`resnet`](#). +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct ResNeXt layers::Any end @functor ResNeXt -(m::ResNeXt)(x) = m.layers(x) - function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = 32, base_width::Integer = 4, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(depth, keys(LRESNET_CONFIGS)) @@ -41,5 +39,7 @@ function ResNeXt(depth::Integer; pretrain::Bool = false, cardinality::Integer = return ResNeXt(layers) end +(m::ResNeXt)(x) = m.layers(x) + backbone(m::ResNeXt) = m.layers[1] classifier(m::ResNeXt) = m.layers[2] diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 487665518..44e32083d 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -15,7 +15,7 @@ Creates a SEResNet model with the specified depth. `SEResNet` does not currently support pretrained weights. -Advanced users who want more configuration options will be better served by using [`resnet`](#). +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct SEResNet layers::Any @@ -58,7 +58,7 @@ Creates a SEResNeXt model with the specified depth, cardinality, and base width. `SEResNeXt` does not currently support pretrained weights. -Advanced users who want more configuration options will be better served by using [`resnet`](#). +Advanced users who want more configuration options will be better served by using [`resnet`](@ref). """ struct SEResNeXt layers::Any diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index 5c688a645..5d63a565d 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -21,17 +21,18 @@ function fire(inplanes::Integer, squeeze_planes::Integer, expand1x1_planes::Inte end """ - squeezenet(; inchannels::Integer = 3, nclasses::Integer = 1000) + squeezenet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) -Create a SqueezeNet +Create a SqueezeNet model. ([reference](https://arxiv.org/abs/1602.07360v4)). # Arguments + - `dropout_prob`: dropout probability for the classifier head. Set to `nothing` to disable dropout. - `inchannels`: number of input channels. - `nclasses`: the number of output classes. """ -function squeezenet(; inchannels::Integer = 3, nclasses::Integer = 1000) +function squeezenet(; dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) backbone = Chain(Conv((3, 3), inchannels => 64, relu; stride = 2), MaxPool((3, 3); stride = 2), fire(64, 16, 64, 64), @@ -44,14 +45,14 @@ function squeezenet(; inchannels::Integer = 3, nclasses::Integer = 1000) fire(384, 48, 192, 192), fire(384, 64, 256, 256), fire(512, 64, 256, 256)) - classifier = Chain(Dropout(0.5), Conv((1, 1), 512 => nclasses, relu), + classifier = Chain(Dropout(dropout_prob), Conv((1, 1), 512 => nclasses, relu), AdaptiveMeanPool((1, 1)), MLUtils.flatten) return Chain(backbone, classifier) end """ SqueezeNet(; pretrain::Bool = false, inchannels::Integer = 3, - nclasses::Integer = 1000) + nclasses::Integer = 1000) Create a SqueezeNet ([reference](https://arxiv.org/abs/1602.07360v4)). @@ -62,7 +63,7 @@ Create a SqueezeNet - `inchannels`: number of input channels. - `nclasses`: the number of output classes. -See also [`squeezenet`](#). +See also [`Metalhead.squeezenet`](@ref). """ struct SqueezeNet layers::Any diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 163c13b68..35d01e0b5 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -36,7 +36,7 @@ Create VGG convolution layers # Arguments - `config`: vector of tuples `(output_channels, num_convolutions)` - for each block (see [`Metalhead.vgg_block`](#)) + for each block (see [`Metalhead.vgg_block`](@ref)) - `batchnorm`: set to `true` to include batch normalization after each convolution - `inchannels`: number of input channels """ @@ -53,7 +53,7 @@ function vgg_convolutional_layers(config::AbstractVector{<:Tuple}, batchnorm::Bo end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) + vgg_classifier_layers(imsize, nclasses, fcsize, dropout_prob) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -61,23 +61,23 @@ Create VGG classifier (fully connected) layers # Arguments - `imsize`: tuple `(width, height, channels)` indicating the size after - the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) + the convolution layers (see [`Metalhead.vgg_convolutional_layers`](@ref)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `dropout_rate`: the dropout level between each fully connected layer + - `dropout_prob`: the dropout level between each fully connected layer """ function vgg_classifier_layers(imsize::NTuple{3, <:Integer}, nclasses::Integer, - fcsize::Integer, dropout_rate) + fcsize::Integer, dropout_prob) return Chain(MLUtils.flatten, Dense(prod(imsize), fcsize, relu), - Dropout(dropout_rate), + Dropout(dropout_prob), Dense(fcsize, fcsize, relu), - Dropout(dropout_rate), + Dropout(dropout_prob), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_prob) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -86,19 +86,19 @@ Create a VGG model - `imsize`: input image width and height as a tuple - `config`: the configuration for the convolution layers - (see [`Metalhead.vgg_convolutional_layers`](#)) + (see [`Metalhead.vgg_convolutional_layers`](@ref)) - `inchannels`: number of input channels - `batchnorm`: set to `true` to use batch normalization after each convolution - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size - (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout_rate`: dropout level between fully connected layers + (see [`Metalhead.vgg_classifier_layers`](@ref)) + - `dropout_prob`: dropout level between fully connected layers """ function vgg(imsize::Dims{2}; config, batchnorm::Bool = false, fcsize::Integer = 4096, - dropout_rate = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) + dropout_prob = 0.0, inchannels::Integer = 3, nclasses::Integer = 1000) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout_rate) + class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout_prob) return Chain(Chain(conv...), class) end @@ -110,7 +110,7 @@ const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512 const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E) """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_rate) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout_prob) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -122,17 +122,17 @@ Construct a VGG model with the specified input image size. Typically, the image - `batchnorm`: set to `true` to use batch normalization after each convolution - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size - (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout_rate`: dropout level between fully connected layers + (see [`Metalhead.vgg_classifier_layers`](@ref)) + - `dropout_prob`: dropout level between fully connected layers """ struct VGG layers::Any end @functor VGG -function VGG(imsize::Dims{2}; config, batchnorm::Bool = false, dropout_rate = 0.5, +function VGG(imsize::Dims{2}; config, batchnorm::Bool = false, dropout_prob = 0.5, inchannels::Integer = 3, nclasses::Integer = 1000) - layers = vgg(imsize; config, inchannels, batchnorm, nclasses, dropout_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, dropout_prob) return VGG(layers) end @@ -156,7 +156,7 @@ Create a VGG style model with specified `depth`. - `inchannels`: number of input channels - `nclasses`: number of output classes -See also [`vgg`](#). +See also [`vgg`](@ref). """ function VGG(depth::Integer; pretrain::Bool = false, batchnorm::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 9bdf1f913..423ff41ec 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -19,10 +19,10 @@ include("attention.jl") export MHAttention include("conv.jl") -export conv_norm, basic_conv_bn, dwsep_conv_bn +export conv_norm, basic_conv_bn, dwsep_conv_norm include("drop.jl") -export DropBlock, DropPath +export DropBlock, StochasticDepth include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens diff --git a/src/layers/attention.jl b/src/layers/attention.jl index b8fd38165..b8560f12f 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,6 +1,6 @@ """ MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_dropout_rate = 0., proj_dropout_rate = 0.) + attn_dropout_prob = 0., proj_dropout_prob = 0.) Multi-head self-attention layer. @@ -9,8 +9,8 @@ Multi-head self-attention layer. - `planes`: number of input channels - `nheads`: number of heads - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_dropout_rate`: dropout rate after the self-attention layer - - `proj_dropout_rate`: dropout rate after the projection layer + - `attn_dropout_prob`: dropout probability after the self-attention layer + - `proj_dropout_prob`: dropout probability after the projection layer """ struct MHAttention{P, Q, R} nheads::Int @@ -21,11 +21,11 @@ end @functor MHAttention function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_dropout_rate = 0.0, proj_dropout_rate = 0.0) + attn_dropout_prob = 0.0, proj_dropout_prob = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_dropout_rate) - proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate)) + attn_drop = Dropout(attn_dropout_prob) + proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob)) return MHAttention(nheads, qkv_layer, attn_drop, proj) end diff --git a/src/layers/classifier.jl b/src/layers/classifier.jl index bebdc4099..90ab5ba33 100644 --- a/src/layers/classifier.jl +++ b/src/layers/classifier.jl @@ -1,7 +1,7 @@ """ create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) + dropout_prob = nothing) Creates a classifier head to be used for models. @@ -13,11 +13,11 @@ Creates a classifier head to be used for models. - `use_conv`: whether to use a 1x1 convolutional layer instead of a `Dense` layer. - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. + - `dropout_prob`: dropout probability used in the classifier head. Set to `nothing` to disable dropout. """ function create_classifier(inplanes::Integer, nclasses::Integer, activation = identity; use_conv::Bool = false, pool_layer = AdaptiveMeanPool((1, 1)), - dropout_rate = nothing) + dropout_prob = nothing) # Decide whether to flatten the input or not flatten_in_pool = !use_conv && pool_layer !== identity if use_conv @@ -31,7 +31,7 @@ function create_classifier(inplanes::Integer, nclasses::Integer, activation = id push!(classifier, pool_layer) end # Dropout is applied after the pooling layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + isnothing(dropout_prob) ? nothing : push!(classifier, Dropout(dropout_prob)) # Fully-connected layer if use_conv push!(classifier, Conv((1, 1), inplanes => nclasses, activation)) @@ -45,7 +45,7 @@ end create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, activations::NTuple{2} = (relu, identity); use_conv::NTuple{2, Bool} = (false, false), - pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + pool_layer = AdaptiveMeanPool((1, 1)), dropout_prob = nothing) Creates a classifier head to be used for models with an extra hidden layer. @@ -62,12 +62,12 @@ Creates a classifier head to be used for models with an extra hidden layer. layer. - `pool_layer`: pooling layer to use. This is passed in with the layer instantiated with any arguments that are needed i.e. as `AdaptiveMeanPool((1, 1))`, for example. - - `dropout_rate`: dropout rate used in the classifier head. Set to `nothing` to disable dropout. + - `dropout_prob`: dropout probability used in the classifier head. Set to `nothing` to disable dropout. """ function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses::Integer, activations::NTuple{2, Any} = (relu, identity); use_conv::NTuple{2, Bool} = (false, false), - pool_layer = AdaptiveMeanPool((1, 1)), dropout_rate = nothing) + pool_layer = AdaptiveMeanPool((1, 1)), dropout_prob = nothing) fc_layers = [uc ? Conv$(1, 1) : Dense for uc in use_conv] # Decide whether to flatten the input or not flatten_in_pool = !use_conv[1] && pool_layer !== identity @@ -86,7 +86,7 @@ function create_classifier(inplanes::Integer, hidden_planes::Integer, nclasses:: push!(classifier, fc_layers[1](inplanes => hidden_planes, activations[1])) end # Dropout is applied after the first dense layer - isnothing(dropout_rate) ? nothing : push!(classifier, Dropout(dropout_rate)) + isnothing(dropout_prob) ? nothing : push!(classifier, Dropout(dropout_prob)) # second fully-connected layer push!(classifier, fc_layers[2](hidden_planes => nclasses, activations[2])) return Chain(classifier...) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index e49611280..cbee8ef03 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1,14 +1,10 @@ """ conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, - eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, - stride::Integer = 1, pad::Integer = 0, dilation::Integer = 1, - groups::Integer = 1, [bias, weight, init]) + preact::Bool = false, stride::Integer = 1, pad::Integer = 0, + dilation::Integer = 1, groups::Integer = 1, [bias, weight, init]) - conv_norm(kernel_size::Dims{2}, inplanes => outplanes, activation = identity; - kwargs...) - -Create a convolution + batch normalization pair with activation. +Create a convolution + normalisation layer pair with activation. # Arguments @@ -16,28 +12,28 @@ Create a convolution + batch normalization pair with activation. - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `norm_layer`: the normalization layer used + - `norm_layer`: the normalisation layer used. Note that using `identity` as the normalisation + layer will result in no normalisation being applied. (This is only compatible with `preact` + and `revnorm` both set to `false`.) - `revnorm`: set to `true` to place the normalisation layer before the convolution - - `preact`: set to `true` to place the activation function before the batch norm + - `preact`: set to `true` to place the activation function before the normalisation layer (only compatible with `revnorm = false`) - - `use_norm`: set to `false` to disable normalization - (only compatible with `revnorm = false` and `preact = false`) + - `bias`: bias for the convolution kernel. This is set to `false` by default if + `norm_layer` is not `identity` and `true` otherwise. - `stride`: stride of the convolution kernel - `pad`: padding of the convolution kernel - `dilation`: dilation of the convolution kernel - `groups`: groups for the convolution kernel - - `bias`: bias for the convolution kernel. This is set to `false` by default if - `use_norm = true`. - - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](@ref)) """ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, activation = relu; norm_layer = BatchNorm, revnorm::Bool = false, - eps::Float32 = 1.0f-5, preact::Bool = false, use_norm::Bool = true, - bias = !use_norm, kwargs...) + preact::Bool = false, bias = !(norm_layer !== identity), kwargs...) # no normalization layer - if !use_norm + if !(norm_layer !== identity) if preact || revnorm - throw(ArgumentError("`preact` only supported with `use_norm = true`")) + throw(ArgumentError("`preact` only supported with `norm_layer !== identity`. + Check if a non-`identity` norm layer is intended.")) else # early return if no norm layer is required return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)] @@ -45,10 +41,10 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, end # channels for norm layer and activation functions for both conv and norm if revnorm - activations = (conv = activation, bn = identity) + activations = (conv = activation, norm = identity) normplanes = inplanes else - activations = (conv = identity, bn = activation) + activations = (conv = identity, norm = activation) normplanes = outplanes end # handle pre-activation @@ -56,24 +52,34 @@ function conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, if revnorm throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time")) else - activations = (conv = activation, bn = identity) + activations = (conv = activation, norm = identity) end end # layers layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; bias, kwargs...), - norm_layer(normplanes, activations.bn; ϵ = eps)] + norm_layer(normplanes, activations.norm)] return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size::Dims{2}, ch::Pair{<:Integer, <:Integer}, - activation = identity; kwargs...) - inplanes, outplanes = ch - return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) -end -# conv + bn layer combination as used by the inception model family matching -# the default values used in TensorFlow +""" + basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; + kwargs...) + +Returns a convolution + batch normalisation pair with activation as used by the +Inception family of models with default values matching those used in the official +TensorFlow implementation. + +# Arguments + + - `kernel_size`: size of the convolution kernel (tuple) + - `inplanes`: number of input feature maps + - `outplanes`: number of output feature maps + - `activation`: the activation function for the final layer + - `kwargs`: keyword arguments passed to [`conv_norm`](@ref) +""" function basic_conv_bn(kernel_size::Dims{2}, inplanes, outplanes, activation = relu; kwargs...) - return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer = BatchNorm, - eps = 1.0f-3, kwargs...) + # TensorFlow uses a default epsilon of 1e-3 for BatchNorm + norm_layer = (args...; kwargs...) -> BatchNorm(args...; ϵ = 1.0f-3, kwargs...) + return conv_norm(kernel_size, inplanes, outplanes, activation; norm_layer, kwargs...) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index 8a82d5b16..f7f5d95bd 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -8,6 +8,8 @@ end ChainRulesCore.@non_differentiable _dropblock_mask(rng, x, gamma, clipped_block_size) +# TODO add experimental `DropBlock` options from timm such as gaussian noise and +# more precise `DropBlock` to deal with edges (#188) """ dropblock([rng = rng_from_array(x)], x::AbstractArray{T, 4}, drop_block_prob, block_size, gamma_scale, active::Bool = true) @@ -20,16 +22,15 @@ regions of size `block_size` in the input. Otherwise, it simply returns the inpu - `rng`: can be used to pass in a custom RNG instead of the default. Custom RNGs are only supported on the CPU. - `x`: input array - - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns + - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns `identity`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculations, refer to [the paper](https://arxiv.org/abs/1810.12890). -If you are an end-user, you do not want this function. Use [`DropBlock`](#) instead. +If you are not a package developer, you most likely do not want this function. Use [`DropBlock`](@ref) +instead. """ -# TODO add experimental `DropBlock` options from timm such as gaussian noise and -# more precise `DropBlock` to deal with edges (#188) function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, block_size::Integer, gamma_scale) where {T} H, W, _, _ = size(x) @@ -46,7 +47,8 @@ end # Dispatch for GPU dropblock_mask(rng::CUDA.RNG, x::CuArray, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) function dropblock_mask(rng, x::CuArray, gamma, bs) - throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports CUDA.RNG for CuArrays.")) + throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only supports + CUDA.RNG for CuArrays.")) end # Dispatch for CPU dropblock_mask(rng, x, gamma, bs) = _dropblock_mask(rng, x, gamma, bs) @@ -61,21 +63,13 @@ It can be used in two ways: either with all blocks having the same survival prob or with a linear scaling rule across the blocks. This is performed only at training time. At test time, the `DropBlock` layer is equivalent to `identity`. -!!! warning - - In the case of the linear scaling rule, the calculations of survival probabilities for each - block may lead to a survival probability > 1 for a given block. This will lead to - `DropBlock` erroring. This usually happens with a low number of blocks and a high base - survival probability, so in such cases it is recommended to use a fixed base survival - probability across blocks. If this is not desired, then a lower base survival probability - is recommended. - ([reference](https://arxiv.org/abs/1810.12890)) # Arguments - `drop_block_prob`: probability of dropping a block. If `nothing` is passed, it returns - `identity`. + `identity`. Note that some literature uses the term "survival probability" instead, + which is equivalent to `1 - drop_block_prob`. - `block_size`: size of the block to drop - `gamma_scale`: multiplicative factor for `gamma` used. For the calculation of gamma, refer to [the paper](https://arxiv.org/abs/1810.12890). @@ -93,7 +87,8 @@ end trainable(a::DropBlock) = (;) function _dropblock_checks(x::AbstractArray{<:Any, 4}, drop_block_prob, gamma_scale) - @assert 0≤drop_block_prob≤1 "drop_block_prob must be between 0 and 1, got $drop_block_prob" + @assert 0≤drop_block_prob≤1 "drop_block_prob must be between 0 and 1, got + $drop_block_prob" @assert 0≤gamma_scale≤1 "gamma_scale must be between 0 and 1, got $gamma_scale" end function _dropblock_checks(x, drop_block_prob, gamma_scale) @@ -108,7 +103,7 @@ function (m::DropBlock)(x) end function Flux.testmode!(m::DropBlock, mode = true) - return (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) + return (m.active = (isnothing(mode) || mode === :auto) ? nothing : !mode; m) end function DropBlock(drop_block_prob = 0.1, block_size::Integer = 7, gamma_scale = 1.0, @@ -126,36 +121,39 @@ function Base.show(io::IO, d::DropBlock) return print(io, ")") end -# TODO look into "row" mode for stochastic depth """ - DropPath(p; [rng = rng_from_array(x)]) + StochasticDepth(p, mode = :row; rng = rng_from_array()) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `0 ≤ p ≤ 1` and -`identity` if p is `nothing`. +Implements Stochastic Depth. This is a `Dropout` layer from Flux that drops values +with probability `p`. ([reference](https://arxiv.org/abs/1603.09382)) This layer can be used to drop certain blocks in a residual structure and allow them to propagate completely through the skip connection. It can be used in two ways: either with all blocks having the same survival probability or with a linear scaling rule across the -blocks. This is performed only at training time. At test time, the `DropPath` layer is +blocks. This is performed only at training time. At test time, the `StochasticDepth` layer is equivalent to `identity`. -!!! warning - - In the case of the linear scaling rule, the calculations of survival probabilities for each - block may lead to a survival probability > 1 for a given block. This will lead to - `DropPath` erroring. This usually happens with a low number of blocks and a high base - survival probability, so in such cases it is recommended to use a fixed base survival - probability across blocks. If this is not desired, then a lower base survival probability - is recommended. - # Arguments - - `p`: rate of Stochastic Depth. + - `p`: probability of Stochastic Depth. Note that some literature uses the term "survival + probability" instead, which is equivalent to `1 - p`. + - `mode`: Either `:batch` or `:row`. `:batch` randomly zeroes the entire input, `row` zeroes + randomly selected rows from the batch. The default is `:row`. - `rng`: can be used to pass in a custom RNG instead of the default. See `Flux.Dropout` for more information on the behaviour of this argument. Custom RNGs are only supported on the CPU. """ -function DropPath(p; rng = rng_from_array()) - return isnothing(p) ? identity : Dropout(p; dims = 4, rng) +function StochasticDepth(p, mode = :row; rng = rng_from_array()) + if isnothing(p) + return identity + else + if mode === :batch + return Dropout(p; dims = 5, rng) + elseif mode === :row + return Dropout(p; dims = 4, rng) + else + throw(ArgumentError("mode must be either `:batch` or `:row`, got $mode")) + end + end end diff --git a/src/layers/mbconv.jl b/src/layers/mbconv.jl index e37a98406..13eef8ae1 100644 --- a/src/layers/mbconv.jl +++ b/src/layers/mbconv.jl @@ -1,17 +1,16 @@ """ - dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, - stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), - pad::Integer = 0, [bias, weight, init]) + dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, stride::Integer = 1, + bias::Bool = !(norm_layer !== identity), pad::Integer = 0, [bias, weight, init]) Create a depthwise separable convolution chain as used in MobileNetv1. This is sequence of layers: - a `kernel_size` depthwise convolution from `inplanes => inplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[1] == true`; otherwise + - a (batch) normalisation layer + `activation` (if `norm_layer !== identity`; otherwise `activation` is applied to the convolution output) - a `kernel_size` convolution from `inplanes => outplanes` - - a (batch) normalisation layer + `activation` (if `use_norm[2] == true`; otherwise + - a (batch) normalisation layer + `activation` (if `norm_layer !== identity`; otherwise `activation` is applied to the convolution output) See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). @@ -22,28 +21,21 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `inplanes`: number of input feature maps - `outplanes`: number of output feature maps - `activation`: the activation function for the final layer - - `revnorm`: set to `true` to place the batch norm before the convolution - - `use_norm`: a tuple of two booleans to specify whether to use normalization for the first and - second convolution - - `bias`: a tuple of two booleans to specify whether to use bias for the first and second - convolution. This is set to `(false, false)` by default if `use_norm[0] == true` and - `use_norm[1] == true`. + - `norm_layer`: the normalisation layer used. Note that using `identity` as the normalisation + layer will result in no normalisation being applied. + - `bias`: whether to use bias in the convolution layers. - `stride`: stride of the first convolution kernel - `pad`: padding of the first convolution kernel - - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) + - `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](@ref)) """ -function dwsep_conv_bn(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, - activation = relu; eps::Float32 = 1.0f-5, revnorm::Bool = false, - stride::Integer = 1, use_norm::NTuple{2, Bool} = (true, true), - bias::NTuple{2, Bool} = (!use_norm[1], !use_norm[2]), kwargs...) - return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; eps, - revnorm, use_norm = use_norm[1], stride, bias = bias[1], - groups = inplanes, kwargs...), - conv_norm((1, 1), inplanes, outplanes, activation; eps, - revnorm, use_norm = use_norm[2], bias = bias[2])) +function dwsep_conv_norm(kernel_size::Dims{2}, inplanes::Integer, outplanes::Integer, + activation = relu; norm_layer = BatchNorm, stride::Integer = 1, + bias::Bool = !(norm_layer !== identity), kwargs...) + return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; norm_layer, + stride, bias, groups = inplanes, kwargs...), # depthwise convolution + conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, bias)) # pointwise convolution end -# TODO add support for stochastic depth to mbconv and fused_mbconv """ mbconv(kernel_size::Dims{2}, inplanes::Integer, explanes::Integer, outplanes::Integer, activation = relu; stride::Integer, @@ -76,7 +68,7 @@ First introduced in the MobileNetv2 paper. - `activation`: The activation function for the first two convolution layer - `stride`: The stride of the convolutional kernel, has to be either 1 or 2 - `reduction`: The reduction factor for the number of hidden feature maps - in a squeeze and excite layer (see [`squeeze_excite`](#)) + in a squeeze and excite layer (see [`squeeze_excite`](@ref)) - `se_round_fn`: The function to round the number of reduced feature maps in the squeeze and excite layer - `norm_layer`: The normalization layer to use diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index e6336de9c..7aa9bc803 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -1,7 +1,7 @@ -# TODO @theabhirath figure out consistent behaviour for dropout rates - 0.0 vs `nothing` +# TODO @theabhirath figure out consistent behaviour for dropout probs - 0.0 vs `nothing` """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout_rate = 0., activation = gelu) + dropout_prob = 0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -10,18 +10,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout_rate`: Dropout rate. + - `dropout_prob`: Dropout probability. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout_rate = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_rate), - Dense(hidden_planes, outplanes), Dropout(dropout_rate)) + dropout_prob = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout_prob), + Dense(hidden_planes, outplanes), Dropout(dropout_prob)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout_rate = 0.0, activation = gelu) + outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -32,17 +32,17 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout_rate`: Dropout rate. + - `dropout_prob`: Dropout probability. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout_rate = 0.0, + outplanes::Integer = inplanes; dropout_prob = 0.0, activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(dropout_rate), + Dropout(dropout_prob), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(dropout_rate)) + Dropout(dropout_prob)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/layers/selayers.jl b/src/layers/selayers.jl index 044d61dbf..40a6c027d 100644 --- a/src/layers/selayers.jl +++ b/src/layers/selayers.jl @@ -1,43 +1,29 @@ """ - squeeze_excite(inplanes::Integer, squeeze_planes::Integer; - norm_layer = planes -> identity, activation = relu, - gate_activation = sigmoid) - - squeeze_excite(inplanes::Integer; reduction::Real = 16, - norm_layer = planes -> identity, activation = relu, - gate_activation = sigmoid) + squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels, + norm_layer = identity, activation = relu, gate_activation = sigmoid) Creates a squeeze-and-excitation layer used in MobileNets, EfficientNets and SE-ResNets. # Arguments - `inplanes`: The number of input feature maps - - `squeeze_planes`: The number of feature maps in the intermediate layers. Alternatively, - specify the keyword arguments `reduction` and `rd_divisior`, which determine the number - of feature maps in the intermediate layers from the number of input feature maps as: - `squeeze_planes = _round_channels(inplanes ÷ reduction)`. (See [`_round_channels`](#) for details.) + - `reduction`: The reduction factor for the number of hidden feature maps in the + squeeze and excite layer. The number of hidden feature maps is calculated as + `round_fn(inplanes / reduction)`. + - `round_fn`: The function to round the number of reduced feature maps. - `activation`: The activation function for the first convolution layer - `gate_activation`: The activation function for the gate layer - `norm_layer`: The normalization layer to be used after the convolution layers - `rd_planes`: The number of hidden feature maps in a squeeze and excite layer """ -# TODO look into a `get_norm_act` layer that will return a closure over the norm layer -# with the activation function passed in when the norm layer is not `identity` -function squeeze_excite(inplanes::Integer, squeeze_planes::Integer; - norm_layer = planes -> identity, activation = relu, - gate_activation = sigmoid) - layers = [AdaptiveMeanPool((1, 1)), - Conv((1, 1), inplanes => squeeze_planes), - norm_layer(squeeze_planes), - activation, - Conv((1, 1), squeeze_planes => inplanes), - norm_layer(inplanes), - gate_activation] - return SkipConnection(Chain(filter!(!=(identity), layers)...), .*) -end -function squeeze_excite(inplanes::Integer; reduction::Real = 16, - round_fn = _round_channels, kwargs...) - return squeeze_excite(inplanes, round_fn(inplanes / reduction); kwargs...) +function squeeze_excite(inplanes::Integer; reduction::Real = 16, round_fn = _round_channels, + norm_layer = identity, activation = relu, gate_activation = sigmoid) + squeeze_planes = round_fn(inplanes ÷ reduction) + return SkipConnection(Chain(AdaptiveMeanPool((1, 1)), + conv_norm((1, 1), inplanes, squeeze_planes, activation; + norm_layer)..., + conv_norm((1, 1), squeeze_planes, inplanes, + gate_activation; norm_layer)...), .*) end """ diff --git a/src/mixers/core.jl b/src/mixers/core.jl index 875136b2e..4c9c31def 100644 --- a/src/mixers/core.jl +++ b/src/mixers/core.jl @@ -1,6 +1,6 @@ """ mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels::Integer = 3, norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0., + patch_size::Dims{2} = (16, 16), embedplanes = 512, stochastic_depth_prob = 0., depth::Integer = 12, nclasses::Integer = 1000, kwargs...) Creates a model with the MLPMixer architecture. @@ -9,26 +9,28 @@ Creates a model with the MLPMixer architecture. # Arguments - `block`: the type of mixer block to use in the model - architecture dependent - (a constructor of the form `block(embedplanes, npatches; drop_path_rate, kwargs...)`) + (a constructor of the form `block(embedplanes, npatches; stochastic_depth_prob, kwargs...)`) - `imsize`: the size of the input image - `inchannels`: the number of input channels - `norm_layer`: the normalization layer to use in the model - `patch_size`: the size of the patches - `embedplanes`: the number of channels after the patch embedding (denotes the hidden dimension) - - `drop_path_rate`: Stochastic depth rate + - `stochastic_depth_prob`: Stochastic depth probability - `depth`: the number of blocks in the model - `nclasses`: number of output classes - `kwargs`: additional arguments (if any) to pass to the mixer block. Will use the defaults if not specified. """ function mlpmixer(block, imsize::Dims{2} = (224, 224); norm_layer = LayerNorm, - patch_size::Dims{2} = (16, 16), embedplanes = 512, drop_path_rate = 0.0, + patch_size::Dims{2} = (16, 16), embedplanes = 512, + stochastic_depth_prob = 0.0, depth::Integer = 12, inchannels::Integer = 3, nclasses::Integer = 1000, kwargs...) npatches = prod(imsize .÷ patch_size) - dp_rates = linear_scheduler(drop_path_rate; depth) + sdschedule = linear_scheduler(stochastic_depth_prob; depth) layers = Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), - Chain([block(embedplanes, npatches; drop_path_rate = dp_rates[i], + Chain([block(embedplanes, npatches; + stochastic_depth_prob = sdschedule[i], kwargs...) for i in 1:depth]...)) classifier = Chain(norm_layer(embedplanes), seconddimmean, Dense(embedplanes, nclasses)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index ab89baadc..845786686 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -44,7 +44,7 @@ end """ spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4.0, norm_layer = LayerNorm, mlp_layer = gated_mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + dropout_prob = 0.0, stochastic_depth_prob = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. ([reference](https://arxiv.org/abs/2105.08050)) @@ -56,19 +56,20 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate + - `dropout_prob`: the dropout probability to use in the MLP blocks + - `stochastic_depth_prob`: Stochastic depth probability - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes::Integer, npatches::Integer; mlp_ratio = 4.0, norm_layer = LayerNorm, mlp_layer = gated_mlp_block, - dropout_rate = 0.0, drop_path_rate = 0.0, activation = gelu) + dropout_prob = 0.0, stochastic_depth_prob = 0.0, + activation = gelu) channelplanes = floor(Int, mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), mlp_layer(sgu, planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +) + dropout_prob), + StochasticDepth(stochastic_depth_prob)), +) end """ @@ -86,7 +87,7 @@ Creates a model with the gMLP architecture. - `inchannels`: the number of input channels - `nclasses`: number of output classes -See also [`Metalhead.mlpmixer`](#). +See also [`Metalhead.mlpmixer`](@ref). """ struct gMLP layers::Any @@ -94,10 +95,13 @@ end @functor gMLP function gMLP(config::Symbol; imsize::Dims{2} = (224, 224), patch_size::Dims{2} = (16, 16), - inchannels::Integer = 3, nclasses::Integer = 1000) + pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("gmlp", config)) + end return gMLP(layers) end diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index 37cc271fb..24656bd63 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block, - mlp_ratio = (0.5, 4.0), dropout_rate = 0.0, drop_path_rate = 0.0, + mlp_ratio = (0.5, 4.0), dropout_prob = 0.0, stochastic_depth_prob = 0.0, activation = gelu) Creates a feedforward block for the MLPMixer architecture. @@ -13,24 +13,24 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate + - `dropout_prob`: the dropout probability to use in the MLP blocks + - `stochastic_depth_prob`: Stochastic depth probability - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block, - mlp_ratio::NTuple{2, Number} = (0.5, 4.0), dropout_rate = 0.0, - drop_path_rate = 0.0, activation = gelu) + mlp_ratio::NTuple{2, Number} = (0.5, 4.0), dropout_prob = 0.0, + stochastic_depth_prob = 0.0, activation = gelu) tokenplanes, channelplanes = floor.(Int, planes .* mlp_ratio) return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), mlp_layer(npatches, tokenplanes; activation, - dropout_rate), + dropout_prob), swapdims((2, 1, 3)), - DropPath(drop_path_rate)), +), + StochasticDepth(stochastic_depth_prob)), +), SkipConnection(Chain(LayerNorm(planes), mlp_layer(planes, channelplanes; activation, - dropout_rate), - DropPath(drop_path_rate)), +)) + dropout_prob), + StochasticDepth(stochastic_depth_prob)), +)) end """ @@ -45,11 +45,11 @@ Creates a model with the MLPMixer architecture. - `config`: the size of the model - one of `small`, `base`, `large` or `huge` - `patch_size`: the size of the patches - `imsize`: the size of the input image - - `drop_path_rate`: Stochastic depth rate + - `stochastic_depth_prob`: Stochastic depth probability - `inchannels`: the number of input channels - `nclasses`: number of output classes -See also [`Metalhead.mlpmixer`](#). +See also [`Metalhead.mlpmixer`](@ref). """ struct MLPMixer layers::Any @@ -57,11 +57,14 @@ end @functor MLPMixer function MLPMixer(config::Symbol; imsize::Dims{2} = (224, 224), - patch_size::Dims{2} = (16, 16), + patch_size::Dims{2} = (16, 16), pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(mixerblock, imsize; patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + if pretrain + loadpretrain!(layers, string("mlpmixer", config)) + end return MLPMixer(layers) end diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl index 21ad89d65..b1ff44ea1 100644 --- a/src/mixers/resmlp.jl +++ b/src/mixers/resmlp.jl @@ -1,5 +1,5 @@ """ - resmixerblock(planes, npatches; dropout_rate = 0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; dropout_prob = 0., stochastic_depth_prob = 0., mlp_ratio = 4.0, activation = gelu, layerscale_init = 1e-4) Creates a block for the ResMixer architecture. @@ -12,25 +12,25 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `dropout_rate`: the dropout rate to use in the MLP blocks - - `drop_path_rate`: Stochastic depth rate + - `dropout_prob`: the dropout probability to use in the MLP blocks + - `stochastic_depth_prob`: Stochastic depth probability - `activation`: the activation function to use in the MLP blocks - `layerscale_init`: initialisation constant for the LayerScale """ function resmixerblock(planes::Integer, npatches::Integer; mlp_layer = mlp_block, - mlp_ratio = 4.0, layerscale_init = 1e-4, dropout_rate = 0.0, - drop_path_rate = 0.0, activation = gelu) + mlp_ratio = 4.0, layerscale_init = 1e-4, dropout_prob = 0.0, + stochastic_depth_prob = 0.0, activation = gelu) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), swapdims((2, 1, 3)), LayerScale(planes, layerscale_init), - DropPath(drop_path_rate)), +), + StochasticDepth(stochastic_depth_prob)), +), SkipConnection(Chain(Flux.Scale(planes), mlp_layer(planes, floor(Int, mlp_ratio * planes); - dropout_rate, activation), + dropout_prob, activation), LayerScale(planes, layerscale_init), - DropPath(drop_path_rate)), +)) + StochasticDepth(stochastic_depth_prob)), +)) end """ @@ -48,7 +48,7 @@ Creates a model with the ResMLP architecture. - `inchannels`: the number of input channels - `nclasses`: number of output classes -See also [`Metalhead.mlpmixer`](#). +See also [`Metalhead.mlpmixer`](@ref). """ struct ResMLP layers::Any @@ -56,11 +56,14 @@ end @functor ResMLP function ResMLP(config::Symbol; imsize::Dims{2} = (224, 224), - patch_size::Dims{2} = (16, 16), + patch_size::Dims{2} = (16, 16), pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(MIXER_CONFIGS)) layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, MIXER_CONFIGS[config]..., inchannels, nclasses) + if pretrain + loadpretrain!(layers, string(resmlp, config)) + end return ResMLP(layers) end diff --git a/src/utilities.jl b/src/utilities.jl index 13b8ec385..6f97e81d2 100644 --- a/src/utilities.jl +++ b/src/utilities.jl @@ -14,9 +14,9 @@ end Convenience function for applying an activation function to the output after summing up the input arrays. Useful as the `connection` argument for the block -function in [`resnet`](#). +function in [`resnet`](@ref). -See also [`reluadd`](#). +See also [`reluadd`](@ref). """ addact(activation = relu, xs...) = activation(sum(xs)) @@ -25,9 +25,9 @@ addact(activation = relu, xs...) = activation(sum(xs)) Convenience function for adding input arrays after applying an activation function to them. Useful as the `connection` argument for the block function in -[`resnet`](#). +[`resnet`](@ref). -See also [`addrelu`](#). +See also [`addrelu`](@ref). """ actadd(activation = relu, xs...) = sum(activation.(x) for x in xs) @@ -66,17 +66,18 @@ function _maybe_big_show(io, model) end """ - linear_scheduler(drop_rate = 0.0; start_value = 0.0, depth) - linear_scheduler(drop_rate::Nothing; depth::Integer) + linear_scheduler(drop_prob = 0.0; start_value = 0.0, depth) + linear_scheduler(drop_prob::Nothing; depth::Integer) -Returns the dropout rates for a given depth using the linear scaling rule. If the -`drop_rate` is `nothing`, it returns a `Vector` of length `depth` with all values -equal to `nothing`. +Returns the dropout probabilities for a given depth using the linear scaling rule. Note +that this returns evenly spaced values between `start_value` and `drop_prob`, not including +`drop_prob`. If `drop_prob` is `nothing`, it returns a `Vector` of length `depth` with all +values equal to `nothing`. """ -function linear_scheduler(drop_rate = 0.0; depth::Integer, start_value = 0.0) - return LinRange(start_value, drop_rate, depth) +function linear_scheduler(drop_prob = 0.0; depth::Integer, start_value = 0.0) + return LinRange(start_value, drop_prob, depth + 1)[1:depth] end -linear_scheduler(drop_rate::Nothing; depth::Integer) = fill(drop_rate, depth) +linear_scheduler(drop_prob::Nothing; depth::Integer) = fill(drop_prob, depth) # Utility function for depth and configuration checks in models function _checkconfig(config, configs) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 75bfb5b07..02c57941f 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ - transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_rate = 0.) + transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout_prob = 0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,26 +10,26 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `dropout_rate`: dropout rate + - `dropout_prob`: dropout probability """ function transformer_encoder(planes::Integer, depth::Integer, nheads::Integer; - mlp_ratio = 4.0, dropout_rate = 0.0) + mlp_ratio = 4.0, dropout_prob = 0.0) layers = [Chain(SkipConnection(prenorm(planes, MHAttention(planes, nheads; - attn_dropout_rate = dropout_rate, - proj_dropout_rate = dropout_rate)), + attn_dropout_prob = dropout_prob, + proj_dropout_prob = dropout_prob)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - dropout_rate)), +)) + dropout_prob)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_rate = 0.1, - emb_dropout_rate = 0.1, pool = :class, nclasses::Integer = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout_prob = 0.1, + emb_dropout_prob = 0.1, pool = :class, nclasses::Integer = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -43,25 +43,25 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `dropout_rate`: dropout rate - - `emb_dropout`: dropout rate for the positional embedding layer + - `dropout_prob`: dropout probability + - `emb_dropout`: dropout probability for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, patch_size::Dims{2} = (16, 16), embedplanes::Integer = 768, - depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_rate = 0.1, - emb_dropout_rate = 0.1, pool::Symbol = :class, nclasses::Integer = 1000) + depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_prob = 0.1, + emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000) @assert pool in [:class, :mean] "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_dropout_rate), + Dropout(emb_dropout_prob), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, - dropout_rate), - pool == :class ? x -> x[:, 1, :] : seconddimmean), + dropout_prob), + pool === :class ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end @@ -92,7 +92,7 @@ Creates a Vision Transformer (ViT) model. - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output -See also [`Metalhead.vit`](#). +See also [`Metalhead.vit`](@ref). """ struct ViT layers::Any @@ -100,9 +100,12 @@ end @functor ViT function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), - inchannels::Integer = 3, nclasses::Integer = 1000) + pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) _checkconfig(config, keys(VIT_CONFIGS)) layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...) + if pretrain + loadpretrain!(layers, string("vit", config)) + end return ViT(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index 0c796e24c..7a6ed42d2 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -42,12 +42,12 @@ end ] @testset for layers in layer_list drop_list = [ - (dropout_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1), - (dropout_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5), - (dropout_rate = 0.8, drop_path_rate = 0.8, drop_block_rate = 0.8), + (dropout_prob = 0.1, stochastic_depth_prob = 0.1, dropblock_prob = 0.1), + (dropout_prob = 0.5, stochastic_depth_prob = 0.5, dropblock_prob = 0.5), + (dropout_prob = 0.8, stochastic_depth_prob = 0.8, dropblock_prob = 0.8), ] - @testset for drop_rates in drop_list - m = Metalhead.resnet(block_fn, layers; drop_rates...) + @testset for drop_probs in drop_list + m = Metalhead.resnet(block_fn, layers; drop_probs...) @test size(m(x_224)) == (1000, 1) @test gradtest(m, x_224) _gc()