1+ """
2+ Implementation of Prof-of-Concept Network: StarNet.
3+
4+ We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
5+ - like NO layer-scale in network design,
6+ - and NO EMA during training,
7+ - which would improve the performance further.
8+
9+ Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
10+ Modified Date: Mar/29/2024
11+ """
12+ from typing import Any , Dict , List , Optional , Set , Tuple , Union
13+
14+ import torch
15+ import torch .nn as nn
16+ import torch .nn .functional as F
17+
18+ from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
19+ from timm .layers import DropPath , SelectAdaptivePool2d , Linear , LayerType , trunc_normal_
20+ from ._builder import build_model_with_cfg
21+ from ._features import feature_take_indices
22+ from ._manipulate import checkpoint_seq
23+ from ._registry import register_model , generate_default_cfgs
24+
25+ __all__ = ['StarNet' ]
26+
27+
28+ class ConvBN (nn .Sequential ):
29+ def __init__ (
30+ self ,
31+ in_channels : int ,
32+ out_channels : int ,
33+ kernel_size : int = 1 ,
34+ stride : int = 1 ,
35+ padding : int = 0 ,
36+ with_bn : bool = True ,
37+ ** kwargs
38+ ):
39+ super ().__init__ ()
40+ self .add_module ('conv' , nn .Conv2d (
41+ in_channels , out_channels , kernel_size , stride = stride , padding = padding , ** kwargs ))
42+ if with_bn :
43+ self .add_module ('bn' , nn .BatchNorm2d (out_channels ))
44+ nn .init .constant_ (self .bn .weight , 1 )
45+ nn .init .constant_ (self .bn .bias , 0 )
46+
47+
48+ class Block (nn .Module ):
49+ def __init__ (
50+ self ,
51+ dim : int ,
52+ mlp_ratio : int = 3 ,
53+ drop_path : float = 0. ,
54+ act_layer : LayerType = nn .ReLU6 ,
55+ ):
56+ super ().__init__ ()
57+ self .dwconv = ConvBN (dim , dim , 7 , 1 , 3 , groups = dim , with_bn = True )
58+ self .f1 = ConvBN (dim , mlp_ratio * dim , 1 , with_bn = False )
59+ self .f2 = ConvBN (dim , mlp_ratio * dim , 1 , with_bn = False )
60+ self .g = ConvBN (mlp_ratio * dim , dim , 1 , with_bn = True )
61+ self .dwconv2 = ConvBN (dim , dim , 7 , 1 , 3 , groups = dim , with_bn = False )
62+ self .act = act_layer ()
63+ self .drop_path = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
64+
65+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
66+ residual = x
67+ x = self .dwconv (x )
68+ x1 , x2 = self .f1 (x ), self .f2 (x )
69+ x = self .act (x1 ) * x2
70+ x = self .dwconv2 (self .g (x ))
71+ x = residual + self .drop_path (x )
72+ return x
73+
74+
75+ class StarNet (nn .Module ):
76+ def __init__ (
77+ self ,
78+ base_dim : int = 32 ,
79+ depths : List [int ] = [3 , 3 , 12 , 5 ],
80+ mlp_ratio : int = 4 ,
81+ drop_rate : float = 0. ,
82+ drop_path_rate : float = 0. ,
83+ act_layer : LayerType = nn .ReLU6 ,
84+ num_classes : int = 1000 ,
85+ in_chans : int = 3 ,
86+ global_pool : str = 'avg' ,
87+ output_stride : int = 32 ,
88+ ** kwargs ,
89+ ):
90+ super ().__init__ ()
91+ assert output_stride == 32
92+ self .num_classes = num_classes
93+ self .drop_rate = drop_rate
94+ self .grad_checkpointing = False
95+ self .feature_info = []
96+ stem_chs = 32
97+
98+ # stem layer
99+ self .stem = nn .Sequential (
100+ ConvBN (in_chans , stem_chs , kernel_size = 3 , stride = 2 , padding = 1 ),
101+ act_layer (),
102+ )
103+ prev_chs = stem_chs
104+
105+ # build stages
106+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))] # stochastic depth
107+ stages = []
108+ cur = 0
109+ for i_layer in range (len (depths )):
110+ embed_dim = base_dim * 2 ** i_layer
111+ down_sampler = ConvBN (prev_chs , embed_dim , 3 , stride = 2 , padding = 1 )
112+ blocks = [Block (embed_dim , mlp_ratio , dpr [cur + i ], act_layer ) for i in range (depths [i_layer ])]
113+ cur += depths [i_layer ]
114+ prev_chs = embed_dim
115+ stages .append (nn .Sequential (down_sampler , * blocks ))
116+ self .feature_info .append (dict (
117+ num_chs = prev_chs , reduction = 2 ** (i_layer + 2 ), module = f'stages.{ i_layer } ' ))
118+ self .stages = nn .Sequential (* stages )
119+ # head
120+ self .num_features = self .head_hidden_size = prev_chs
121+ self .norm = nn .BatchNorm2d (self .num_features )
122+ self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
123+ self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
124+ self .head = Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
125+ self .apply (self ._init_weights )
126+
127+ def _init_weights (self , m ):
128+ if isinstance (m , (nn .Linear , nn .Conv2d )):
129+ trunc_normal_ (m .weight , std = .02 )
130+ if isinstance (m , nn .Linear ) and m .bias is not None :
131+ nn .init .constant_ (m .bias , 0 )
132+ elif isinstance (m , nn .BatchNorm2d ):
133+ nn .init .constant_ (m .bias , 0 )
134+ nn .init .constant_ (m .weight , 1.0 )
135+
136+ @torch .jit .ignore
137+ def no_weight_decay (self ) -> Set :
138+ return set ()
139+
140+ @torch .jit .ignore
141+ def group_matcher (self , coarse : bool = False ) -> Dict [str , Any ]:
142+ matcher = dict (
143+ stem = r'^stem\.\d+' ,
144+ blocks = [(r'^stages\.(\d+)' , None ), (r'^norm' , (99999 ,))]
145+ )
146+ return matcher
147+
148+ @torch .jit .ignore
149+ def set_grad_checkpointing (self , enable : bool = True ):
150+ self .grad_checkpointing = enable
151+
152+ @torch .jit .ignore
153+ def get_classifier (self ) -> nn .Module :
154+ return self .head
155+
156+ def reset_classifier (self , num_classes : int , global_pool : Optional [str ] = None ):
157+ self .num_classes = num_classes
158+ if global_pool is not None :
159+ # NOTE: cannot meaningfully change pooling of efficient head after creation
160+ self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
161+ self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
162+ self .head = Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
163+
164+ def forward_intermediates (
165+ self ,
166+ x : torch .Tensor ,
167+ indices : Optional [Union [int , List [int ]]] = None ,
168+ norm : bool = False ,
169+ stop_early : bool = False ,
170+ output_fmt : str = 'NCHW' ,
171+ intermediates_only : bool = False ,
172+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
173+ """ Forward features that returns intermediates.
174+
175+ Args:
176+ x: Input image tensor
177+ indices: Take last n blocks if int, all if None, select matching indices if sequence
178+ norm: Apply norm layer to compatible intermediates
179+ stop_early: Stop iterating over blocks when last desired intermediate hit
180+ output_fmt: Shape of intermediate feature outputs
181+ intermediates_only: Only return intermediate features
182+ Returns:
183+
184+ """
185+ assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
186+ intermediates = []
187+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
188+ last_idx = len (self .stages ) - 1
189+
190+ # forward pass
191+ x = self .stem (x )
192+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
193+ stages = self .stages
194+ else :
195+ stages = self .stages [:max_index + 1 ]
196+
197+ for feat_idx , stage in enumerate (stages ):
198+ x = stage (x )
199+ if feat_idx in take_indices :
200+ if norm and feat_idx == last_idx :
201+ x_inter = self .norm (x ) # applying final norm last intermediate
202+ else :
203+ x_inter = x
204+ intermediates .append (x_inter )
205+
206+ if intermediates_only :
207+ return intermediates
208+
209+ x = self .norm (x )
210+
211+ return x , intermediates
212+
213+ def prune_intermediate_layers (
214+ self ,
215+ indices : Union [int , List [int ]] = 1 ,
216+ prune_norm : bool = False ,
217+ prune_head : bool = True ,
218+ ):
219+ """ Prune layers not required for specified intermediates.
220+ """
221+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
222+ self .stages = self .stages [:max_index + 1 ] # truncate blocks w/ stem as idx 0
223+ if prune_norm :
224+ self .norm = nn .Identity ()
225+ if prune_head :
226+ self .reset_classifier (0 , '' )
227+ return take_indices
228+
229+ def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
230+ x = self .stem (x )
231+ if self .grad_checkpointing and not torch .jit .is_scripting ():
232+ x = checkpoint_seq (self .stages , x , flatten = True )
233+ else :
234+ x = self .stages (x )
235+ x = self .norm (x )
236+ return x
237+
238+ def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
239+ x = self .global_pool (x )
240+ x = self .flatten (x )
241+ if self .drop_rate > 0. :
242+ x = F .dropout (x , p = self .drop_rate , training = self .training )
243+ return x if pre_logits else self .head (x )
244+
245+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
246+ x = self .forward_features (x )
247+ x = self .forward_head (x )
248+ return x
249+
250+
251+ def checkpoint_filter_fn (state_dict : Dict [str , torch .Tensor ], model : nn .Module ) -> Dict [str , torch .Tensor ]:
252+ if 'state_dict' in state_dict :
253+ state_dict = state_dict ['state_dict' ]
254+ out_dict = state_dict
255+ return out_dict
256+
257+
258+ def _cfg (url : str = '' , ** kwargs : Any ) -> Dict [str , Any ]:
259+ return {
260+ 'url' : url , 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : (7 , 7 ),
261+ 'crop_pct' : 0.875 , 'interpolation' : 'bicubic' ,
262+ 'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
263+ 'first_conv' : 'stem.0.conv' , 'classifier' : 'head' ,
264+ 'paper_ids' : 'arXiv:2403.19967' ,
265+ 'paper_name' : 'Rewrite the Stars' ,
266+ 'origin_url' : 'https://github.com/ma-xu/Rewrite-the-Stars' ,
267+ ** kwargs
268+ }
269+
270+
271+ default_cfgs = generate_default_cfgs ({
272+ 'starnet_s1.in1k' : _cfg (
273+ # hf_hub_id='timm/',
274+ url = 'https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar' ,
275+ ),
276+ 'starnet_s2.in1k' : _cfg (
277+ # hf_hub_id='timm/',
278+ url = 'https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar' ,
279+ ),
280+ 'starnet_s3.in1k' : _cfg (
281+ # hf_hub_id='timm/',
282+ url = 'https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar' ,
283+ ),
284+ 'starnet_s4.in1k' : _cfg (
285+ # hf_hub_id='timm/',
286+ url = 'https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar' ,
287+ ),
288+ 'starnet_s050.untrained' : _cfg (),
289+ 'starnet_s100.untrained' : _cfg (),
290+ 'starnet_s150.untrained' : _cfg (),
291+ })
292+
293+
294+ def _create_starnet (variant : str , pretrained : bool = False , ** kwargs : Any ) -> StarNet :
295+ model = build_model_with_cfg (
296+ StarNet , variant , pretrained ,
297+ pretrained_filter_fn = checkpoint_filter_fn ,
298+ feature_cfg = dict (out_indices = (0 , 1 , 2 , 3 ), flatten_sequential = True ),
299+ ** kwargs ,
300+ )
301+ return model
302+
303+
304+ @register_model
305+ def starnet_s1 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
306+ model_args = dict (base_dim = 24 , depths = [2 , 2 , 8 , 3 ])
307+ return _create_starnet ('starnet_s1' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
308+
309+
310+ @register_model
311+ def starnet_s2 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
312+ model_args = dict (base_dim = 32 , depths = [1 , 2 , 6 , 2 ])
313+ return _create_starnet ('starnet_s2' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
314+
315+
316+ @register_model
317+ def starnet_s3 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
318+ model_args = dict (base_dim = 32 , depths = [2 , 2 , 8 , 4 ])
319+ return _create_starnet ('starnet_s3' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
320+
321+
322+ @register_model
323+ def starnet_s4 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
324+ model_args = dict (base_dim = 32 , depths = [3 , 3 , 12 , 5 ])
325+ return _create_starnet ('starnet_s4' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
326+
327+
328+ # very small networks #
329+ @register_model
330+ def starnet_s050 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
331+ model_args = dict (base_dim = 16 , depths = [1 , 1 , 3 , 1 ], mlp_ratio = 3 )
332+ return _create_starnet ('starnet_s050' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
333+
334+
335+ @register_model
336+ def starnet_s100 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
337+ model_args = dict (base_dim = 20 , depths = [1 , 2 , 4 , 1 ], mlp_ratio = 4 )
338+ return _create_starnet ('starnet_s100' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
339+
340+
341+ @register_model
342+ def starnet_s150 (pretrained : bool = False , ** kwargs : Any ) -> StarNet :
343+ model_args = dict (base_dim = 24 , depths = [1 , 2 , 4 , 2 ], mlp_ratio = 3 )
344+ return _create_starnet ('starnet_s150' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments