Skip to content

Commit 32012a4

Browse files
committed
Added SelecSLS model
1 parent 187ecba commit 32012a4

File tree

2 files changed

+364
-0
lines changed

2 files changed

+364
-0
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .xception import *
88
from .nasnet import *
99
from .pnasnet import *
10+
from .selecsls import *
1011
from .gen_efficientnet import *
1112
from .inception_v3 import *
1213
from .gluon_resnet import *

timm/models/selecsls.py

Lines changed: 363 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,363 @@
1+
"""PyTorch SelecSLS on ImageNet
2+
3+
Based on ResNet implementation in this repository
4+
SelecSLS (core) Network Architecture as proposed in
5+
XNect: Real-time Multi-person 3D Human Pose Estimation with a Single RGB Camera, Mehta et al.
6+
https://arxiv.org/abs/1907.00837
7+
8+
Implementation by Dushyant Mehta (@mehtadushy)
9+
"""
10+
import math
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
16+
from .registry import register_model
17+
from .helpers import load_pretrained
18+
from .adaptive_avgmax_pool import SelectAdaptivePool2d
19+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20+
21+
22+
__all__ = ['SelecSLS'] # model_registry will add each entrypoint fn to this
23+
24+
25+
def _cfg(url='', **kwargs):
26+
return {
27+
'url': url,
28+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (3, 3),
29+
'crop_pct': 0.875, 'interpolation': 'bilinear',
30+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
31+
'first_conv': 'stem', 'classifier': 'fc',
32+
**kwargs
33+
}
34+
35+
36+
default_cfgs = {
37+
'selecsls42': _cfg(
38+
url='',
39+
interpolation='bicubic'),
40+
'selecsls60': _cfg(
41+
url='',
42+
interpolation='bicubic'),
43+
'selecsls60NH': _cfg(
44+
url='',
45+
interpolation='bicubic'),
46+
'selecsls84': _cfg(
47+
url='',
48+
interpolation='bicubic'),
49+
}
50+
51+
52+
def conv_bn(inp, oup, stride):
53+
return nn.Sequential(
54+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
55+
nn.BatchNorm2d(oup),
56+
nn.ReLU(inplace=True)
57+
)
58+
59+
60+
def conv_1x1_bn(inp, oup):
61+
return nn.Sequential(
62+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
63+
nn.BatchNorm2d(oup),
64+
nn.ReLU(inplace=True)
65+
)
66+
67+
class SelecSLSBlock(nn.Module):
68+
def __init__(self, inp, skip, k, oup, isFirst, stride):
69+
super(SelecSLSBlock, self).__init__()
70+
self.stride = stride
71+
self.isFirst = isFirst
72+
assert stride in [1, 2]
73+
74+
#Process input with 4 conv blocks with the same number of input and output channels
75+
self.conv1 = nn.Sequential(
76+
nn.Conv2d(inp, k, 3, stride, 1,groups= 1, bias=False, dilation=1),
77+
nn.BatchNorm2d(k),
78+
nn.ReLU(inplace=True)
79+
)
80+
self.conv2 = nn.Sequential(
81+
nn.Conv2d(k, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
82+
nn.BatchNorm2d(k),
83+
nn.ReLU(inplace=True)
84+
)
85+
self.conv3 = nn.Sequential(
86+
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
87+
nn.BatchNorm2d(k//2),
88+
nn.ReLU(inplace=True)
89+
)
90+
self.conv4 = nn.Sequential(
91+
nn.Conv2d(k//2, k, 1, 1, 0,groups= 1, bias=False, dilation=1),
92+
nn.BatchNorm2d(k),
93+
nn.ReLU(inplace=True)
94+
)
95+
self.conv5 = nn.Sequential(
96+
nn.Conv2d(k, k//2, 3, 1, 1,groups= 1, bias=False, dilation=1),
97+
nn.BatchNorm2d(k//2),
98+
nn.ReLU(inplace=True)
99+
)
100+
self.conv6 = nn.Sequential(
101+
nn.Conv2d(2*k + (0 if isFirst else skip), oup, 1, 1, 0,groups= 1, bias=False, dilation=1),
102+
nn.BatchNorm2d(oup),
103+
nn.ReLU(inplace=True)
104+
)
105+
106+
def forward(self, x):
107+
assert isinstance(x,list)
108+
assert len(x) in [1,2]
109+
110+
d1 = self.conv1(x[0])
111+
d2 = self.conv3(self.conv2(d1))
112+
d3 = self.conv5(self.conv4(d2))
113+
if self.isFirst:
114+
out = self.conv6(torch.cat([d1, d2, d3], 1))
115+
return [out, out]
116+
else:
117+
return [self.conv6(torch.cat([d1, d2, d3, x[1]], 1)) , x[1]]
118+
119+
class SelecSLS(nn.Module):
120+
"""SelecSLS42 / SelecSLS60 / SelecSLS84
121+
122+
Parameters
123+
----------
124+
cfg : network config
125+
String indicating the network config
126+
num_classes : int, default 1000
127+
Number of classification classes.
128+
in_chans : int, default 3
129+
Number of input (color) channels.
130+
drop_rate : float, default 0.
131+
Dropout probability before classifier, for training
132+
global_pool : str, default 'avg'
133+
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
134+
"""
135+
def __init__(self, cfg='selecsls60', num_classes=1000, in_chans=3,
136+
drop_rate=0.0, global_pool='avg'):
137+
self.num_classes = num_classes
138+
self.drop_rate = drop_rate
139+
super(SelecSLS, self).__init__()
140+
141+
self.stem = conv_bn(in_chans, 32, 2)
142+
#Core Network
143+
self.features = []
144+
if cfg=='selecsls42':
145+
self.block = SelecSLSBlock
146+
#Define configuration of the network after the initial neck
147+
self.selecSLS_config = [
148+
#inp,skip, k, oup, isFirst, stride
149+
[ 32, 0, 64, 64, True, 2],
150+
[ 64, 64, 64, 128, False, 1],
151+
[128, 0, 144, 144, True, 2],
152+
[144, 144, 144, 288, False, 1],
153+
[288, 0, 304, 304, True, 2],
154+
[304, 304, 304, 480, False, 1],
155+
]
156+
#Head can be replaced with alternative configurations depending on the problem
157+
self.head = nn.Sequential(
158+
conv_bn(480, 960, 2),
159+
conv_bn(960, 1024, 1),
160+
conv_bn(1024, 1024, 2),
161+
conv_1x1_bn(1024, 1280),
162+
)
163+
self.num_features = 1280
164+
elif cfg=='selecsls42NH':
165+
self.block = SelecSLSBlock
166+
#Define configuration of the network after the initial neck
167+
self.selecSLS_config = [
168+
#inp,skip, k, oup, isFirst, stride
169+
[ 32, 0, 64, 64, True, 2],
170+
[ 64, 64, 64, 128, False, 1],
171+
[128, 0, 144, 144, True, 2],
172+
[144, 144, 144, 288, False, 1],
173+
[288, 0, 304, 304, True, 2],
174+
[304, 304, 304, 480, False, 1],
175+
]
176+
#Head can be replaced with alternative configurations depending on the problem
177+
self.head = nn.Sequential(
178+
conv_bn(480, 960, 2),
179+
conv_bn(960, 1024, 1),
180+
conv_bn(1024, 1280, 2),
181+
conv_1x1_bn(1280, 1024),
182+
)
183+
self.num_features = 1024
184+
elif cfg=='selecsls60':
185+
self.block = SelecSLSBlock
186+
#Define configuration of the network after the initial neck
187+
self.selecSLS_config = [
188+
#inp,skip, k, oup, isFirst, stride
189+
[ 32, 0, 64, 64, True, 2],
190+
[ 64, 64, 64, 128, False, 1],
191+
[128, 0, 128, 128, True, 2],
192+
[128, 128, 128, 128, False, 1],
193+
[128, 128, 128, 288, False, 1],
194+
[288, 0, 288, 288, True, 2],
195+
[288, 288, 288, 288, False, 1],
196+
[288, 288, 288, 288, False, 1],
197+
[288, 288, 288, 416, False, 1],
198+
]
199+
#Head can be replaced with alternative configurations depending on the problem
200+
self.head = nn.Sequential(
201+
conv_bn(416, 756, 2),
202+
conv_bn(756, 1024, 1),
203+
conv_bn(1024, 1024, 2),
204+
conv_1x1_bn(1024, 1280),
205+
)
206+
self.num_features = 1280
207+
elif cfg=='selecsls60NH':
208+
self.block = SelecSLSBlock
209+
#Define configuration of the network after the initial neck
210+
self.selecSLS_config = [
211+
#inp,skip, k, oup, isFirst, stride
212+
[ 32, 0, 64, 64, True, 2],
213+
[ 64, 64, 64, 128, False, 1],
214+
[128, 0, 128, 128, True, 2],
215+
[128, 128, 128, 128, False, 1],
216+
[128, 128, 128, 288, False, 1],
217+
[288, 0, 288, 288, True, 2],
218+
[288, 288, 288, 288, False, 1],
219+
[288, 288, 288, 288, False, 1],
220+
[288, 288, 288, 416, False, 1],
221+
]
222+
#Head can be replaced with alternative configurations depending on the problem
223+
self.head = nn.Sequential(
224+
conv_bn(416, 756, 2),
225+
conv_bn(756, 1024, 1),
226+
conv_bn(1024, 1280, 2),
227+
conv_1x1_bn(1280, 1024),
228+
)
229+
self.num_features = 1024
230+
elif cfg=='selecsls84':
231+
self.block = SelecSLSBlock
232+
#Define configuration of the network after the initial neck
233+
self.selecSLS_config = [
234+
#inp,skip, k, oup, isFirst, stride
235+
[ 32, 0, 64, 64, True, 2],
236+
[ 64, 64, 64, 144, False, 1],
237+
[144, 0, 144, 144, True, 2],
238+
[144, 144, 144, 144, False, 1],
239+
[144, 144, 144, 144, False, 1],
240+
[144, 144, 144, 144, False, 1],
241+
[144, 144, 144, 304, False, 1],
242+
[304, 0, 304, 304, True, 2],
243+
[304, 304, 304, 304, False, 1],
244+
[304, 304, 304, 304, False, 1],
245+
[304, 304, 304, 304, False, 1],
246+
[304, 304, 304, 304, False, 1],
247+
[304, 304, 304, 512, False, 1],
248+
]
249+
#Head can be replaced with alternative configurations depending on the problem
250+
self.head = nn.Sequential(
251+
conv_bn(512, 960, 2),
252+
conv_bn(960, 1024, 1),
253+
conv_bn(1024, 1024, 2),
254+
conv_1x1_bn(1024, 1280),
255+
)
256+
self.num_features = 1280
257+
else:
258+
raise ValueError('Invalid net configuration '+cfg+' !!!')
259+
260+
for inp, skip, k, oup, isFirst, stride in self.selecSLS_config:
261+
self.features.append(self.block(inp, skip, k, oup, isFirst, stride))
262+
self.features = nn.Sequential(*self.features)
263+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
264+
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
265+
266+
for n, m in self.named_modules():
267+
if isinstance(m, nn.Conv2d):
268+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
269+
elif isinstance(m, nn.BatchNorm2d):
270+
nn.init.constant_(m.weight, 1.)
271+
nn.init.constant_(m.bias, 0.)
272+
273+
def get_classifier(self):
274+
return self.fc
275+
276+
def reset_classifier(self, num_classes, global_pool='avg'):
277+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
278+
self.num_classes = num_classes
279+
del self.fc
280+
if num_classes:
281+
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
282+
else:
283+
self.fc = None
284+
285+
def forward_features(self, x, pool=True):
286+
x = self.stem(x)
287+
x = self.features([x])
288+
x = self.head(x[0])
289+
290+
if pool:
291+
x = self.global_pool(x)
292+
x = x.view(x.size(0), -1)
293+
return x
294+
295+
def forward(self, x):
296+
x = self.forward_features(x)
297+
if self.drop_rate > 0.:
298+
x = F.dropout(x, p=self.drop_rate, training=self.training)
299+
x = self.fc(x)
300+
return x
301+
302+
303+
@register_model
304+
def selecsls42(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
305+
"""Constructs a SelecSLS42 model.
306+
"""
307+
default_cfg = default_cfgs['selecsls42']
308+
model = SelecSLS(
309+
cfg='selecsls42', num_classes=1000, in_chans=3, **kwargs)
310+
model.default_cfg = default_cfg
311+
if pretrained:
312+
load_pretrained(model, default_cfg, num_classes, in_chans)
313+
return model
314+
315+
@register_model
316+
def selecsls42NH(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
317+
"""Constructs a SelecSLS42NH model.
318+
"""
319+
default_cfg = default_cfgs['selecsls42NH']
320+
model = SelecSLS(
321+
cfg='selecsls42NH', num_classes=1000, in_chans=3,**kwargs)
322+
model.default_cfg = default_cfg
323+
if pretrained:
324+
load_pretrained(model, default_cfg, num_classes, in_chans)
325+
return model
326+
327+
@register_model
328+
def selecsls60(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
329+
"""Constructs a SelecSLS60 model.
330+
"""
331+
default_cfg = default_cfgs['selecsls60']
332+
model = SelecSLS(
333+
cfg='selecsls60', num_classes=1000, in_chans=3,**kwargs)
334+
model.default_cfg = default_cfg
335+
if pretrained:
336+
load_pretrained(model, default_cfg, num_classes, in_chans)
337+
return model
338+
339+
340+
@register_model
341+
def selecsls60NH(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
342+
"""Constructs a SelecSLS60NH model.
343+
"""
344+
default_cfg = default_cfgs['selecsls60NH']
345+
model = SelecSLS(
346+
cfg='selecsls60NH', num_classes=1000, in_chans=3,**kwargs)
347+
model.default_cfg = default_cfg
348+
if pretrained:
349+
load_pretrained(model, default_cfg, num_classes, in_chans)
350+
return model
351+
352+
@register_model
353+
def selecsls84(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
354+
"""Constructs a SelecSLS84 model.
355+
"""
356+
default_cfg = default_cfgs['selecsls84']
357+
model = SelecSLS(
358+
cfg='selecsls84', num_classes=1000, in_chans=3, **kwargs)
359+
model.default_cfg = default_cfg
360+
if pretrained:
361+
load_pretrained(model, default_cfg, num_classes, in_chans)
362+
return model
363+

0 commit comments

Comments
 (0)