Skip to content

Commit 5db7452

Browse files
committed
Fix visformer in_chans stem handling
1 parent fd92ba0 commit 5db7452

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def test_model_features_pretrained(model_name, batch_size):
190190
def test_model_forward_torchscript(model_name, batch_size):
191191
"""Run a single forward pass with each model"""
192192
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
193-
if max(input_size) > MAX_JIT_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional
193+
if max(input_size) > MAX_JIT_SIZE:
194194
pytest.skip("Fixed input size model > limit.")
195195

196196
with set_scriptable(True):

timm/models/visformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _cfg(url='', **kwargs):
2626
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
2727
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
2828
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
29-
'first_conv': 'patch_embed.proj', 'classifier': 'head',
29+
'first_conv': 'stem.0', 'classifier': 'head',
3030
**kwargs
3131
}
3232

@@ -183,7 +183,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, in
183183
img_size //= 8
184184
else:
185185
self.stem = nn.Sequential(
186-
nn.Conv2d(3, self.init_channels, 7, stride=2, padding=3, bias=False),
186+
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
187187
nn.BatchNorm2d(self.init_channels),
188188
nn.ReLU(inplace=True)
189189
)

0 commit comments

Comments
 (0)