Skip to content

Commit 6cc214b

Browse files
committed
Consistency in model entrypoints
* move pretrained entrypoint arg to first pos to be closer to torchvision/hub * change DPN weight URLS to my github location
1 parent b20bb58 commit 6cc214b

File tree

12 files changed

+125
-113
lines changed

12 files changed

+125
-113
lines changed

timm/models/densenet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _filter_pretrained(state_dict):
4343
return state_dict
4444

4545

46-
def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
46+
def densenet121(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
4747
r"""Densenet-121 model from
4848
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
4949
"""
@@ -56,7 +56,7 @@ def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
5656
return model
5757

5858

59-
def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
59+
def densenet169(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
6060
r"""Densenet-169 model from
6161
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
6262
"""
@@ -69,7 +69,7 @@ def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
6969
return model
7070

7171

72-
def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
72+
def densenet201(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
7373
r"""Densenet-201 model from
7474
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
7575
"""
@@ -82,7 +82,7 @@ def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
8282
return model
8383

8484

85-
def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
85+
def densenet161(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
8686
r"""Densenet-201 model from
8787
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
8888
"""

timm/models/dpn.py

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import torch
1313
import torch.nn as nn
14+
import torch.nn.functional as F
1415
from collections import OrderedDict
1516

1617
from .helpers import load_pretrained
@@ -31,81 +32,87 @@ def _cfg(url=''):
3132

3233

3334
default_cfgs = {
34-
'dpn68': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
35-
'dpn68b_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
36-
'dpn92_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
37-
'dpn98': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
38-
'dpn131': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
39-
'dpn107_extra': _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
35+
'dpn68': _cfg(
36+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
37+
'dpn68b_extra': _cfg(
38+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
39+
'dpn92_extra': _cfg(
40+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
41+
'dpn98': _cfg(
42+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
43+
'dpn131': _cfg(
44+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn131-71dfe43e0.pth'),
45+
'dpn107_extra': _cfg(
46+
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn107_extra-1ac7121e2.pth')
4047
}
4148

4249

43-
def dpn68(num_classes=1000, in_chans=3, pretrained=False):
50+
def dpn68(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
4451
default_cfg = default_cfgs['dpn68']
4552
model = DPN(
4653
small=True, num_init_features=10, k_r=128, groups=32,
4754
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
48-
num_classes=num_classes, in_chans=in_chans)
55+
num_classes=num_classes, in_chans=in_chans, **kwargs)
4956
model.default_cfg = default_cfg
5057
if pretrained:
5158
load_pretrained(model, default_cfg, num_classes, in_chans)
5259
return model
5360

5461

55-
def dpn68b(num_classes=1000, in_chans=3, pretrained=False):
62+
def dpn68b(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
5663
default_cfg = default_cfgs['dpn68b_extra']
5764
model = DPN(
5865
small=True, num_init_features=10, k_r=128, groups=32,
5966
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
60-
num_classes=num_classes, in_chans=in_chans)
67+
num_classes=num_classes, in_chans=in_chans, **kwargs)
6168
model.default_cfg = default_cfg
6269
if pretrained:
6370
load_pretrained(model, default_cfg, num_classes, in_chans)
6471
return model
6572

6673

67-
def dpn92(num_classes=1000, in_chans=3, pretrained=False):
74+
def dpn92(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
6875
default_cfg = default_cfgs['dpn92_extra']
6976
model = DPN(
7077
num_init_features=64, k_r=96, groups=32,
7178
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
72-
num_classes=num_classes, in_chans=in_chans)
79+
num_classes=num_classes, in_chans=in_chans, **kwargs)
7380
model.default_cfg = default_cfg
7481
if pretrained:
7582
load_pretrained(model, default_cfg, num_classes, in_chans)
7683
return model
7784

7885

79-
def dpn98(num_classes=1000, in_chans=3, pretrained=False):
86+
def dpn98(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
8087
default_cfg = default_cfgs['dpn98']
8188
model = DPN(
8289
num_init_features=96, k_r=160, groups=40,
8390
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
84-
num_classes=num_classes, in_chans=in_chans)
91+
num_classes=num_classes, in_chans=in_chans, **kwargs)
8592
model.default_cfg = default_cfg
8693
if pretrained:
8794
load_pretrained(model, default_cfg, num_classes, in_chans)
8895
return model
8996

9097

91-
def dpn131(num_classes=1000, in_chans=3, pretrained=False):
98+
def dpn131(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
9299
default_cfg = default_cfgs['dpn131']
93100
model = DPN(
94101
num_init_features=128, k_r=160, groups=40,
95102
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
96-
num_classes=num_classes, in_chans=in_chans)
103+
num_classes=num_classes, in_chans=in_chans, **kwargs)
97104
model.default_cfg = default_cfg
98105
if pretrained:
99106
load_pretrained(model, default_cfg, num_classes, in_chans)
100107
return model
101108

102109

103-
def dpn107(num_classes=1000, in_chans=3, pretrained=False):
110+
def dpn107(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
104111
default_cfg = default_cfgs['dpn107_extra']
105112
model = DPN(
106113
num_init_features=128, k_r=200, groups=50,
107114
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
108-
num_classes=num_classes, in_chans=in_chans)
115+
num_classes=num_classes, in_chans=in_chans, **kwargs)
109116
model.default_cfg = default_cfg
110117
if pretrained:
111118
load_pretrained(model, default_cfg, num_classes, in_chans)
@@ -220,9 +227,11 @@ def forward(self, x):
220227
class DPN(nn.Module):
221228
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
222229
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
223-
num_classes=1000, in_chans=3, fc_act=nn.ELU(inplace=True)):
230+
num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg', fc_act=nn.ELU()):
224231
super(DPN, self).__init__()
225232
self.num_classes = num_classes
233+
self.drop_rate = drop_rate
234+
self.global_pool = global_pool
226235
self.b = b
227236
bw_factor = 1 if small else 4
228237

@@ -285,8 +294,9 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
285294
def get_classifier(self):
286295
return self.classifier
287296

288-
def reset_classifier(self, num_classes):
297+
def reset_classifier(self, num_classes, global_pool='avg'):
289298
self.num_classes = num_classes
299+
self.global_pool = global_pool
290300
del self.classifier
291301
if num_classes:
292302
self.classifier = nn.Conv2d(self.num_features, num_classes, kernel_size=1, bias=True)
@@ -296,11 +306,13 @@ def reset_classifier(self, num_classes):
296306
def forward_features(self, x, pool=True):
297307
x = self.features(x)
298308
if pool:
299-
x = select_adaptive_pool2d(x, pool_type='avg')
309+
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
300310
return x
301311

302312
def forward(self, x):
303313
x = self.forward_features(x)
314+
if self.drop_rate > 0.:
315+
x = F.dropout(x, p=self.drop_rate, training=self.training)
304316
out = self.classifier(x)
305317
return out.view(out.size(0), -1)
306318

0 commit comments

Comments
 (0)