Skip to content

Commit 7d3c2dc

Browse files
committed
Add group_matcher for DaViT
1 parent 7bc7798 commit 7d3c2dc

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

timm/models/davit.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,17 @@ def _init_weights(self, m):
547547
if isinstance(m, nn.Linear) and m.bias is not None:
548548
nn.init.constant_(m.bias, 0)
549549

550+
@torch.jit.ignore
551+
def group_matcher(self, coarse=False):
552+
return dict(
553+
stem=r'^stem', # stem and embed
554+
blocks=r'^stages\.(\d+)' if coarse else [
555+
(r'^stages\.(\d+).downsample', (0,)),
556+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
557+
(r'^norm_pre', (99999,)),
558+
]
559+
)
560+
550561
@torch.jit.ignore
551562
def set_grad_checkpointing(self, enable=True):
552563
self.grad_checkpointing = enable

0 commit comments

Comments
 (0)