11"""
2- irblockbuilder (::typeof(irblockfn), block_configs::AbstractVector{<:Tuple},
2+ invresbuilder (::typeof(irblockfn), block_configs::AbstractVector{<:Tuple},
33 inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
44 stochastic_depth_prob = nothing, norm_layer = BatchNorm,
55 divisor::Integer = 8, kwargs...)
66
7- Constructs a collection of inverted residual blocks for a given stage. Note that
8- this function is not intended to be called directly, but rather by the [`mbconv_stage_builder`](@ref)
9- function. This function must only be extended if the user wishes to extend a custom inverted
10- residual block type.
7+ Creates a block builder for `irblockfn` within a given stage.
8+ Note that this function is not intended to be called directly, but instead passed to
9+ [`mbconv_stage_builder`](@ref) which will return a builder over all stages.
10+ Users wanting to provide a custom inverted residual block type can extend this
11+ function by defining `invresbuilder(::typeof(my_block), ...)`.
1112"""
12- function irblockbuilder (:: typeof (dwsep_conv_norm), block_configs:: AbstractVector{<:Tuple} ,
13- inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
14- stochastic_depth_prob = nothing , norm_layer = BatchNorm,
15- divisor:: Integer = 8 , kwargs... )
13+ function invresbuilder (:: typeof (dwsep_conv_norm), block_configs:: AbstractVector{<:Tuple} ,
14+ inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
15+ stochastic_depth_prob = nothing , norm_layer = BatchNorm,
16+ divisor:: Integer = 8 , kwargs... )
1617 width_mult, depth_mult = scalings
1718 block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
1819 outplanes = _round_channels (outplanes * width_mult, divisor)
@@ -29,10 +30,10 @@ function irblockbuilder(::typeof(dwsep_conv_norm), block_configs::AbstractVector
2930 return get_layers, ceil (Int, nrepeats * depth_mult)
3031end
3132
32- function irblockbuilder (:: typeof (mbconv), block_configs:: AbstractVector{<:Tuple} ,
33- inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
34- stochastic_depth_prob = nothing , norm_layer = BatchNorm,
35- divisor:: Integer = 8 , se_from_explanes:: Bool = false , kwargs... )
33+ function invresbuilder (:: typeof (mbconv), block_configs:: AbstractVector{<:Tuple} ,
34+ inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
35+ stochastic_depth_prob = nothing , norm_layer = BatchNorm,
36+ divisor:: Integer = 8 , se_from_explanes:: Bool = false , kwargs... )
3637 width_mult, depth_mult = scalings
3738 block_repeats = [ceil (Int, block_configs[idx][end - 2 ] * depth_mult)
3839 for idx in eachindex (block_configs)]
@@ -64,10 +65,10 @@ function irblockbuilder(::typeof(mbconv), block_configs::AbstractVector{<:Tuple}
6465 return get_layers, block_repeats[stage_idx]
6566end
6667
67- function irblockbuilder (:: typeof (fused_mbconv), block_configs:: AbstractVector{<:Tuple} ,
68- inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
69- stochastic_depth_prob = nothing , norm_layer = BatchNorm,
70- divisor:: Integer = 8 , kwargs... )
68+ function invresbuilder (:: typeof (fused_mbconv), block_configs:: AbstractVector{<:Tuple} ,
69+ inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
70+ stochastic_depth_prob = nothing , norm_layer = BatchNorm,
71+ divisor:: Integer = 8 , kwargs... )
7172 width_mult, depth_mult = scalings
7273 block_repeats = [ceil (Int, block_configs[idx][end - 1 ] * depth_mult)
7374 for idx in eachindex (block_configs)]
9091
9192function mbconv_stage_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ,
9293 scalings:: NTuple{2, Real} ; kwargs... )
93- bxs = [irblockbuilder (block_configs[idx][1 ], block_configs, inplanes, idx, scalings;
94- kwargs... ) for idx in eachindex (block_configs)]
94+ bxs = [invresbuilder (block_configs[idx][1 ], block_configs, inplanes, idx, scalings;
95+ kwargs... ) for idx in eachindex (block_configs)]
9596 return (stage_idx, block_idx) -> first .(bxs)[stage_idx](block_idx), last .(bxs)
9697end
0 commit comments