Skip to content

Commit 8c9696c

Browse files
committed
More model and test fixes
1 parent ca52108 commit 8c9696c

File tree

6 files changed

+31
-17
lines changed

6 files changed

+31
-17
lines changed

tests/test_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@
2727
NON_STD_FILTERS = [
2828
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
2929
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
30-
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*']
30+
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
31+
'coatne?t_*', 'max?vit_*',
32+
]
3133
NUM_NON_STD = len(NON_STD_FILTERS)
3234

3335
# exclude models that cause specific test failures

timm/models/gcvit.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _cfg(url='', **kwargs):
4343
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
4444
'crop_pct': 0.875, 'interpolation': 'bicubic',
4545
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
46-
'first_conv': 'stem.conv', 'classifier': 'head.fc',
46+
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
4747
'fixed_input_size': True,
4848
**kwargs
4949
}
@@ -106,7 +106,7 @@ def __init__(
106106
dim_out=None,
107107
reduction='conv',
108108
act_layer=nn.GELU,
109-
norm_layer=LayerNorm2d,
109+
norm_layer=LayerNorm2d, # NOTE in NCHW
110110
):
111111
super().__init__()
112112
dim_out = dim_out or dim
@@ -163,12 +163,10 @@ def __init__(
163163
self,
164164
in_chs: int = 3,
165165
out_chs: int = 96,
166-
act_layer: str = 'gelu',
167-
norm_layer: str = 'layernorm2d', # NOTE norm for NCHW
166+
act_layer: Callable = nn.GELU,
167+
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
168168
):
169169
super().__init__()
170-
act_layer = get_act_layer(act_layer)
171-
norm_layer = get_norm_layer(norm_layer)
172170
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
173171
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
174172

@@ -333,15 +331,11 @@ def __init__(
333331
proj_drop: float = 0.,
334332
attn_drop: float = 0.,
335333
drop_path: Union[List[float], float] = 0.0,
336-
act_layer: str = 'gelu',
337-
norm_layer: str = 'layernorm2d',
338-
norm_layer_cl: str = 'layernorm',
334+
act_layer: Callable = nn.GELU,
335+
norm_layer: Callable = nn.LayerNorm,
336+
norm_layer_cl: Callable = LayerNorm2d,
339337
):
340338
super().__init__()
341-
act_layer = get_act_layer(act_layer)
342-
norm_layer = get_norm_layer(norm_layer)
343-
norm_layer_cl = get_norm_layer(norm_layer_cl)
344-
345339
if downsample:
346340
self.downsample = Downsample2d(
347341
dim=dim,
@@ -421,8 +415,13 @@ def __init__(
421415
act_layer: str = 'gelu',
422416
norm_layer: str = 'layernorm2d',
423417
norm_layer_cl: str = 'layernorm',
418+
norm_eps: float = 1e-5,
424419
):
425420
super().__init__()
421+
act_layer = get_act_layer(act_layer)
422+
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
423+
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
424+
426425
img_size = to_2tuple(img_size)
427426
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
428427
self.global_pool = global_pool
@@ -432,7 +431,11 @@ def __init__(
432431
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
433432

434433
self.stem = Stem(
435-
in_chs=in_chans, out_chs=embed_dim, act_layer=act_layer, norm_layer=norm_layer)
434+
in_chs=in_chans,
435+
out_chs=embed_dim,
436+
act_layer=act_layer,
437+
norm_layer=norm_layer
438+
)
436439

437440
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
438441
stages = []

timm/models/layers/create_norm_act.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
batchnorm=BatchNormAct2d,
1919
batchnorm2d=BatchNormAct2d,
2020
groupnorm=GroupNormAct,
21+
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
2122
layernorm=LayerNormAct,
2223
layernorm2d=LayerNormAct2d,
2324
evonormb0=EvoNorm2dB0,
@@ -72,6 +73,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
7273
norm_act_layer = BatchNormAct2d
7374
elif type_name.startswith('groupnorm'):
7475
norm_act_layer = GroupNormAct
76+
elif type_name.startswith('groupnorm1'):
77+
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
7578
elif type_name.startswith('layernorm2d'):
7679
norm_act_layer = LayerNormAct2d
7780
elif type_name.startswith('layernorm'):

timm/models/layers/norm_act.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
self.act = act_layer(**act_args)
227227
else:
228228
self.act = nn.Identity()
229+
self._fast_norm = is_fast_norm()
229230

230231
def forward(self, x):
231232
x = x.permute(0, 2, 3, 1)

timm/models/mvitv2.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torch import nn
2525

2626
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
27+
from .fx_features import register_notrace_function
2728
from .helpers import build_model_with_cfg
2829
from .layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
2930
from .registry import register_model
@@ -35,7 +36,8 @@ def _cfg(url='', **kwargs):
3536
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
3637
'crop_pct': .9, 'interpolation': 'bicubic',
3738
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
38-
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': True,
39+
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
40+
'fixed_input_size': True,
3941
**kwargs
4042
}
4143

@@ -169,6 +171,7 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
169171
return x.flatten(2).transpose(1, 2), x.shape[-2:]
170172

171173

174+
@register_notrace_function
172175
def reshape_pre_pool(
173176
x,
174177
feat_size: List[int],
@@ -183,6 +186,7 @@ def reshape_pre_pool(
183186
return x, cls_tok
184187

185188

189+
@register_notrace_function
186190
def reshape_post_pool(
187191
x,
188192
num_heads: int,
@@ -196,6 +200,7 @@ def reshape_post_pool(
196200
return x, feat_size
197201

198202

203+
@register_notrace_function
199204
def cal_rel_pos_type(
200205
attn: torch.Tensor,
201206
q: torch.Tensor,

timm/models/pvt_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _cfg(url='', **kwargs):
3636
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
3737
'crop_pct': 0.9, 'interpolation': 'bicubic',
3838
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
39-
'first_conv': 'patch_embed.conv', 'classifier': 'head', 'fixed_input_size': False,
39+
'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
4040
**kwargs
4141
}
4242

0 commit comments

Comments
 (0)