1111
1212import torch
1313import torch .nn as nn
14+ import torch .nn .functional as F
1415from collections import OrderedDict
1516
1617from .helpers import load_pretrained
@@ -31,81 +32,87 @@ def _cfg(url=''):
3132
3233
3334default_cfgs = {
34- 'dpn68' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth' ),
35- 'dpn68b_extra' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth' ),
36- 'dpn92_extra' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth' ),
37- 'dpn98' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth' ),
38- 'dpn131' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth' ),
39- 'dpn107_extra' : _cfg (url = 'http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth' )
35+ 'dpn68' : _cfg (
36+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth' ),
37+ 'dpn68b_extra' : _cfg (
38+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth' ),
39+ 'dpn92_extra' : _cfg (
40+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth' ),
41+ 'dpn98' : _cfg (
42+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth' ),
43+ 'dpn131' : _cfg (
44+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth' ),
45+ 'dpn107_extra' : _cfg (
46+ url = 'https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth' )
4047}
4148
4249
43- def dpn68 (num_classes = 1000 , in_chans = 3 , pretrained = False ):
50+ def dpn68 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
4451 default_cfg = default_cfgs ['dpn68' ]
4552 model = DPN (
4653 small = True , num_init_features = 10 , k_r = 128 , groups = 32 ,
4754 k_sec = (3 , 4 , 12 , 3 ), inc_sec = (16 , 32 , 32 , 64 ),
48- num_classes = num_classes , in_chans = in_chans )
55+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
4956 model .default_cfg = default_cfg
5057 if pretrained :
5158 load_pretrained (model , default_cfg , num_classes , in_chans )
5259 return model
5360
5461
55- def dpn68b (num_classes = 1000 , in_chans = 3 , pretrained = False ):
62+ def dpn68b (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
5663 default_cfg = default_cfgs ['dpn68b_extra' ]
5764 model = DPN (
5865 small = True , num_init_features = 10 , k_r = 128 , groups = 32 ,
5966 b = True , k_sec = (3 , 4 , 12 , 3 ), inc_sec = (16 , 32 , 32 , 64 ),
60- num_classes = num_classes , in_chans = in_chans )
67+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
6168 model .default_cfg = default_cfg
6269 if pretrained :
6370 load_pretrained (model , default_cfg , num_classes , in_chans )
6471 return model
6572
6673
67- def dpn92 (num_classes = 1000 , in_chans = 3 , pretrained = False ):
74+ def dpn92 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
6875 default_cfg = default_cfgs ['dpn92_extra' ]
6976 model = DPN (
7077 num_init_features = 64 , k_r = 96 , groups = 32 ,
7178 k_sec = (3 , 4 , 20 , 3 ), inc_sec = (16 , 32 , 24 , 128 ),
72- num_classes = num_classes , in_chans = in_chans )
79+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
7380 model .default_cfg = default_cfg
7481 if pretrained :
7582 load_pretrained (model , default_cfg , num_classes , in_chans )
7683 return model
7784
7885
79- def dpn98 (num_classes = 1000 , in_chans = 3 , pretrained = False ):
86+ def dpn98 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
8087 default_cfg = default_cfgs ['dpn98' ]
8188 model = DPN (
8289 num_init_features = 96 , k_r = 160 , groups = 40 ,
8390 k_sec = (3 , 6 , 20 , 3 ), inc_sec = (16 , 32 , 32 , 128 ),
84- num_classes = num_classes , in_chans = in_chans )
91+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
8592 model .default_cfg = default_cfg
8693 if pretrained :
8794 load_pretrained (model , default_cfg , num_classes , in_chans )
8895 return model
8996
9097
91- def dpn131 (num_classes = 1000 , in_chans = 3 , pretrained = False ):
98+ def dpn131 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
9299 default_cfg = default_cfgs ['dpn131' ]
93100 model = DPN (
94101 num_init_features = 128 , k_r = 160 , groups = 40 ,
95102 k_sec = (4 , 8 , 28 , 3 ), inc_sec = (16 , 32 , 32 , 128 ),
96- num_classes = num_classes , in_chans = in_chans )
103+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
97104 model .default_cfg = default_cfg
98105 if pretrained :
99106 load_pretrained (model , default_cfg , num_classes , in_chans )
100107 return model
101108
102109
103- def dpn107 (num_classes = 1000 , in_chans = 3 , pretrained = False ):
110+ def dpn107 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
104111 default_cfg = default_cfgs ['dpn107_extra' ]
105112 model = DPN (
106113 num_init_features = 128 , k_r = 200 , groups = 50 ,
107114 k_sec = (4 , 8 , 20 , 3 ), inc_sec = (20 , 64 , 64 , 128 ),
108- num_classes = num_classes , in_chans = in_chans )
115+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
109116 model .default_cfg = default_cfg
110117 if pretrained :
111118 load_pretrained (model , default_cfg , num_classes , in_chans )
@@ -220,9 +227,11 @@ def forward(self, x):
220227class DPN (nn .Module ):
221228 def __init__ (self , small = False , num_init_features = 64 , k_r = 96 , groups = 32 ,
222229 b = False , k_sec = (3 , 4 , 20 , 3 ), inc_sec = (16 , 32 , 24 , 128 ),
223- num_classes = 1000 , in_chans = 3 , fc_act = nn .ELU (inplace = True )):
230+ num_classes = 1000 , in_chans = 3 , drop_rate = 0. , global_pool = 'avg' , fc_act = nn .ELU ()):
224231 super (DPN , self ).__init__ ()
225232 self .num_classes = num_classes
233+ self .drop_rate = drop_rate
234+ self .global_pool = global_pool
226235 self .b = b
227236 bw_factor = 1 if small else 4
228237
@@ -285,8 +294,9 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
285294 def get_classifier (self ):
286295 return self .classifier
287296
288- def reset_classifier (self , num_classes ):
297+ def reset_classifier (self , num_classes , global_pool = 'avg' ):
289298 self .num_classes = num_classes
299+ self .global_pool = global_pool
290300 del self .classifier
291301 if num_classes :
292302 self .classifier = nn .Conv2d (self .num_features , num_classes , kernel_size = 1 , bias = True )
@@ -296,11 +306,13 @@ def reset_classifier(self, num_classes):
296306 def forward_features (self , x , pool = True ):
297307 x = self .features (x )
298308 if pool :
299- x = select_adaptive_pool2d (x , pool_type = 'avg' )
309+ x = select_adaptive_pool2d (x , pool_type = self . global_pool )
300310 return x
301311
302312 def forward (self , x ):
303313 x = self .forward_features (x )
314+ if self .drop_rate > 0. :
315+ x = F .dropout (x , p = self .drop_rate , training = self .training )
304316 out = self .classifier (x )
305317 return out .view (out .size (0 ), - 1 )
306318
0 commit comments