11"""
2- basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
3- norm_layer = BatchNorm, revnorm = false,
4- drop_block = identity, drop_path = identity,
5- attn_fn = planes -> identity)
2+ basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3+ reduction_factor::Integer = 1, activation = relu,
4+ norm_layer = BatchNorm, revnorm::Bool = false,
5+ drop_block = identity, drop_path = identity,
6+ attn_fn = planes -> identity)
67
78Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
89
@@ -11,10 +12,11 @@ Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385
1112 - `inplanes`: number of input feature maps
1213 - `planes`: number of feature maps for the block
1314 - `stride`: the stride of the block
14- - `reduction_factor`: the factor by which the input feature maps
15- are reduced before the first convolution.
15+ - `reduction_factor`: the factor by which the input feature maps are reduced before
16+ the first convolution.
1617 - `activation`: the activation function to use.
1718 - `norm_layer`: the normalization layer to use.
19+ - `revnorm`: set to `true` to place the normalisation layer before the convolution
1820 - `drop_block`: the drop block layer
1921 - `drop_path`: the drop path layer
2022 - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -36,11 +38,12 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
3638end
3739
3840"""
39- bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
40- reduction_factor = 1, activation = relu,
41- norm_layer = BatchNorm, revnorm = false,
42- drop_block = identity, drop_path = identity,
43- attn_fn = planes -> identity)
41+ bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
42+ cardinality::Integer = 1, base_width::Integer = 64,
43+ reduction_factor::Integer = 1, activation = relu,
44+ norm_layer = BatchNorm, revnorm::Bool = false,
45+ drop_block = identity, drop_path = identity,
46+ attn_fn = planes -> identity)
4447
4548Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
4649
@@ -55,6 +58,7 @@ Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.
5558 convolution.
5659 - `activation`: the activation function to use.
5760 - `norm_layer`: the normalization layer to use.
61+ - `revnorm`: set to `true` to place the normalisation layer before the convolution
5862 - `drop_block`: the drop block layer
5963 - `drop_path`: the drop path layer
6064 - `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
@@ -112,7 +116,7 @@ function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
112116end
113117
114118# Shortcut configurations for the ResNet models
115- const shortcut_dict = Dict (:A => (downsample_identity, downsample_identity),
119+ const RESNET_SHORTCUTS = Dict (:A => (downsample_identity, downsample_identity),
116120 :B => (downsample_conv, downsample_identity),
117121 :C => (downsample_conv, downsample_conv),
118122 :D => (downsample_pool, downsample_identity))
@@ -153,7 +157,8 @@ on how to use this function.
153157 shows peformance improvements over the `:deep` stem in some cases.
154158
155159 - `inchannels`: The number of channels in the input.
156- - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
160+ - `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution +
161+ normalization with a stride of two.
157162 - `norm_layer`: The normalisation layer used in the stem.
158163 - `activation`: The activation function used in the stem.
159164"""
@@ -253,8 +258,6 @@ function bottleneck_builder(block_repeats::AbstractVector{<:Integer};
253258 stride = stride_fn (stage_idx, block_idx)
254259 downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
255260 downsample_tuple[1 ] : downsample_tuple[2 ]
256- # DropBlock, DropPath both take in rates based on a linear scaling schedule
257- schedule_idx = sum (block_repeats[1 : (stage_idx - 1 )]) + block_idx
258261 drop_path = DropPath (pathschedule[schedule_idx])
259262 drop_block = DropBlock (blockschedule[schedule_idx])
260263 block = bottleneck (inplanes, planes; stride, cardinality, base_width,
@@ -289,35 +292,46 @@ function resnet(img_dims, stem, get_layers, block_repeats::AbstractVector{<:Inte
289292 return Chain (backbone, classifier_fn (nfeaturemaps))
290293end
291294
292- function resnet (block_type:: Symbol , block_repeats:: AbstractVector{<:Integer} ;
293- downsample_opt:: NTuple{2, Any} = (downsample_conv, downsample_identity),
295+ function resnet (block_type, block_repeats:: AbstractVector{<:Integer} ,
296+ downsample_opt:: NTuple{2, Any} = (downsample_conv, downsample_identity);
294297 cardinality:: Integer = 1 , base_width:: Integer = 64 , inplanes:: Integer = 64 ,
295298 reduction_factor:: Integer = 1 , imsize:: Dims{2} = (256 , 256 ),
296- inchannels:: Integer = 3 , stem_fn = resnet_stem,
297- connection = addact, activation = relu, norm_layer = BatchNorm,
298- revnorm:: Bool = false , attn_fn = planes -> identity,
299- pool_layer = AdaptiveMeanPool ((1 , 1 )), use_conv:: Bool = false ,
300- drop_block_rate = 0.0 , drop_path_rate = 0.0 , dropout_rate = 0.0 ,
301- nclasses:: Integer = 1000 )
299+ inchannels:: Integer = 3 , stem_fn = resnet_stem, connection = addact,
300+ activation = relu, norm_layer = BatchNorm, revnorm:: Bool = false ,
301+ attn_fn = planes -> identity, pool_layer = AdaptiveMeanPool ((1 , 1 )),
302+ use_conv:: Bool = false , drop_block_rate = 0.0 , drop_path_rate = 0.0 ,
303+ dropout_rate = 0.0 , nclasses:: Integer = 1000 , kwargs... )
302304 # Build stem
303305 stem = stem_fn (; inchannels)
304306 # Block builder
305- if block_type == : basicblock
307+ if block_type == basicblock
306308 @assert cardinality== 1 " Cardinality must be 1 for `basicblock`"
307309 @assert base_width== 64 " Base width must be 64 for `basicblock`"
308310 get_layers = basicblock_builder (block_repeats; inplanes, reduction_factor,
309311 activation, norm_layer, revnorm, attn_fn,
310312 drop_block_rate, drop_path_rate,
311313 stride_fn = resnet_stride,
312314 planes_fn = resnet_planes,
313- downsample_tuple = downsample_opt)
314- elseif block_type == :bottleneck
315+ downsample_tuple = downsample_opt,
316+ kwargs... )
317+ elseif block_type == bottleneck
315318 get_layers = bottleneck_builder (block_repeats; inplanes, cardinality, base_width,
316- reduction_factor, activation, norm_layer,
317- revnorm, attn_fn, drop_block_rate, drop_path_rate,
319+ reduction_factor, activation, norm_layer, revnorm,
320+ attn_fn, drop_block_rate, drop_path_rate,
318321 stride_fn = resnet_stride,
319322 planes_fn = resnet_planes,
320- downsample_tuple = downsample_opt)
323+ downsample_tuple = downsample_opt,
324+ kwargs... )
325+ elseif block_type == bottle2neck
326+ @assert drop_block_rate== 0.0 " DropBlock not supported for `bottle2neck`"
327+ @assert drop_path_rate== 0.0 " DropPath not supported for `bottle2neck`"
328+ @assert reduction_factor== 1 " Reduction factor not supported for `bottle2neck`"
329+ get_layers = bottle2neck_builder (block_repeats; inplanes, cardinality, base_width,
330+ activation, norm_layer, revnorm, attn_fn,
331+ stride_fn = resnet_stride,
332+ planes_fn = resnet_planes,
333+ downsample_tuple = downsample_opt,
334+ kwargs... )
321335 else
322336 # TODO : write better message when we have link to dev docs for resnet
323337 throw (ArgumentError (" Unknown block type $block_type " ))
@@ -328,12 +342,16 @@ function resnet(block_type::Symbol, block_repeats::AbstractVector{<:Integer};
328342 connection$ activation, classifier_fn)
329343end
330344function resnet (block_fn, block_repeats, downsample_opt:: Symbol = :B ; kwargs... )
331- return resnet (block_fn, block_repeats, shortcut_dict [downsample_opt]; kwargs... )
345+ return resnet (block_fn, block_repeats, RESNET_SHORTCUTS [downsample_opt]; kwargs... )
332346end
333347
334348# block-layer configurations for ResNet-like models
335- const RESNET_CONFIGS = Dict (18 => (:basicblock , [2 , 2 , 2 , 2 ]),
336- 34 => (:basicblock , [3 , 4 , 6 , 3 ]),
337- 50 => (:bottleneck , [3 , 4 , 6 , 3 ]),
338- 101 => (:bottleneck , [3 , 4 , 23 , 3 ]),
339- 152 => (:bottleneck , [3 , 8 , 36 , 3 ]))
349+ const RESNET_CONFIGS = Dict (18 => (basicblock, [2 , 2 , 2 , 2 ]),
350+ 34 => (basicblock, [3 , 4 , 6 , 3 ]),
351+ 50 => (bottleneck, [3 , 4 , 6 , 3 ]),
352+ 101 => (bottleneck, [3 , 4 , 23 , 3 ]),
353+ 152 => (bottleneck, [3 , 8 , 36 , 3 ]))
354+
355+ const LRESNET_CONFIGS = Dict (50 => (bottleneck, [3 , 4 , 6 , 3 ]),
356+ 101 => (bottleneck, [3 , 4 , 23 , 3 ]),
357+ 152 => (bottleneck, [3 , 8 , 36 , 3 ]))
0 commit comments