Skip to content

Commit b60a4fb

Browse files
committed
add starnet
1 parent 081e6c2 commit b60a4fb

File tree

2 files changed

+345
-0
lines changed

2 files changed

+345
-0
lines changed

timm/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from .senet import *
6262
from .sequencer import *
6363
from .sknet import *
64+
from .starnet import *
6465
from .swiftformer import *
6566
from .swin_transformer import *
6667
from .swin_transformer_v2 import *

timm/models/starnet.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""
2+
Implementation of Prof-of-Concept Network: StarNet.
3+
4+
We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
5+
- like NO layer-scale in network design,
6+
- and NO EMA during training,
7+
- which would improve the performance further.
8+
9+
Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
10+
Modified Date: Mar/29/2024
11+
"""
12+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
18+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19+
from timm.layers import DropPath, SelectAdaptivePool2d, Linear, LayerType, trunc_normal_
20+
from ._builder import build_model_with_cfg
21+
from ._features import feature_take_indices
22+
from ._manipulate import checkpoint_seq
23+
from ._registry import register_model, generate_default_cfgs
24+
25+
__all__ = ['StarNet']
26+
27+
28+
class ConvBN(nn.Sequential):
29+
def __init__(
30+
self,
31+
in_channels: int,
32+
out_channels: int,
33+
kernel_size: int = 1,
34+
stride: int = 1,
35+
padding: int = 0,
36+
with_bn: bool = True,
37+
**kwargs
38+
):
39+
super().__init__()
40+
self.add_module('conv', nn.Conv2d(
41+
in_channels, out_channels, kernel_size, stride=stride, padding=padding, **kwargs))
42+
if with_bn:
43+
self.add_module('bn', nn.BatchNorm2d(out_channels))
44+
nn.init.constant_(self.bn.weight, 1)
45+
nn.init.constant_(self.bn.bias, 0)
46+
47+
48+
class Block(nn.Module):
49+
def __init__(
50+
self,
51+
dim: int,
52+
mlp_ratio: int = 3,
53+
drop_path: float = 0.,
54+
act_layer: LayerType = nn.ReLU6,
55+
):
56+
super().__init__()
57+
self.dwconv = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=True)
58+
self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
59+
self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
60+
self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
61+
self.dwconv2 = ConvBN(dim, dim, 7, 1, 3, groups=dim, with_bn=False)
62+
self.act = act_layer()
63+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
64+
65+
def forward(self, x: torch.Tensor) -> torch.Tensor:
66+
residual = x
67+
x = self.dwconv(x)
68+
x1, x2 = self.f1(x), self.f2(x)
69+
x = self.act(x1) * x2
70+
x = self.dwconv2(self.g(x))
71+
x = residual + self.drop_path(x)
72+
return x
73+
74+
75+
class StarNet(nn.Module):
76+
def __init__(
77+
self,
78+
base_dim: int = 32,
79+
depths: List[int] = [3, 3, 12, 5],
80+
mlp_ratio: int = 4,
81+
drop_rate: float = 0.,
82+
drop_path_rate: float = 0.,
83+
act_layer: LayerType = nn.ReLU6,
84+
num_classes: int = 1000,
85+
in_chans: int = 3,
86+
global_pool: str = 'avg',
87+
output_stride: int = 32,
88+
**kwargs,
89+
):
90+
super().__init__()
91+
assert output_stride == 32
92+
self.num_classes = num_classes
93+
self.drop_rate = drop_rate
94+
self.grad_checkpointing = False
95+
self.feature_info = []
96+
stem_chs = 32
97+
98+
# stem layer
99+
self.stem = nn.Sequential(
100+
ConvBN(in_chans, stem_chs, kernel_size=3, stride=2, padding=1),
101+
act_layer(),
102+
)
103+
prev_chs = stem_chs
104+
105+
# build stages
106+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth
107+
stages = []
108+
cur = 0
109+
for i_layer in range(len(depths)):
110+
embed_dim = base_dim * 2 ** i_layer
111+
down_sampler = ConvBN(prev_chs, embed_dim, 3, stride=2, padding=1)
112+
blocks = [Block(embed_dim, mlp_ratio, dpr[cur + i], act_layer) for i in range(depths[i_layer])]
113+
cur += depths[i_layer]
114+
prev_chs = embed_dim
115+
stages.append(nn.Sequential(down_sampler, *blocks))
116+
self.feature_info.append(dict(
117+
num_chs=prev_chs, reduction=2**(i_layer+2), module=f'stages.{i_layer}'))
118+
self.stages = nn.Sequential(*stages)
119+
# head
120+
self.num_features = self.head_hidden_size = prev_chs
121+
self.norm = nn.BatchNorm2d(self.num_features)
122+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
123+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
124+
self.head = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
125+
self.apply(self._init_weights)
126+
127+
def _init_weights(self, m):
128+
if isinstance(m, (nn.Linear, nn.Conv2d)):
129+
trunc_normal_(m.weight, std=.02)
130+
if isinstance(m, nn.Linear) and m.bias is not None:
131+
nn.init.constant_(m.bias, 0)
132+
elif isinstance(m, nn.BatchNorm2d):
133+
nn.init.constant_(m.bias, 0)
134+
nn.init.constant_(m.weight, 1.0)
135+
136+
@torch.jit.ignore
137+
def no_weight_decay(self) -> Set:
138+
return set()
139+
140+
@torch.jit.ignore
141+
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
142+
matcher = dict(
143+
stem=r'^stem\.\d+',
144+
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
145+
)
146+
return matcher
147+
148+
@torch.jit.ignore
149+
def set_grad_checkpointing(self, enable: bool = True):
150+
self.grad_checkpointing = enable
151+
152+
@torch.jit.ignore
153+
def get_classifier(self) -> nn.Module:
154+
return self.head
155+
156+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
157+
self.num_classes = num_classes
158+
if global_pool is not None:
159+
# NOTE: cannot meaningfully change pooling of efficient head after creation
160+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
161+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
162+
self.head = Linear(self.head_hidden_size, num_classes) if num_classes > 0 else nn.Identity()
163+
164+
def forward_intermediates(
165+
self,
166+
x: torch.Tensor,
167+
indices: Optional[Union[int, List[int]]] = None,
168+
norm: bool = False,
169+
stop_early: bool = False,
170+
output_fmt: str = 'NCHW',
171+
intermediates_only: bool = False,
172+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
173+
""" Forward features that returns intermediates.
174+
175+
Args:
176+
x: Input image tensor
177+
indices: Take last n blocks if int, all if None, select matching indices if sequence
178+
norm: Apply norm layer to compatible intermediates
179+
stop_early: Stop iterating over blocks when last desired intermediate hit
180+
output_fmt: Shape of intermediate feature outputs
181+
intermediates_only: Only return intermediate features
182+
Returns:
183+
184+
"""
185+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
186+
intermediates = []
187+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
188+
last_idx = len(self.stages) - 1
189+
190+
# forward pass
191+
x = self.stem(x)
192+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
193+
stages = self.stages
194+
else:
195+
stages = self.stages[:max_index + 1]
196+
197+
for feat_idx, stage in enumerate(stages):
198+
x = stage(x)
199+
if feat_idx in take_indices:
200+
if norm and feat_idx == last_idx:
201+
x_inter = self.norm(x) # applying final norm last intermediate
202+
else:
203+
x_inter = x
204+
intermediates.append(x_inter)
205+
206+
if intermediates_only:
207+
return intermediates
208+
209+
x = self.norm(x)
210+
211+
return x, intermediates
212+
213+
def prune_intermediate_layers(
214+
self,
215+
indices: Union[int, List[int]] = 1,
216+
prune_norm: bool = False,
217+
prune_head: bool = True,
218+
):
219+
""" Prune layers not required for specified intermediates.
220+
"""
221+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
222+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
223+
if prune_norm:
224+
self.norm = nn.Identity()
225+
if prune_head:
226+
self.reset_classifier(0, '')
227+
return take_indices
228+
229+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
230+
x = self.stem(x)
231+
if self.grad_checkpointing and not torch.jit.is_scripting():
232+
x = checkpoint_seq(self.stages, x, flatten=True)
233+
else:
234+
x = self.stages(x)
235+
x = self.norm(x)
236+
return x
237+
238+
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
239+
x = self.global_pool(x)
240+
x = self.flatten(x)
241+
if self.drop_rate > 0.:
242+
x = F.dropout(x, p=self.drop_rate, training=self.training)
243+
return x if pre_logits else self.head(x)
244+
245+
def forward(self, x: torch.Tensor) -> torch.Tensor:
246+
x = self.forward_features(x)
247+
x = self.forward_head(x)
248+
return x
249+
250+
251+
def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]:
252+
if 'state_dict' in state_dict:
253+
state_dict = state_dict['state_dict']
254+
out_dict = state_dict
255+
return out_dict
256+
257+
258+
def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
259+
return {
260+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
261+
'crop_pct': 0.875, 'interpolation': 'bicubic',
262+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
263+
'first_conv': 'stem.0.conv', 'classifier': 'head',
264+
'paper_ids': 'arXiv:2403.19967',
265+
'paper_name': 'Rewrite the Stars',
266+
'origin_url': 'https://github.com/ma-xu/Rewrite-the-Stars',
267+
**kwargs
268+
}
269+
270+
271+
default_cfgs = generate_default_cfgs({
272+
'starnet_s1.in1k': _cfg(
273+
# hf_hub_id='timm/',
274+
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar',
275+
),
276+
'starnet_s2.in1k': _cfg(
277+
# hf_hub_id='timm/',
278+
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar',
279+
),
280+
'starnet_s3.in1k': _cfg(
281+
# hf_hub_id='timm/',
282+
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar',
283+
),
284+
'starnet_s4.in1k': _cfg(
285+
# hf_hub_id='timm/',
286+
url='https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar',
287+
),
288+
'starnet_s050.untrained': _cfg(),
289+
'starnet_s100.untrained': _cfg(),
290+
'starnet_s150.untrained': _cfg(),
291+
})
292+
293+
294+
def _create_starnet(variant: str, pretrained: bool = False, **kwargs: Any) -> StarNet:
295+
model = build_model_with_cfg(
296+
StarNet, variant, pretrained,
297+
pretrained_filter_fn=checkpoint_filter_fn,
298+
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
299+
**kwargs,
300+
)
301+
return model
302+
303+
304+
@register_model
305+
def starnet_s1(pretrained: bool = False, **kwargs: Any) -> StarNet:
306+
model_args = dict(base_dim=24, depths=[2, 2, 8, 3])
307+
return _create_starnet('starnet_s1', pretrained=pretrained, **dict(model_args, **kwargs))
308+
309+
310+
@register_model
311+
def starnet_s2(pretrained: bool = False, **kwargs: Any) -> StarNet:
312+
model_args = dict(base_dim=32, depths=[1, 2, 6, 2])
313+
return _create_starnet('starnet_s2', pretrained=pretrained, **dict(model_args, **kwargs))
314+
315+
316+
@register_model
317+
def starnet_s3(pretrained: bool = False, **kwargs: Any) -> StarNet:
318+
model_args = dict(base_dim=32, depths=[2, 2, 8, 4])
319+
return _create_starnet('starnet_s3', pretrained=pretrained, **dict(model_args, **kwargs))
320+
321+
322+
@register_model
323+
def starnet_s4(pretrained: bool = False, **kwargs: Any) -> StarNet:
324+
model_args = dict(base_dim=32, depths=[3, 3, 12, 5])
325+
return _create_starnet('starnet_s4', pretrained=pretrained, **dict(model_args, **kwargs))
326+
327+
328+
# very small networks #
329+
@register_model
330+
def starnet_s050(pretrained: bool = False, **kwargs: Any) -> StarNet:
331+
model_args = dict(base_dim=16, depths=[1, 1, 3, 1], mlp_ratio=3)
332+
return _create_starnet('starnet_s050', pretrained=pretrained, **dict(model_args, **kwargs))
333+
334+
335+
@register_model
336+
def starnet_s100(pretrained: bool = False, **kwargs: Any) -> StarNet:
337+
model_args = dict(base_dim=20, depths=[1, 2, 4, 1], mlp_ratio=4)
338+
return _create_starnet('starnet_s100', pretrained=pretrained, **dict(model_args, **kwargs))
339+
340+
341+
@register_model
342+
def starnet_s150(pretrained: bool = False, **kwargs: Any) -> StarNet:
343+
model_args = dict(base_dim=24, depths=[1, 2, 4, 2], mlp_ratio=3)
344+
return _create_starnet('starnet_s150', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)