Skip to content

Commit b41cffa

Browse files
committed
Fix a few issues loading pretrained vit/bit npz weights w/ num_classes=0 __init__ arg. Missed a few other small classifier handling detail on Mlp, GhostNet, Levit. Should fix #713
1 parent dc42282 commit b41cffa

File tree

7 files changed

+23
-6
lines changed

7 files changed

+23
-6
lines changed

tests/test_models.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size):
147147
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
148148
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
149149

150+
if 'pruned' not in model_name: # FIXME better pruned model handling
151+
# test classifier + global pool deletion via __init__
152+
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
153+
outputs = model.forward(input_tensor)
154+
assert len(outputs.shape) == 4
155+
if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet):
156+
# FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ
157+
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
158+
150159
# check classifier name matches default_cfg
151160
classifier = cfg['classifier']
152161
if not isinstance(classifier, (tuple, list)):
@@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
193202
assert len(outputs.shape) == 2
194203
assert outputs.shape[1] == model.num_features
195204

205+
model = create_model(model_name, pretrained=False, num_classes=0).eval()
206+
outputs = model.forward(input_tensor)
207+
if isinstance(outputs, tuple):
208+
outputs = outputs[0]
209+
assert len(outputs.shape) == 2
210+
assert outputs.shape[1] == model.num_features
211+
196212
# check classifier name matches default_cfg
197213
classifier = cfg['classifier']
198214
if not isinstance(classifier, (tuple, list)):
@@ -217,6 +233,7 @@ def test_model_load_pretrained(model_name, batch_size):
217233
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
218234
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
219235
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
236+
create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0)
220237

221238
@pytest.mark.timeout(120)
222239
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, o
182182
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
183183
self.act2 = nn.ReLU(inplace=True)
184184
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
185-
self.classifier = Linear(out_chs, num_classes)
185+
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
186186

187187
def get_classifier(self):
188188
return self.classifier

timm/models/levit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model):
542542
state_dict = state_dict['model']
543543
D = model.state_dict()
544544
for k in state_dict.keys():
545-
if D[k].ndim == 4 and state_dict[k].ndim == 2:
545+
if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2:
546546
state_dict[k] = state_dict[k][:, :, None, None]
547547
return state_dict
548548

timm/models/mlp_mixer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def __init__(
266266
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
267267
for _ in range(num_blocks)])
268268
self.norm = norm_layer(embed_dim)
269-
self.head = nn.Linear(embed_dim, self.num_classes) # zero init
269+
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
270270

271271
self.init_weights(nlhb=nlhb)
272272

timm/models/resnetv2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,8 @@ def t2p(conv_weights):
424424
model.stem.conv.weight.copy_(stem_conv_w)
425425
model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
426426
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
427-
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
427+
if isinstance(model.head.fc, nn.Conv2d) and \
428+
model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
428429
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
429430
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
430431
for i, (sname, stage) in enumerate(model.stages.named_children()):

timm/models/visformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,6 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, in
237237
self.num_features = embed_dim if self.vit_stem else embed_dim * 2
238238
self.norm = norm_layer(self.num_features)
239239
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
240-
self.head = nn.Linear(self.num_features, num_classes)
241240

242241
# weights init
243242
if self.pos_embed:

timm/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def _n2p(w, t=True):
448448
model.pos_embed.copy_(pos_embed_w)
449449
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
450450
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
451-
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
451+
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
452452
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
453453
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
454454
for i, block in enumerate(model.blocks.children()):

0 commit comments

Comments
 (0)