@@ -91,6 +91,12 @@ def _cfg(url='', **kwargs):
9191 url = 'https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x8-b4712904.pth' ),
9292 'swsl_resnext101_32x16d' : _cfg (
9393 url = 'https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth' ),
94+ 'seresnext26d_32x4d' : _cfg (
95+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth' ,
96+ interpolation = 'bicubic' ),
97+ 'seresnext26t_32x4d' : _cfg (
98+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26t_32x4d-361bc1c4.pth' ,
99+ interpolation = 'bicubic' ),
94100}
95101
96102
@@ -231,10 +237,11 @@ class ResNet(nn.Module):
231237
232238 ResNet variants:
233239 * normal, b - 7x7 stem, stem_width = 64, same as torchvision ResNet, NVIDIA ResNet 'v1.5', Gluon v1b
234- * c - 3 layer deep 3x3 stem, stem_width = 32
235- * d - 3 layer deep 3x3 stem, stem_width = 32, average pool in downsample
236- * e - 3 layer deep 3x3 stem, stem_width = 64, average pool in downsample
237- * s - 3 layer deep 3x3 stem, stem_width = 64
240+ * c - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64)
241+ * d - 3 layer deep 3x3 stem, stem_width = 32 (32, 32, 64), average pool in downsample
242+ * e - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128), average pool in downsample
243+ * s - 3 layer deep 3x3 stem, stem_width = 64 (64, 64, 128)
244+ * t - 3 layer deep 3x3 stem, stem width = 32 (24, 48, 64), average pool in downsample
238245
239246 ResNeXt
240247 * normal - 7x7 stem, stem_width = 64, standard cardinality and base widths
@@ -263,10 +270,13 @@ class ResNet(nn.Module):
263270 Number of convolution groups for 3x3 conv in Bottleneck.
264271 base_width : int, default 64
265272 Factor determining bottleneck channels. `planes * base_width / 64 * cardinality`
266- deep_stem : bool, default False
267- Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
268273 stem_width : int, default 64
269274 Number of channels in stem convolutions
275+ stem_type : str, default ''
276+ The type of stem:
277+ * '', default - a single 7x7 conv with a width of stem_width
278+ * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2
279+ * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width//4 * 6, stem_width * 2
270280 block_reduce_first: int, default 1
271281 Reduction factor for first convolution output width of residual blocks,
272282 1 for all archs except senets, where 2
@@ -283,12 +293,13 @@ class ResNet(nn.Module):
283293 Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
284294 """
285295 def __init__ (self , block , layers , num_classes = 1000 , in_chans = 3 , use_se = False ,
286- cardinality = 1 , base_width = 64 , stem_width = 64 , deep_stem = False ,
296+ cardinality = 1 , base_width = 64 , stem_width = 64 , stem_type = '' ,
287297 block_reduce_first = 1 , down_kernel_size = 1 , avg_down = False , dilated = False ,
288298 norm_layer = nn .BatchNorm2d , drop_rate = 0.0 , global_pool = 'avg' ,
289299 zero_init_last_bn = True , block_args = None ):
290300 block_args = block_args or dict ()
291301 self .num_classes = num_classes
302+ deep_stem = 'deep' in stem_type
292303 self .inplanes = stem_width * 2 if deep_stem else 64
293304 self .cardinality = cardinality
294305 self .base_width = base_width
@@ -298,16 +309,20 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
298309 super (ResNet , self ).__init__ ()
299310
300311 if deep_stem :
312+ stem_chs_1 = stem_chs_2 = stem_width
313+ if 'tiered' in stem_type :
314+ stem_chs_1 = 3 * (stem_width // 4 )
315+ stem_chs_2 = 6 * (stem_width // 4 )
301316 self .conv1 = nn .Sequential (* [
302- nn .Conv2d (in_chans , stem_width , 3 , stride = 2 , padding = 1 , bias = False ),
303- norm_layer (stem_width ),
317+ nn .Conv2d (in_chans , stem_chs_1 , 3 , stride = 2 , padding = 1 , bias = False ),
318+ norm_layer (stem_chs_1 ),
304319 nn .ReLU (inplace = True ),
305- nn .Conv2d (stem_width , stem_width , 3 , stride = 1 , padding = 1 , bias = False ),
306- norm_layer (stem_width ),
320+ nn .Conv2d (stem_chs_1 , stem_chs_2 , 3 , stride = 1 , padding = 1 , bias = False ),
321+ norm_layer (stem_chs_2 ),
307322 nn .ReLU (inplace = True ),
308- nn .Conv2d (stem_width , self .inplanes , 3 , stride = 1 , padding = 1 , bias = False )])
323+ nn .Conv2d (stem_chs_2 , self .inplanes , 3 , stride = 1 , padding = 1 , bias = False )])
309324 else :
310- self .conv1 = nn .Conv2d (in_chans , stem_width , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
325+ self .conv1 = nn .Conv2d (in_chans , self . inplanes , kernel_size = 7 , stride = 2 , padding = 3 , bias = False )
311326 self .bn1 = norm_layer (self .inplanes )
312327 self .relu = nn .ReLU (inplace = True )
313328 self .maxpool = nn .MaxPool2d (kernel_size = 3 , stride = 2 , padding = 1 )
@@ -324,7 +339,7 @@ def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
324339 self .num_features = 512 * block .expansion
325340 self .fc = nn .Linear (self .num_features * self .global_pool .feat_mult (), num_classes )
326341
327- last_bn_name = 'bn3' if 'Bottleneck ' in block .__name__ else 'bn2'
342+ last_bn_name = 'bn3' if 'Bottle ' in block .__name__ else 'bn2'
328343 for n , m in self .named_modules ():
329344 if isinstance (m , nn .Conv2d ):
330345 nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
@@ -440,7 +455,7 @@ def resnet26d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
440455 """
441456 default_cfg = default_cfgs ['resnet26d' ]
442457 model = ResNet (
443- Bottleneck , [2 , 2 , 2 , 2 ], stem_width = 32 , deep_stem = True , avg_down = True ,
458+ Bottleneck , [2 , 2 , 2 , 2 ], stem_width = 32 , stem_type = 'deep' , avg_down = True ,
444459 num_classes = num_classes , in_chans = in_chans , ** kwargs )
445460 model .default_cfg = default_cfg
446461 if pretrained :
@@ -466,7 +481,7 @@ def resnet50d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
466481 """
467482 default_cfg = default_cfgs ['resnet50d' ]
468483 model = ResNet (
469- Bottleneck , [3 , 4 , 6 , 3 ], stem_width = 32 , deep_stem = True , avg_down = True ,
484+ Bottleneck , [3 , 4 , 6 , 3 ], stem_width = 32 , stem_type = 'deep' , avg_down = True ,
470485 num_classes = num_classes , in_chans = in_chans , ** kwargs )
471486 model .default_cfg = default_cfg
472487 if pretrained :
@@ -574,7 +589,7 @@ def resnext50d_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
574589 default_cfg = default_cfgs ['resnext50d_32x4d' ]
575590 model = ResNet (
576591 Bottleneck , [3 , 4 , 6 , 3 ], cardinality = 32 , base_width = 4 ,
577- stem_width = 32 , deep_stem = True , avg_down = True ,
592+ stem_width = 32 , stem_type = 'deep' , avg_down = True ,
578593 num_classes = num_classes , in_chans = in_chans , ** kwargs )
579594 model .default_cfg = default_cfg
580595 if pretrained :
@@ -854,3 +869,34 @@ def swsl_resnext101_32x16d(pretrained=True, **kwargs):
854869 if pretrained :
855870 load_pretrained (model , num_classes = kwargs .get ('num_classes' , 0 ), in_chans = kwargs .get ('in_chans' , 3 ))
856871 return model
872+
873+
874+ @register_model
875+ def seresnext26d_32x4d (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
876+ """Constructs a ResNet-26 v1d model.
877+ This is technically a 28 layer ResNet, sticking with 'd' modifier from Gluon for now.
878+ """
879+ default_cfg = default_cfgs ['seresnext26d_32x4d' ]
880+ model = ResNet (
881+ Bottleneck , [2 , 2 , 2 , 2 ], cardinality = 32 , base_width = 4 ,
882+ stem_width = 32 , stem_type = 'deep' , avg_down = True , use_se = True ,
883+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
884+ model .default_cfg = default_cfg
885+ if pretrained :
886+ load_pretrained (model , default_cfg , num_classes , in_chans )
887+ return model
888+
889+
890+ @register_model
891+ def seresnext26t_32x4d (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
892+ """Constructs a ResNet-26 v1d model.
893+ """
894+ default_cfg = default_cfgs ['seresnext26t_32x4d' ]
895+ model = ResNet (
896+ Bottleneck , [2 , 2 , 2 , 2 ], cardinality = 32 , base_width = 4 ,
897+ stem_width = 32 , stem_type = 'deep_tiered' , avg_down = True , use_se = True ,
898+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
899+ model .default_cfg = default_cfg
900+ if pretrained :
901+ load_pretrained (model , default_cfg , num_classes , in_chans )
902+ return model
0 commit comments