We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7bc7798 commit 7d3c2dcCopy full SHA for 7d3c2dc
timm/models/davit.py
@@ -547,6 +547,17 @@ def _init_weights(self, m):
547
if isinstance(m, nn.Linear) and m.bias is not None:
548
nn.init.constant_(m.bias, 0)
549
550
+ @torch.jit.ignore
551
+ def group_matcher(self, coarse=False):
552
+ return dict(
553
+ stem=r'^stem', # stem and embed
554
+ blocks=r'^stages\.(\d+)' if coarse else [
555
+ (r'^stages\.(\d+).downsample', (0,)),
556
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
557
+ (r'^norm_pre', (99999,)),
558
+ ]
559
+ )
560
+
561
@torch.jit.ignore
562
def set_grad_checkpointing(self, enable=True):
563
self.grad_checkpointing = enable
0 commit comments