@@ -138,6 +138,9 @@ def _cfg(url='', **kwargs):
138138 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth' ),
139139 'mixnet_l' : _cfg (
140140 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth' ),
141+ 'mixnet_xl' : _cfg (
142+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl-ac5fbe8d.pth' ),
143+ 'mixnet_xxl' : _cfg (),
141144 'tf_mixnet_s' : _cfg (
142145 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth' ),
143146 'tf_mixnet_m' : _cfg (
@@ -312,21 +315,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
312315 else :
313316 assert False , 'Unknown block type (%s)' % block_type
314317
315- # return a list of block args expanded by num_repeat and
316- # scaled by depth_multiplier
317- num_repeat = int (math .ceil (num_repeat * depth_multiplier ))
318- return [deepcopy (block_args ) for _ in range (num_repeat )]
318+ return block_args , num_repeat
319319
320320
321- def _decode_arch_def (arch_def , depth_multiplier = 1.0 ):
321+ def _scale_stage_depth (stack_args , repeats , depth_multiplier = 1.0 , depth_trunc = 'ceil' ):
322+ """ Per-stage depth scaling
323+ Scales the block repeats in each stage. This depth scaling impl maintains
324+ compatibility with the EfficientNet scaling method, while allowing sensible
325+ scaling for other models that may have multiple block arg definitions in each stage.
326+ """
327+
328+ # We scale the total repeat count for each stage, there may be multiple
329+ # block arg defs per stage so we need to sum.
330+ num_repeat = sum (repeats )
331+ if depth_trunc == 'round' :
332+ # Truncating to int by rounding allows stages with few repeats to remain
333+ # proportionally smaller for longer. This is a good choice when stage definitions
334+ # include single repeat stages that we'd prefer to keep that way as long as possible
335+ num_repeat_scaled = max (1 , round (num_repeat * depth_multiplier ))
336+ else :
337+ # The default for EfficientNet truncates repeats to int via 'ceil'.
338+ # Any multiplier > 1.0 will result in an increased depth for every stage.
339+ num_repeat_scaled = int (math .ceil (num_repeat * depth_multiplier ))
340+
341+ # Proportionally distribute repeat count scaling to each block definition in the stage.
342+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
343+ # The first block makes less sense to repeat in most of the arch definitions.
344+ repeats_scaled = []
345+ for r in repeats [::- 1 ]:
346+ rs = max (1 , round ((r / num_repeat * num_repeat_scaled )))
347+ repeats_scaled .append (rs )
348+ num_repeat -= r
349+ num_repeat_scaled -= rs
350+ repeats_scaled = repeats_scaled [::- 1 ]
351+
352+ # Apply the calculated scaling to each block arg in the stage
353+ sa_scaled = []
354+ for ba , rep in zip (stack_args , repeats_scaled ):
355+ sa_scaled .extend ([deepcopy (ba ) for _ in range (rep )])
356+ return sa_scaled
357+
358+
359+ def _decode_arch_def (arch_def , depth_multiplier = 1.0 , depth_trunc = 'ceil' ):
322360 arch_args = []
323361 for stack_idx , block_strings in enumerate (arch_def ):
324362 assert isinstance (block_strings , list )
325363 stack_args = []
364+ repeats = []
326365 for block_str in block_strings :
327366 assert isinstance (block_str , str )
328- stack_args .extend (_decode_block_str (block_str , depth_multiplier ))
329- arch_args .append (stack_args )
367+ ba , rep = _decode_block_str (block_str )
368+ stack_args .append (ba )
369+ repeats .append (rep )
370+ arch_args .append (_scale_stage_depth (stack_args , repeats , depth_multiplier , depth_trunc ))
330371 return arch_args
331372
332373
@@ -1261,7 +1302,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
12611302 return model
12621303
12631304
1264- def _gen_mixnet_m (channel_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
1305+ def _gen_mixnet_m (channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
12651306 """Creates a MixNet Medium-Large model.
12661307
12671308 Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
@@ -1283,7 +1324,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
12831324 # 7x7
12841325 ]
12851326 model = GenEfficientNet (
1286- _decode_arch_def (arch_def ),
1327+ _decode_arch_def (arch_def , depth_multiplier = depth_multiplier , depth_trunc = 'round' ),
12871328 num_classes = num_classes ,
12881329 stem_size = 24 ,
12891330 num_features = 1536 ,
@@ -1876,6 +1917,36 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
18761917 return model
18771918
18781919
1920+ @register_model
1921+ def mixnet_xl (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1922+ """Creates a MixNet Extra-Large model.
1923+ Not a paper spec, experimental def by RW w/ depth scaling.
1924+ """
1925+ default_cfg = default_cfgs ['mixnet_xl' ]
1926+ #kwargs['drop_connect_rate'] = 0.2
1927+ model = _gen_mixnet_m (
1928+ channel_multiplier = 1.6 , depth_multiplier = 1.2 , num_classes = num_classes , in_chans = in_chans , ** kwargs )
1929+ model .default_cfg = default_cfg
1930+ if pretrained :
1931+ load_pretrained (model , default_cfg , num_classes , in_chans )
1932+ return model
1933+
1934+
1935+ @register_model
1936+ def mixnet_xxl (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1937+ """Creates a MixNet Double Extra Large model.
1938+ Not a paper spec, experimental def by RW w/ depth scaling.
1939+ """
1940+ default_cfg = default_cfgs ['mixnet_xxl' ]
1941+ # kwargs['drop_connect_rate'] = 0.2
1942+ model = _gen_mixnet_m (
1943+ channel_multiplier = 2.4 , depth_multiplier = 1.3 , num_classes = num_classes , in_chans = in_chans , ** kwargs )
1944+ model .default_cfg = default_cfg
1945+ if pretrained :
1946+ load_pretrained (model , default_cfg , num_classes , in_chans )
1947+ return model
1948+
1949+
18791950@register_model
18801951def tf_mixnet_s (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
18811952 """Creates a MixNet Small model. Tensorflow compatible variant
0 commit comments