2828__all__ = ['SHViT' ]
2929
3030
31- class Residule (nn .Module ):
31+ class Residual (nn .Module ):
3232 def __init__ (self , m : nn .Module ):
3333 super ().__init__ ()
3434 self .m = m
@@ -38,7 +38,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3838
3939 @torch .no_grad ()
4040 def fuse (self ) -> nn .Module :
41- if isinstance (self .m , Conv2d_BN ):
41+ if isinstance (self .m , Conv2dNorm ):
4242 m = self .m .fuse ()
4343 assert (m .groups == m .in_channels )
4444 identity = torch .ones (m .weight .shape [0 ], m .weight .shape [1 ], 1 , 1 )
@@ -49,7 +49,7 @@ def fuse(self) -> nn.Module:
4949 return self
5050
5151
52- class Conv2d_BN (nn .Sequential ):
52+ class Conv2dNorm (nn .Sequential ):
5353 def __init__ (
5454 self ,
5555 in_channels : int ,
@@ -89,7 +89,7 @@ def fuse(self) -> nn.Conv2d:
8989 return m
9090
9191
92- class BN_Linear (nn .Sequential ):
92+ class NormLinear (nn .Sequential ):
9393 def __init__ (
9494 self ,
9595 in_features : int ,
@@ -124,12 +124,12 @@ class PatchMerging(nn.Module):
124124 def __init__ (self , dim : int , out_dim : int , act_layer : LayerType = nn .ReLU ):
125125 super ().__init__ ()
126126 hid_dim = int (dim * 4 )
127- self .conv1 = Conv2d_BN (dim , hid_dim )
127+ self .conv1 = Conv2dNorm (dim , hid_dim )
128128 self .act1 = act_layer ()
129- self .conv2 = Conv2d_BN (hid_dim , hid_dim , 3 , 2 , 1 , groups = hid_dim )
129+ self .conv2 = Conv2dNorm (hid_dim , hid_dim , 3 , 2 , 1 , groups = hid_dim )
130130 self .act2 = act_layer ()
131131 self .se = SqueezeExcite (hid_dim , 0.25 )
132- self .conv3 = Conv2d_BN (hid_dim , out_dim )
132+ self .conv3 = Conv2dNorm (hid_dim , out_dim )
133133
134134 def forward (self , x : torch .Tensor ) -> torch .Tensor :
135135 x = self .conv1 (x )
@@ -144,9 +144,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
144144class FFN (nn .Module ):
145145 def __init__ (self , dim : int , embed_dim : int , act_layer : LayerType = nn .ReLU ):
146146 super ().__init__ ()
147- self .pw1 = Conv2d_BN (dim , embed_dim )
147+ self .pw1 = Conv2dNorm (dim , embed_dim )
148148 self .act = act_layer ()
149- self .pw2 = Conv2d_BN (embed_dim , dim , bn_weight_init = 0 )
149+ self .pw2 = Conv2dNorm (embed_dim , dim , bn_weight_init = 0 )
150150
151151 def forward (self , x : torch .Tensor ) -> torch .Tensor :
152152 x = self .pw1 (x )
@@ -173,8 +173,8 @@ def __init__(
173173
174174 self .pre_norm = norm_layer (pdim )
175175
176- self .qkv = Conv2d_BN (pdim , qk_dim * 2 + pdim )
177- self .proj = nn .Sequential (act_layer (), Conv2d_BN (dim , dim , bn_weight_init = 0 ))
176+ self .qkv = Conv2dNorm (pdim , qk_dim * 2 + pdim )
177+ self .proj = nn .Sequential (act_layer (), Conv2dNorm (dim , dim , bn_weight_init = 0 ))
178178
179179 def forward (self , x : torch .Tensor ) -> torch .Tensor :
180180 B , _ , H , W = x .shape
@@ -202,12 +202,12 @@ def __init__(
202202 act_layer : LayerType = nn .ReLU ,
203203 ):
204204 super ().__init__ ()
205- self .conv = Residule ( Conv2d_BN (dim , dim , 3 , 1 , 1 , groups = dim , bn_weight_init = 0 ))
205+ self .conv = Residual ( Conv2dNorm (dim , dim , 3 , 1 , 1 , groups = dim , bn_weight_init = 0 ))
206206 if type == "s" :
207- self .mixer = Residule (SHSA (dim , qk_dim , pdim , norm_layer , act_layer ))
207+ self .mixer = Residual (SHSA (dim , qk_dim , pdim , norm_layer , act_layer ))
208208 else :
209209 self .mixer = nn .Identity ()
210- self .ffn = Residule (FFN (dim , int (dim * 2 )))
210+ self .ffn = Residual (FFN (dim , int (dim * 2 )))
211211
212212 def forward (self , x : torch .Tensor ) -> torch .Tensor :
213213 x = self .conv (x )
@@ -231,11 +231,11 @@ def __init__(
231231 super ().__init__ ()
232232 self .grad_checkpointing = False
233233 self .downsample = nn .Sequential (
234- Residule ( Conv2d_BN (prev_dim , prev_dim , 3 , 1 , 1 , groups = prev_dim )),
235- Residule (FFN (prev_dim , int (prev_dim * 2 ), act_layer )),
234+ Residual ( Conv2dNorm (prev_dim , prev_dim , 3 , 1 , 1 , groups = prev_dim )),
235+ Residual (FFN (prev_dim , int (prev_dim * 2 ), act_layer )),
236236 PatchMerging (prev_dim , dim , act_layer ),
237- Residule ( Conv2d_BN (dim , dim , 3 , 1 , 1 , groups = dim )),
238- Residule (FFN (dim , int (dim * 2 ), act_layer )),
237+ Residual ( Conv2dNorm (dim , dim , 3 , 1 , 1 , groups = dim )),
238+ Residual (FFN (dim , int (dim * 2 ), act_layer )),
239239 ) if prev_dim != dim else nn .Identity ()
240240
241241 self .blocks = nn .Sequential (* [
@@ -274,13 +274,13 @@ def __init__(
274274 # Patch embedding
275275 stem_chs = embed_dim [0 ]
276276 self .patch_embed = nn .Sequential (
277- Conv2d_BN (in_chans , stem_chs // 8 , 3 , 2 , 1 ),
277+ Conv2dNorm (in_chans , stem_chs // 8 , 3 , 2 , 1 ),
278278 act_layer (),
279- Conv2d_BN (stem_chs // 8 , stem_chs // 4 , 3 , 2 , 1 ),
279+ Conv2dNorm (stem_chs // 8 , stem_chs // 4 , 3 , 2 , 1 ),
280280 act_layer (),
281- Conv2d_BN (stem_chs // 4 , stem_chs // 2 , 3 , 2 , 1 ),
281+ Conv2dNorm (stem_chs // 4 , stem_chs // 2 , 3 , 2 , 1 ),
282282 act_layer (),
283- Conv2d_BN (stem_chs // 2 , stem_chs , 3 , 2 , 1 )
283+ Conv2dNorm (stem_chs // 2 , stem_chs , 3 , 2 , 1 )
284284 )
285285
286286 # Build SHViT blocks
@@ -305,7 +305,7 @@ def __init__(
305305 self .num_features = self .head_hidden_size = embed_dim [- 1 ]
306306 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
307307 self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
308- self .head = BN_Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
308+ self .head = NormLinear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
309309
310310 @torch .jit .ignore
311311 def no_weight_decay (self ) -> Set :
@@ -336,7 +336,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
336336 # cannot meaningfully change pooling of efficient head after creation
337337 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
338338 self .flatten = nn .Flatten (1 ) if global_pool else nn .Identity () # don't flatten if pooling disabled
339- self .head = BN_Linear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
339+ self .head = NormLinear (self .head_hidden_size , num_classes ) if num_classes > 0 else nn .Identity ()
340340
341341 def forward_intermediates (
342342 self ,
@@ -426,36 +426,36 @@ def fuse_children(net):
426426
427427
428428def checkpoint_filter_fn (state_dict : Dict [str , torch .Tensor ], model : nn .Module ) -> Dict [str , torch .Tensor ]:
429- if 'model' in state_dict :
430- state_dict = state_dict [ 'model' ]
431- out_dict = {}
432-
433- replace_rules = [
434- (re .compile (r'^blocks1\.' ), 'stages.0.blocks.' ),
435- (re .compile (r'^blocks2\.' ), 'stages.1.blocks.' ),
436- (re .compile (r'^blocks3\.' ), 'stages.2.blocks.' ),
437- ]
438- downsample_mapping = {}
439- for i in range (1 , 3 ):
440- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .0\\ .0\\ .' ] = f'stages.{ i } .downsample.0.'
441- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .0\\ .1\\ .' ] = f'stages.{ i } .downsample.1.'
442- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .1\\ .' ] = f'stages.{ i } .downsample.2.'
443- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .2\\ .0\\ .' ] = f'stages.{ i } .downsample.3.'
444- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .2\\ .1\\ .' ] = f'stages.{ i } .downsample.4.'
445- for j in range (3 , 10 ):
446- downsample_mapping [f'^stages\\ .{ i } \\ .blocks\\ .{ j } \\ .' ] = f'stages.{ i } .blocks.{ j - 3 } .'
447-
448- downsample_patterns = [
449- (re .compile (pattern ), replacement ) for pattern , replacement in downsample_mapping .items ()]
450-
451- for k , v in state_dict .items ():
452- for pattern , replacement in replace_rules :
453- k = pattern .sub (replacement , k )
454- for pattern , replacement in downsample_patterns :
455- k = pattern .sub (replacement , k )
456- out_dict [k ] = v
457-
458- return out_dict
429+ state_dict = state_dict . get ( 'model' , state_dict )
430+
431+ # out_dict = {}
432+ #
433+ # replace_rules = [
434+ # (re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
435+ # (re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
436+ # (re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
437+ # ]
438+ # downsample_mapping = {}
439+ # for i in range(1, 3):
440+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
441+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
442+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
443+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
444+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
445+ # for j in range(3, 10):
446+ # downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
447+ #
448+ # downsample_patterns = [
449+ # (re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]
450+ #
451+ # for k, v in state_dict.items():
452+ # for pattern, replacement in replace_rules:
453+ # k = pattern.sub(replacement, k)
454+ # for pattern, replacement in downsample_patterns:
455+ # k = pattern.sub(replacement, k)
456+ # out_dict[k] = v
457+
458+ return state_dict
459459
460460
461461def _cfg (url : str = '' , ** kwargs : Any ) -> Dict [str , Any ]:
@@ -473,20 +473,20 @@ def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]:
473473
474474default_cfgs = generate_default_cfgs ({
475475 'shvit_s1.in1k' : _cfg (
476- # hf_hub_id='timm/',
477- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth' ,
476+ hf_hub_id = 'timm/' ,
477+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s1.pth',
478478 ),
479479 'shvit_s2.in1k' : _cfg (
480- # hf_hub_id='timm/',
481- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth' ,
480+ hf_hub_id = 'timm/' ,
481+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s2.pth',
482482 ),
483483 'shvit_s3.in1k' : _cfg (
484- # hf_hub_id='timm/',
485- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth' ,
484+ hf_hub_id = 'timm/' ,
485+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s3.pth',
486486 ),
487487 'shvit_s4.in1k' : _cfg (
488- # hf_hub_id='timm/',
489- url = 'https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth' ,
488+ hf_hub_id = 'timm/' ,
489+ # url='https://github.com/ysj9909/SHViT/releases/download/v1.0/shvit_s4.pth',
490490 input_size = (3 , 256 , 256 ),
491491 ),
492492})
0 commit comments