|
| 1 | +function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, |
| 2 | + width_mult::Real; norm_layer = BatchNorm, kwargs...) |
| 3 | + block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx] |
| 4 | + outplanes = _round_channels(outplanes * width_mult) |
| 5 | + if stage_idx != 1 |
| 6 | + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult) |
| 7 | + end |
| 8 | + function get_layers(block_idx::Integer) |
| 9 | + inplanes = block_idx == 1 ? inplanes : outplanes |
| 10 | + stride = block_idx == 1 ? stride : 1 |
| 11 | + block = Chain(block_fn((k, k), inplanes, outplanes, activation; |
| 12 | + stride, pad = SamePad(), norm_layer, kwargs...)...) |
| 13 | + return (block,) |
| 14 | + end |
| 15 | + return get_layers, nrepeats |
| 16 | +end |
| 17 | + |
| 18 | +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, |
| 19 | + scalings::NTuple{2, Real}; norm_layer = BatchNorm, |
| 20 | + divisor::Integer = 8, se_from_explanes::Bool = false, |
| 21 | + kwargs...) |
| 22 | + width_mult, depth_mult = scalings |
| 23 | + block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx] |
| 24 | + # calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes |
| 25 | + if !isnothing(reduction) |
| 26 | + reduction = !se_from_explanes ? reduction * expansion : reduction |
| 27 | + end |
| 28 | + if stage_idx != 1 |
| 29 | + inplanes = _round_channels(block_configs[stage_idx - 1][3] * width_mult, divisor) |
| 30 | + end |
| 31 | + outplanes = _round_channels(outplanes * width_mult, divisor) |
| 32 | + function get_layers(block_idx::Integer) |
| 33 | + inplanes = block_idx == 1 ? inplanes : outplanes |
| 34 | + explanes = _round_channels(inplanes * expansion, divisor) |
| 35 | + stride = block_idx == 1 ? stride : 1 |
| 36 | + block = block_fn((k, k), inplanes, explanes, outplanes, activation; norm_layer, |
| 37 | + stride, reduction, kwargs...) |
| 38 | + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) |
| 39 | + end |
| 40 | + return get_layers, ceil(Int, nrepeats * depth_mult) |
| 41 | +end |
| 42 | + |
| 43 | +function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer, |
| 44 | + width_mult::Real; norm_layer = BatchNorm, kwargs...) |
| 45 | + return mbconv_builder(block_configs, inplanes, stage_idx, (width_mult, 1); |
| 46 | + norm_layer, kwargs...) |
| 47 | +end |
| 48 | + |
| 49 | +function fused_mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer; |
| 50 | + norm_layer = BatchNorm, kwargs...) |
| 51 | + block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx] |
| 52 | + inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1][3] |
| 53 | + function get_layers(block_idx::Integer) |
| 54 | + inplanes = block_idx == 1 ? inplanes : outplanes |
| 55 | + explanes = _round_channels(inplanes * expansion, 8) |
| 56 | + stride = block_idx == 1 ? stride : 1 |
| 57 | + block = block_fn((k, k), inplanes, explanes, outplanes, activation; |
| 58 | + norm_layer, stride, kwargs...) |
| 59 | + return stride == 1 && inplanes == outplanes ? (identity, block) : (block,) |
| 60 | + end |
| 61 | + return get_layers, nrepeats |
| 62 | +end |
| 63 | + |
| 64 | +# TODO - these builders need to be more flexible to potentially specify stuff like |
| 65 | +# activation functions and reductions that don't change |
| 66 | +function _get_builder(::typeof(dwsep_conv_bn), block_configs::AbstractVector{<:Tuple}, |
| 67 | + inplanes::Integer, stage_idx::Integer; |
| 68 | + scalings::Union{Nothing, NTuple{2, Real}} = nothing, |
| 69 | + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) |
| 70 | + @assert isnothing(scalings) "dwsep_conv_bn does not support the `scalings` argument" |
| 71 | + return dwsepconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, |
| 72 | + kwargs...) |
| 73 | +end |
| 74 | + |
| 75 | +function _get_builder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}, |
| 76 | + inplanes::Integer, stage_idx::Integer; |
| 77 | + scalings::Union{Nothing, NTuple{2, Real}} = nothing, |
| 78 | + width_mult::Union{Nothing, Number} = nothing, norm_layer, kwargs...) |
| 79 | + if isnothing(scalings) |
| 80 | + return mbconv_builder(block_configs, inplanes, stage_idx, width_mult; norm_layer, |
| 81 | + kwargs...) |
| 82 | + elseif isnothing(width_mult) |
| 83 | + return mbconv_builder(block_configs, inplanes, stage_idx, scalings; norm_layer, |
| 84 | + kwargs...) |
| 85 | + else |
| 86 | + throw(ArgumentError("Only one of `scalings` and `width_mult` can be specified")) |
| 87 | + end |
| 88 | +end |
| 89 | + |
| 90 | +function _get_builder(::typeof(fused_mbconv), block_configs::AbstractVector{<:Tuple}, |
| 91 | + inplanes::Integer, stage_idx::Integer; |
| 92 | + scalings::Union{Nothing, NTuple{2, Real}} = nothing, |
| 93 | + width_mult::Union{Nothing, Number} = nothing, norm_layer) |
| 94 | + @assert isnothing(width_mult) "fused_mbconv does not support the `width_mult` argument." |
| 95 | + @assert isnothing(scalings)||scalings == (1, 1) "fused_mbconv does not support the `scalings` argument" |
| 96 | + return fused_mbconv_builder(block_configs, inplanes, stage_idx; norm_layer) |
| 97 | +end |
| 98 | + |
| 99 | +function mbconv_stack_builder(block_configs::AbstractVector{<:Tuple}, inplanes::Integer; |
| 100 | + scalings::Union{Nothing, NTuple{2, Real}} = nothing, |
| 101 | + width_mult::Union{Nothing, Number} = nothing, |
| 102 | + norm_layer = BatchNorm, kwargs...) |
| 103 | + bxs = [_get_builder(block_configs[idx][1], block_configs, inplanes, idx; scalings, |
| 104 | + width_mult, norm_layer, kwargs...) |
| 105 | + for idx in eachindex(block_configs)] |
| 106 | + return (stage_idx, block_idx) -> first.(bxs)[stage_idx](block_idx), last.(bxs) |
| 107 | +end |
0 commit comments