Skip to content

Commit 33ada0c

Browse files
committed
Add group_matcher to focalnet for proper layer-wise LR decay
1 parent b271dc0 commit 33ada0c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

timm/models/focalnet.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,20 @@ def __init__(
436436
def no_weight_decay(self):
437437
return {''}
438438

439+
@torch.jit.ignore
440+
def group_matcher(self, coarse=False):
441+
return dict(
442+
stem=r'^stem',
443+
blocks=[
444+
(r'^layers\.(\d+)', None),
445+
(r'^norm', (99999,))
446+
] if coarse else [
447+
(r'^layers\.(\d+).downsample', (0,)),
448+
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
449+
(r'^norm', (99999,)),
450+
]
451+
)
452+
439453
@torch.jit.ignore
440454
def set_grad_checkpointing(self, enable=True):
441455
self.grad_checkpointing = enable

0 commit comments

Comments
 (0)