@@ -58,7 +58,7 @@ def __init__(
5858 stride : int = 1 ,
5959 padding : int = 0 ,
6060 bn_weight_init : int = 1 ,
61- ** kwargs
61+ ** kwargs ,
6262 ):
6363 super ().__init__ ()
6464 self .add_module ('c' , nn .Conv2d (
@@ -229,21 +229,25 @@ def __init__(
229229 act_layer : LayerType = nn .ReLU ,
230230 ):
231231 super ().__init__ ()
232- self .down = nn .Sequential (
232+ self .grad_checkpointing = False
233+ self .downsample = nn .Sequential (
233234 Residule (Conv2d_BN (prev_dim , prev_dim , 3 , 1 , 1 , groups = prev_dim )),
234235 Residule (FFN (prev_dim , int (prev_dim * 2 ), act_layer )),
235236 PatchMerging (prev_dim , dim , act_layer ),
236237 Residule (Conv2d_BN (dim , dim , 3 , 1 , 1 , groups = dim )),
237238 Residule (FFN (dim , int (dim * 2 ), act_layer )),
238239 ) if prev_dim != dim else nn .Identity ()
239240
240- self .block = nn .Sequential (* [
241+ self .blocks = nn .Sequential (* [
241242 BasicBlock (dim , qk_dim , pdim , type , norm_layer , act_layer ) for _ in range (depth )
242243 ])
243244
244245 def forward (self , x : torch .Tensor ) -> torch .Tensor :
245- x = self .down (x )
246- x = self .block (x )
246+ x = self .downsample (x )
247+ if self .grad_checkpointing and not torch .jit .is_scripting ():
248+ x = checkpoint_seq (self .blocks , x , flatten = True )
249+ else :
250+ x = self .blocks (x )
247251 return x
248252
249253
@@ -265,7 +269,6 @@ def __init__(
265269 super ().__init__ ()
266270 self .num_classes = num_classes
267271 self .drop_rate = drop_rate
268- self .grad_checkpointing = False
269272 self .feature_info = []
270273
271274 # Patch embedding
@@ -281,10 +284,10 @@ def __init__(
281284 )
282285
283286 # Build SHViT blocks
284- blocks = []
287+ stages = []
285288 prev_chs = stem_chs
286289 for i in range (len (embed_dim )):
287- blocks .append (StageBlock (
290+ stages .append (StageBlock (
288291 prev_dim = prev_chs ,
289292 dim = embed_dim [i ],
290293 qk_dim = qk_dim [i ],
@@ -295,9 +298,9 @@ def __init__(
295298 act_layer = act_layer ,
296299 ))
297300 prev_chs = embed_dim [i ]
298- self .feature_info .append (dict (num_chs = prev_chs , reduction = 2 ** (i + 4 ), module = f'blocks .{ i } ' ))
299-
300- self . blocks = nn . Sequential ( * blocks )
301+ self .feature_info .append (dict (num_chs = prev_chs , reduction = 2 ** (i + 4 ), module = f'stages .{ i } ' ))
302+ self . stages = nn . Sequential ( * stages )
303+
301304 # Classifier head
302305 self .num_features = self .head_hidden_size = embed_dim [- 1 ]
303306 self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
@@ -310,12 +313,19 @@ def no_weight_decay(self) -> Set:
310313
311314 @torch .jit .ignore
312315 def group_matcher (self , coarse : bool = False ) -> Dict [str , Any ]:
313- matcher = dict (stem = r'^patch_embed' , blocks = [(r'^blocks\.(\d+)' , None )])
316+ matcher = dict (
317+ stem = r'^patch_embed' , # stem and embed
318+ blocks = r'^stages\.(\d+)' if coarse else [
319+ (r'^stages\.(\d+).downsample' , (0 ,)),
320+ (r'^stages\.(\d+)\.blocks\.(\d+)' , None ),
321+ ]
322+ )
314323 return matcher
315324
316325 @torch .jit .ignore
317- def set_grad_checkpointing (self , enable : bool = True ):
318- self .grad_checkpointing = enable
326+ def set_grad_checkpointing (self , enable = True ):
327+ for s in self .stages :
328+ s .grad_checkpointing = enable
319329
320330 @torch .jit .ignore
321331 def get_classifier (self ) -> nn .Module :
@@ -351,14 +361,14 @@ def forward_intermediates(
351361 """
352362 assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
353363 intermediates = []
354- take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
364+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
355365
356366 # forward pass
357367 x = self .patch_embed (x )
358368 if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
359- stages = self .blocks
369+ stages = self .stages
360370 else :
361- stages = self .blocks [:max_index + 1 ]
371+ stages = self .stages [:max_index + 1 ]
362372
363373 for feat_idx , stage in enumerate (stages ):
364374 x = stage (x )
@@ -378,18 +388,15 @@ def prune_intermediate_layers(
378388 ):
379389 """ Prune layers not required for specified intermediates.
380390 """
381- take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
382- self .blocks = self .blocks [:max_index + 1 ] # truncate blocks w/ stem as idx 0
391+ take_indices , max_index = feature_take_indices (len (self .stages ), indices )
392+ self .stages = self .stages [:max_index + 1 ] # truncate blocks w/ stem as idx 0
383393 if prune_head :
384394 self .reset_classifier (0 , '' )
385395 return take_indices
386396
387397 def forward_features (self , x : torch .Tensor ) -> torch .Tensor :
388398 x = self .patch_embed (x )
389- if self .grad_checkpointing and not torch .jit .is_scripting ():
390- x = checkpoint_seq (self .blocks , x , flatten = True )
391- else :
392- x = self .blocks (x )
399+ x = self .stages (x )
393400 return x
394401
395402 def forward_head (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
@@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
424431 out_dict = {}
425432
426433 replace_rules = [
427- (re .compile (r'^blocks1\.' ), 'blocks .0.block .' ),
428- (re .compile (r'^blocks2\.' ), 'blocks .1.block .' ),
429- (re .compile (r'^blocks3\.' ), 'blocks .2.block .' ),
434+ (re .compile (r'^blocks1\.' ), 'stages .0.blocks .' ),
435+ (re .compile (r'^blocks2\.' ), 'stages .1.blocks .' ),
436+ (re .compile (r'^blocks3\.' ), 'stages .2.blocks .' ),
430437 ]
431438 downsample_mapping = {}
432439 for i in range (1 , 3 ):
433- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .0\\ .0\\ .' ] = f'blocks .{ i } .down .0.'
434- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .0\\ .1\\ .' ] = f'blocks .{ i } .down .1.'
435- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .1\\ .' ] = f'blocks .{ i } .down .2.'
436- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .2\\ .0\\ .' ] = f'blocks .{ i } .down .3.'
437- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .2\\ .1\\ .' ] = f'blocks .{ i } .down .4.'
440+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .0\\ .0\\ .' ] = f'stages .{ i } .downsample .0.'
441+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .0\\ .1\\ .' ] = f'stages .{ i } .downsample .1.'
442+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .1\\ .' ] = f'stages .{ i } .downsample .2.'
443+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .2\\ .0\\ .' ] = f'stages .{ i } .downsample .3.'
444+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .2\\ .1\\ .' ] = f'stages .{ i } .downsample .4.'
438445 for j in range (3 , 10 ):
439- downsample_mapping [f'^blocks \\ .{ i } \\ .block \\ .{ j } \\ .' ] = f'blocks .{ i } .block .{ j - 3 } .'
446+ downsample_mapping [f'^stages \\ .{ i } \\ .blocks \\ .{ j } \\ .' ] = f'stages .{ i } .blocks .{ j - 3 } .'
440447
441448 downsample_patterns = [
442449 (re .compile (pattern ), replacement ) for pattern , replacement in downsample_mapping .items ()]
0 commit comments