Skip to content

Commit 89d2952

Browse files
committed
update group_matcher
1 parent 7fc0692 commit 89d2952

File tree

4 files changed

+76
-54
lines changed

4 files changed

+76
-54
lines changed

timm/models/fasternet.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Licensed under the MIT License.
1717

1818
from functools import partial
19-
from typing import Any, Dict, List, Optional, Tuple, Union
19+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -118,6 +118,7 @@ def __init__(
118118
merge_size: Union[int, Tuple[int, int]] = 2,
119119
):
120120
super().__init__()
121+
self.grad_checkpointing = False
121122
self.blocks = nn.Sequential(*[
122123
MLPBlock(
123124
dim=dim,
@@ -127,18 +128,22 @@ def __init__(
127128
layer_scale_init_value=layer_scale_init_value,
128129
norm_layer=norm_layer,
129130
act_layer=act_layer,
130-
pconv_fw_type=pconv_fw_type
131+
pconv_fw_type=pconv_fw_type,
131132
)
132133
for i in range(depth)
133134
])
134-
self.down = PatchMerging(
135+
self.downsample = PatchMerging(
135136
dim=dim // 2,
136137
patch_size=merge_size,
137138
norm_layer=norm_layer,
138139
) if use_merge else nn.Identity()
139140

140141
def forward(self, x: torch.Tensor) -> torch.Tensor:
141-
x = self.blocks(self.down(x))
142+
x = self.downsample(x)
143+
if self.grad_checkpointing and not torch.jit.is_scripting():
144+
x = checkpoint_seq(self.blocks, x, flatten=True)
145+
else:
146+
x = self.blocks(x)
142147
return x
143148

144149

@@ -202,7 +207,6 @@ def __init__(
202207
depths = (depths) # it means the model has only one stage
203208
self.num_stages = len(depths)
204209
self.feature_info = []
205-
self.grad_checkpointing = False
206210

207211
self.patch_embed = PatchEmbed(
208212
in_chans=in_chans,
@@ -255,20 +259,26 @@ def _initialize_weights(self):
255259
if m.bias is not None:
256260
nn.init.constant_(m.bias, 0)
257261

262+
@torch.jit.ignore
263+
def no_weight_decay(self) -> Set:
264+
return set()
265+
258266
@torch.jit.ignore
259267
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
260268
matcher = dict(
261-
stem=r'patch_embed',
262-
blocks=[
263-
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
264-
(r'conv_head', (99999,))
269+
stem=r'^patch_embed', # stem and embed
270+
blocks=r'^stages\.(\d+)' if coarse else [
271+
(r'^stages\.(\d+).downsample', (0,)),
272+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
273+
(r'^conv_head', (99999,)),
265274
]
266275
)
267276
return matcher
268277

269278
@torch.jit.ignore
270-
def set_grad_checkpointing(self, enable: bool = True):
271-
self.grad_checkpointing = enable
279+
def set_grad_checkpointing(self, enable=True):
280+
for s in self.stages:
281+
s.grad_checkpointing = enable
272282

273283
@torch.jit.ignore
274284
def get_classifier(self) -> nn.Module:
@@ -339,10 +349,7 @@ def prune_intermediate_layers(
339349

340350
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
341351
x = self.patch_embed(x)
342-
if self.grad_checkpointing and not torch.jit.is_scripting():
343-
x = checkpoint_seq(self.stages, x, flatten=True)
344-
else:
345-
x = self.stages(x)
352+
x = self.stages(x)
346353
return x
347354

348355
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
@@ -371,11 +378,11 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
371378
}
372379

373380
stage_mapping = {
374-
'stages.1.': 'stages.1.down.',
381+
'stages.1.': 'stages.1.downsample.',
375382
'stages.2.': 'stages.1.',
376-
'stages.3.': 'stages.2.down.',
383+
'stages.3.': 'stages.2.downsample.',
377384
'stages.4.': 'stages.2.',
378-
'stages.5.': 'stages.3.down.',
385+
'stages.5.': 'stages.3.downsample.',
379386
'stages.6.': 'stages.3.'
380387
}
381388

timm/models/shvit.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
stride: int = 1,
5959
padding: int = 0,
6060
bn_weight_init: int = 1,
61-
**kwargs
61+
**kwargs,
6262
):
6363
super().__init__()
6464
self.add_module('c', nn.Conv2d(
@@ -229,21 +229,25 @@ def __init__(
229229
act_layer: LayerType = nn.ReLU,
230230
):
231231
super().__init__()
232-
self.down = nn.Sequential(
232+
self.grad_checkpointing = False
233+
self.downsample = nn.Sequential(
233234
Residule(Conv2d_BN(prev_dim, prev_dim, 3, 1, 1, groups=prev_dim)),
234235
Residule(FFN(prev_dim, int(prev_dim * 2), act_layer)),
235236
PatchMerging(prev_dim, dim, act_layer),
236237
Residule(Conv2d_BN(dim, dim, 3, 1, 1, groups=dim)),
237238
Residule(FFN(dim, int(dim * 2), act_layer)),
238239
) if prev_dim != dim else nn.Identity()
239240

240-
self.block = nn.Sequential(*[
241+
self.blocks = nn.Sequential(*[
241242
BasicBlock(dim, qk_dim, pdim, type, norm_layer, act_layer) for _ in range(depth)
242243
])
243244

244245
def forward(self, x: torch.Tensor) -> torch.Tensor:
245-
x = self.down(x)
246-
x = self.block(x)
246+
x = self.downsample(x)
247+
if self.grad_checkpointing and not torch.jit.is_scripting():
248+
x = checkpoint_seq(self.blocks, x, flatten=True)
249+
else:
250+
x = self.blocks(x)
247251
return x
248252

249253

@@ -265,7 +269,6 @@ def __init__(
265269
super().__init__()
266270
self.num_classes = num_classes
267271
self.drop_rate = drop_rate
268-
self.grad_checkpointing = False
269272
self.feature_info = []
270273

271274
# Patch embedding
@@ -281,10 +284,10 @@ def __init__(
281284
)
282285

283286
# Build SHViT blocks
284-
blocks = []
287+
stages = []
285288
prev_chs = stem_chs
286289
for i in range(len(embed_dim)):
287-
blocks.append(StageBlock(
290+
stages.append(StageBlock(
288291
prev_dim=prev_chs,
289292
dim=embed_dim[i],
290293
qk_dim=qk_dim[i],
@@ -295,9 +298,9 @@ def __init__(
295298
act_layer=act_layer,
296299
))
297300
prev_chs = embed_dim[i]
298-
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'blocks.{i}'))
299-
300-
self.blocks = nn.Sequential(*blocks)
301+
self.feature_info.append(dict(num_chs=prev_chs, reduction=2**(i+4), module=f'stages.{i}'))
302+
self.stages = nn.Sequential(*stages)
303+
301304
# Classifier head
302305
self.num_features = self.head_hidden_size = embed_dim[-1]
303306
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
@@ -310,12 +313,19 @@ def no_weight_decay(self) -> Set:
310313

311314
@torch.jit.ignore
312315
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
313-
matcher = dict(stem=r'^patch_embed', blocks=[(r'^blocks\.(\d+)', None)])
316+
matcher = dict(
317+
stem=r'^patch_embed', # stem and embed
318+
blocks=r'^stages\.(\d+)' if coarse else [
319+
(r'^stages\.(\d+).downsample', (0,)),
320+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
321+
]
322+
)
314323
return matcher
315324

316325
@torch.jit.ignore
317-
def set_grad_checkpointing(self, enable: bool = True):
318-
self.grad_checkpointing = enable
326+
def set_grad_checkpointing(self, enable=True):
327+
for s in self.stages:
328+
s.grad_checkpointing = enable
319329

320330
@torch.jit.ignore
321331
def get_classifier(self) -> nn.Module:
@@ -351,14 +361,14 @@ def forward_intermediates(
351361
"""
352362
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
353363
intermediates = []
354-
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
364+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
355365

356366
# forward pass
357367
x = self.patch_embed(x)
358368
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
359-
stages = self.blocks
369+
stages = self.stages
360370
else:
361-
stages = self.blocks[:max_index + 1]
371+
stages = self.stages[:max_index + 1]
362372

363373
for feat_idx, stage in enumerate(stages):
364374
x = stage(x)
@@ -378,18 +388,15 @@ def prune_intermediate_layers(
378388
):
379389
""" Prune layers not required for specified intermediates.
380390
"""
381-
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
382-
self.blocks = self.blocks[:max_index + 1] # truncate blocks w/ stem as idx 0
391+
take_indices, max_index = feature_take_indices(len(self.stages), indices)
392+
self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0
383393
if prune_head:
384394
self.reset_classifier(0, '')
385395
return take_indices
386396

387397
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
388398
x = self.patch_embed(x)
389-
if self.grad_checkpointing and not torch.jit.is_scripting():
390-
x = checkpoint_seq(self.blocks, x, flatten=True)
391-
else:
392-
x = self.blocks(x)
399+
x = self.stages(x)
393400
return x
394401

395402
def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
@@ -424,19 +431,19 @@ def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module)
424431
out_dict = {}
425432

426433
replace_rules = [
427-
(re.compile(r'^blocks1\.'), 'blocks.0.block.'),
428-
(re.compile(r'^blocks2\.'), 'blocks.1.block.'),
429-
(re.compile(r'^blocks3\.'), 'blocks.2.block.'),
434+
(re.compile(r'^blocks1\.'), 'stages.0.blocks.'),
435+
(re.compile(r'^blocks2\.'), 'stages.1.blocks.'),
436+
(re.compile(r'^blocks3\.'), 'stages.2.blocks.'),
430437
]
431438
downsample_mapping = {}
432439
for i in range(1, 3):
433-
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.0\\.'] = f'blocks.{i}.down.0.'
434-
downsample_mapping[f'^blocks\\.{i}\\.block\\.0\\.1\\.'] = f'blocks.{i}.down.1.'
435-
downsample_mapping[f'^blocks\\.{i}\\.block\\.1\\.'] = f'blocks.{i}.down.2.'
436-
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.0\\.'] = f'blocks.{i}.down.3.'
437-
downsample_mapping[f'^blocks\\.{i}\\.block\\.2\\.1\\.'] = f'blocks.{i}.down.4.'
440+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.0\\.'] = f'stages.{i}.downsample.0.'
441+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.0\\.1\\.'] = f'stages.{i}.downsample.1.'
442+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.1\\.'] = f'stages.{i}.downsample.2.'
443+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.0\\.'] = f'stages.{i}.downsample.3.'
444+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.2\\.1\\.'] = f'stages.{i}.downsample.4.'
438445
for j in range(3, 10):
439-
downsample_mapping[f'^blocks\\.{i}\\.block\\.{j}\\.'] = f'blocks.{i}.block.{j - 3}.'
446+
downsample_mapping[f'^stages\\.{i}\\.blocks\\.{j}\\.'] = f'stages.{i}.blocks.{j - 3}.'
440447

441448
downsample_patterns = [
442449
(re.compile(pattern), replacement) for pattern, replacement in downsample_mapping.items()]

timm/models/starnet.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
stride: int = 1,
3535
padding: int = 0,
3636
with_bn: bool = True,
37-
**kwargs
37+
**kwargs,
3838
):
3939
super().__init__()
4040
self.add_module('conv', nn.Conv2d(
@@ -141,7 +141,10 @@ def no_weight_decay(self) -> Set:
141141
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
142142
matcher = dict(
143143
stem=r'^stem\.\d+',
144-
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
144+
blocks=[
145+
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
146+
(r'norm', (99999,))
147+
]
145148
)
146149
return matcher
147150

@@ -206,7 +209,8 @@ def forward_intermediates(
206209
if intermediates_only:
207210
return intermediates
208211

209-
x = self.norm(x)
212+
if feat_idx == last_idx:
213+
x = self.norm(x)
210214

211215
return x, intermediates
212216

timm/models/swiftformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,11 @@ def no_weight_decay(self) -> Set:
402402
def group_matcher(self, coarse: bool = False) -> Dict[str, Any]:
403403
matcher = dict(
404404
stem=r'^stem', # stem and embed
405-
blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
405+
blocks=r'^stages\.(\d+)' if coarse else [
406+
(r'^stages\.(\d+).downsample', (0,)),
407+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
408+
(r'^norm', (99999,)),
409+
]
406410
)
407411
return matcher
408412

0 commit comments

Comments
 (0)