1- function dwsepconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ,
2- width_mult:: Real ; norm_layer = BatchNorm, kwargs... )
1+ """
2+ invresbuilder(::typeof(irblockfn), block_configs::AbstractVector{<:Tuple},
3+ inplanes::Integer, stage_idx::Integer, scalings::NTuple{2, Real};
4+ stochastic_depth_prob = nothing, norm_layer = BatchNorm,
5+ divisor::Integer = 8, kwargs...)
6+
7+ Creates a block builder for `irblockfn` within a given stage.
8+ Note that this function is not intended to be called directly, but instead passed to
9+ [`mbconv_stage_builder`](@ref) which will return a builder over all stages.
10+ Users wanting to provide a custom inverted residual block type can extend this
11+ function by defining `invresbuilder(::typeof(my_block), ...)`.
12+ """
13+ function invresbuilder (:: typeof (dwsep_conv_norm), block_configs:: AbstractVector{<:Tuple} ,
14+ inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
15+ stochastic_depth_prob = nothing , norm_layer = BatchNorm,
16+ divisor:: Integer = 8 , kwargs... )
17+ width_mult, depth_mult = scalings
318 block_fn, k, outplanes, stride, nrepeats, activation = block_configs[stage_idx]
4- outplanes = _round_channels (outplanes * width_mult)
19+ outplanes = _round_channels (outplanes * width_mult, divisor )
520 if stage_idx != 1
6- inplanes = _round_channels (block_configs[stage_idx - 1 ][3 ] * width_mult)
21+ inplanes = _round_channels (block_configs[stage_idx - 1 ][3 ] * width_mult, divisor )
722 end
823 function get_layers (block_idx:: Integer )
924 inplanes = block_idx == 1 ? inplanes : outplanes
@@ -12,15 +27,17 @@ function dwsepconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
1227 stride, pad = SamePad (), norm_layer, kwargs... )... )
1328 return (block,)
1429 end
15- return get_layers, nrepeats
30+ return get_layers, ceil (Int, nrepeats * depth_mult)
1631end
1732
18- function mbconv_builder (block_configs, inplanes :: Integer , stage_idx :: Integer ,
19- scalings:: NTuple{2, Real} ; norm_layer = BatchNorm,
20- divisor :: Integer = 8 , se_from_explanes :: Bool = false ,
21- kwargs... )
33+ function invresbuilder ( :: typeof (mbconv), block_configs :: AbstractVector{<:Tuple} ,
34+ inplanes :: Integer , stage_idx :: Integer , scalings:: NTuple{2, Real} ;
35+ stochastic_depth_prob = nothing , norm_layer = BatchNorm ,
36+ divisor :: Integer = 8 , se_from_explanes :: Bool = false , kwargs... )
2237 width_mult, depth_mult = scalings
23- block_fn, k, outplanes, expansion, stride, nrepeats, reduction, activation = block_configs[stage_idx]
38+ block_repeats = [ceil (Int, block_configs[idx][end - 2 ] * depth_mult)
39+ for idx in eachindex (block_configs)]
40+ block_fn, k, outplanes, expansion, stride, _, reduction, activation = block_configs[stage_idx]
2441 # calculate number of reduced channels for squeeze-excite layer from explanes instead of inplanes
2542 if ! isnothing (reduction)
2643 reduction = ! se_from_explanes ? reduction * expansion : reduction
@@ -29,79 +46,52 @@ function mbconv_builder(block_configs, inplanes::Integer, stage_idx::Integer,
2946 inplanes = _round_channels (block_configs[stage_idx - 1 ][3 ] * width_mult, divisor)
3047 end
3148 outplanes = _round_channels (outplanes * width_mult, divisor)
49+ sdschedule = linear_scheduler (stochastic_depth_prob; depth = sum (block_repeats))
3250 function get_layers (block_idx:: Integer )
3351 inplanes = block_idx == 1 ? inplanes : outplanes
3452 explanes = _round_channels (inplanes * expansion, divisor)
3553 stride = block_idx == 1 ? stride : 1
3654 block = block_fn ((k, k), inplanes, explanes, outplanes, activation; norm_layer,
3755 stride, reduction, kwargs... )
38- return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
56+ use_skip = stride == 1 && inplanes == outplanes
57+ if use_skip
58+ schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
59+ drop_path = StochasticDepth (sdschedule[schedule_idx])
60+ return (drop_path, block)
61+ else
62+ return (block,)
63+ end
3964 end
40- return get_layers, ceil (Int, nrepeats * depth_mult)
41- end
42-
43- function mbconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ,
44- width_mult:: Real ; norm_layer = BatchNorm, kwargs... )
45- return mbconv_builder (block_configs, inplanes, stage_idx, (width_mult, 1 );
46- norm_layer, kwargs... )
65+ return get_layers, block_repeats[stage_idx]
4766end
4867
49- function fused_mbconv_builder (block_configs, inplanes:: Integer , stage_idx:: Integer ;
50- norm_layer = BatchNorm, kwargs... )
51- block_fn, k, outplanes, expansion, stride, nrepeats, activation = block_configs[stage_idx]
68+ function invresbuilder (:: typeof (fused_mbconv), block_configs:: AbstractVector{<:Tuple} ,
69+ inplanes:: Integer , stage_idx:: Integer , scalings:: NTuple{2, Real} ;
70+ stochastic_depth_prob = nothing , norm_layer = BatchNorm,
71+ divisor:: Integer = 8 , kwargs... )
72+ width_mult, depth_mult = scalings
73+ block_repeats = [ceil (Int, block_configs[idx][end - 1 ] * depth_mult)
74+ for idx in eachindex (block_configs)]
75+ block_fn, k, outplanes, expansion, stride, _, activation = block_configs[stage_idx]
5276 inplanes = stage_idx == 1 ? inplanes : block_configs[stage_idx - 1 ][3 ]
77+ outplanes = _round_channels (outplanes * width_mult, divisor)
78+ sdschedule = linear_scheduler (stochastic_depth_prob; depth = sum (block_repeats))
5379 function get_layers (block_idx:: Integer )
5480 inplanes = block_idx == 1 ? inplanes : outplanes
55- explanes = _round_channels (inplanes * expansion, 8 )
81+ explanes = _round_channels (inplanes * expansion, divisor )
5682 stride = block_idx == 1 ? stride : 1
5783 block = block_fn ((k, k), inplanes, explanes, outplanes, activation;
5884 norm_layer, stride, kwargs... )
59- return stride == 1 && inplanes == outplanes ? (identity, block) : (block,)
85+ schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
86+ drop_path = StochasticDepth (sdschedule[schedule_idx])
87+ return stride == 1 && inplanes == outplanes ? (drop_path, block) : (block,)
6088 end
61- return get_layers, nrepeats
62- end
63-
64- # TODO - these builders need to be more flexible to potentially specify stuff like
65- # activation functions and reductions that don't change
66- function _get_builder (:: typeof (dwsep_conv_bn), block_configs:: AbstractVector{<:Tuple} ,
67- inplanes:: Integer , stage_idx:: Integer ;
68- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
69- width_mult:: Union{Nothing, Number} = nothing , norm_layer, kwargs... )
70- @assert isnothing (scalings) " dwsep_conv_bn does not support the `scalings` argument"
71- return dwsepconv_builder (block_configs, inplanes, stage_idx, width_mult; norm_layer,
72- kwargs... )
73- end
74-
75- function _get_builder (:: typeof (mbconv), block_configs:: AbstractVector{<:Tuple} ,
76- inplanes:: Integer , stage_idx:: Integer ;
77- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
78- width_mult:: Union{Nothing, Number} = nothing , norm_layer, kwargs... )
79- if isnothing (scalings)
80- return mbconv_builder (block_configs, inplanes, stage_idx, width_mult; norm_layer,
81- kwargs... )
82- elseif isnothing (width_mult)
83- return mbconv_builder (block_configs, inplanes, stage_idx, scalings; norm_layer,
84- kwargs... )
85- else
86- throw (ArgumentError (" Only one of `scalings` and `width_mult` can be specified" ))
87- end
88- end
89-
90- function _get_builder (:: typeof (fused_mbconv), block_configs:: AbstractVector{<:Tuple} ,
91- inplanes:: Integer , stage_idx:: Integer ;
92- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
93- width_mult:: Union{Nothing, Number} = nothing , norm_layer)
94- @assert isnothing (width_mult) " fused_mbconv does not support the `width_mult` argument."
95- @assert isnothing (scalings)|| scalings == (1 , 1 ) " fused_mbconv does not support the `scalings` argument"
96- return fused_mbconv_builder (block_configs, inplanes, stage_idx; norm_layer)
89+ return get_layers, block_repeats[stage_idx]
9790end
9891
99- function mbconv_stack_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ;
100- scalings:: Union{Nothing, NTuple{2, Real}} = nothing ,
101- width_mult:: Union{Nothing, Number} = nothing ,
102- norm_layer = BatchNorm, kwargs... )
103- bxs = [_get_builder (block_configs[idx][1 ], block_configs, inplanes, idx; scalings,
104- width_mult, norm_layer, kwargs... )
105- for idx in eachindex (block_configs)]
92+ function mbconv_stage_builder (block_configs:: AbstractVector{<:Tuple} , inplanes:: Integer ,
93+ scalings:: NTuple{2, Real} ; kwargs... )
94+ bxs = [invresbuilder (block_configs[idx][1 ], block_configs, inplanes, idx, scalings;
95+ kwargs... ) for idx in eachindex (block_configs)]
10696 return (stage_idx, block_idx) -> first .(bxs)[stage_idx](block_idx), last .(bxs)
10797end
0 commit comments