Skip to content

Commit 6d0238f

Browse files
Merge pull request #81 from SegmentationBLWX/dev
merge tests
2 parents 65ba520 + 2c26cf7 commit 6d0238f

File tree

7 files changed

+294
-1
lines changed

7 files changed

+294
-1
lines changed

ssseg/modules/models/backbones/swin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'swin_base_patch4_window12_384_22k': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth',
2525
'swin_base_patch4_window7_224_22k': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window7_224_22k.pth',
2626
'swin_large_patch4_window12_384_22k': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth',
27+
'swin_large_patch4_window12_384_22kto1k': 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22kto1k.pth',
2728
}
2829
'''AUTO_ASSERT_STRUCTURE_TYPES'''
2930
AUTO_ASSERT_STRUCTURE_TYPES = {
@@ -62,6 +63,11 @@
6263
'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
6364
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,
6465
},
66+
'swin_large_patch4_window12_384_22kto1k': {
67+
'pretrain_img_size': 384, 'in_channels': 3, 'embed_dims': 192, 'patch_size': 4, 'window_size': 12, 'mlp_ratio': 4,
68+
'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
69+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,
70+
},
6571
}
6672

6773

ssseg/modules/models/backbones/vit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def _forward(x):
7171
class VisionTransformer(nn.Module):
7272
def __init__(self, structure_type, img_size=224, patch_size=16, patch_pad='corner', in_channels=3, embed_dims=768, num_layers=12, num_heads=12, mlp_ratio=4, out_origin=False, out_indices=(9, 14, 19, 23),
7373
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., with_cls_token=True, output_cls_token=False, norm_cfg={'type': 'LayerNorm', 'eps': 1e-6}, act_cfg={'type': 'GELU'},
74-
patch_norm=False, patch_bias=False, pre_norm=False, final_norm=False, interpolate_mode='bilinear', num_fcs=2, use_checkpoint=False, pretrained=True, pretrained_model_path=''):
74+
patch_norm=False, patch_bias=True, pre_norm=False, final_norm=False, interpolate_mode='bilinear', num_fcs=2, use_checkpoint=False, pretrained=True, pretrained_model_path=''):
7575
super(VisionTransformer, self).__init__()
7676
img_size = tolen2tuple(img_size)
7777
# set attributes

tests/test_backbones/test_mit.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
'''
2+
Function:
3+
Implementation of Testing MiT
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.mit import DEFAULT_MODEL_URLS
9+
10+
11+
'''MiTs'''
12+
cfgs = [
13+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b0', 'pretrained': True, 'pretrained_model_path': 'mit_b0.pth',
14+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
15+
'embed_dims': 32, 'num_stages': 4, 'num_layers': [2, 2, 2, 2], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
16+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
17+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b1', 'pretrained': True, 'pretrained_model_path': 'mit_b1.pth',
18+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
19+
'embed_dims': 64, 'num_stages': 4, 'num_layers': [2, 2, 2, 2], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
20+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
21+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b2', 'pretrained': True, 'pretrained_model_path': 'mit_b2.pth',
22+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
23+
'embed_dims': 64, 'num_stages': 4, 'num_layers': [3, 4, 6, 3], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
24+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
25+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b3', 'pretrained': True, 'pretrained_model_path': 'mit_b3.pth',
26+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
27+
'embed_dims': 64, 'num_stages': 4, 'num_layers': [3, 4, 18, 3], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
28+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
29+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b4', 'pretrained': True, 'pretrained_model_path': 'mit_b4.pth',
30+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
31+
'embed_dims': 64, 'num_stages': 4, 'num_layers': [3, 8, 27, 3], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
32+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
33+
{'type': 'MixVisionTransformer', 'structure_type': 'mit-b5', 'pretrained': True, 'pretrained_model_path': 'mit_b5.pth',
34+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
35+
'embed_dims': 64, 'num_stages': 4, 'num_layers': [3, 6, 40, 3], 'num_heads': [1, 2, 5, 8], 'patch_sizes': [7, 3, 3, 3],
36+
'sr_ratios': [8, 4, 2, 1], 'mlp_ratio': 4, 'qkv_bias': True, 'drop_rate': 0.0, 'attn_drop_rate': 0.0, 'drop_path_rate': 0.1,},
37+
]
38+
for cfg in cfgs:
39+
mit = BuildBackbone(backbone_cfg=cfg)
40+
state_dict = loadpretrainedweights(
41+
structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
42+
)
43+
state_dict = mit.mitconvert(state_dict)
44+
try:
45+
mit.load_state_dict(state_dict, strict=False)
46+
except Exception as err:
47+
print(err)
48+
try:
49+
mit.load_state_dict(state_dict, strict=True)
50+
except Exception as err:
51+
print(err)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
'''
2+
Function:
3+
Implementation of Testing Mobilenets
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.mobilenet import DEFAULT_MODEL_URLS
9+
10+
11+
'''mobilenetv2'''
12+
cfgs = [
13+
{'type': 'MobileNetV2', 'structure_type': 'mobilenetv2', 'pretrained': True, 'outstride': 8, 'selected_indices': (0, 1, 2, 3),},
14+
]
15+
for cfg in cfgs:
16+
mobilenet = BuildBackbone(backbone_cfg=cfg)
17+
state_dict = loadpretrainedweights(structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS)
18+
keys = list(state_dict.keys())
19+
for key in keys:
20+
if key.startswith('backbone.'):
21+
value = state_dict.pop(key)
22+
key = '.'.join(key.split('.')[1:])
23+
state_dict[key] = value
24+
try:
25+
mobilenet.load_state_dict(state_dict, strict=False)
26+
except Exception as err:
27+
print(err)
28+
try:
29+
mobilenet.load_state_dict(state_dict, strict=True)
30+
except Exception as err:
31+
print(err)
32+
33+
34+
'''mobilenetv3'''
35+
cfgs = [
36+
{'type': 'MobileNetV3', 'structure_type': 'mobilenetv3_small', 'pretrained': True, 'outstride': 8,
37+
'arch_type': 'small', 'out_indices': (0, 1, 12), 'selected_indices': (0, 1, 2),},
38+
{'type': 'MobileNetV3', 'structure_type': 'mobilenetv3_large', 'pretrained': True,
39+
'outstride': 8, 'arch_type': 'large', 'selected_indices': (0, 1, 2),},
40+
]
41+
for cfg in cfgs:
42+
mobilenet = BuildBackbone(backbone_cfg=cfg)
43+
state_dict = loadpretrainedweights(structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS)
44+
keys = list(state_dict.keys())
45+
for key in keys:
46+
if key.startswith('backbone.'):
47+
value = state_dict.pop(key)
48+
key = '.'.join(key.split('.')[1:])
49+
state_dict[key] = value
50+
try:
51+
mobilenet.load_state_dict(state_dict, strict=False)
52+
except Exception as err:
53+
print(err)
54+
try:
55+
mobilenet.load_state_dict(state_dict, strict=True)
56+
except Exception as err:
57+
print(err)

tests/test_backbones/test_swin.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
'''
2+
Function:
3+
Implementation of Testing SwinTransformer
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
import torch.nn.functional as F
8+
from ssseg.modules import BuildBackbone, loadpretrainedweights
9+
from ssseg.modules.models.backbones.swin import DEFAULT_MODEL_URLS
10+
11+
12+
'''SwinTransformers'''
13+
cfgs = [
14+
{'type': 'SwinTransformer', 'structure_type': 'swin_large_patch4_window12_384_22kto1k', 'pretrained': True,
15+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
16+
'pretrain_img_size': 384, 'in_channels': 3, 'embed_dims': 192, 'patch_size': 4, 'window_size': 12, 'mlp_ratio': 4,
17+
'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
18+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
19+
{'type': 'SwinTransformer', 'structure_type': 'swin_large_patch4_window12_384_22k', 'pretrained': True,
20+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
21+
'pretrain_img_size': 384, 'in_channels': 3, 'embed_dims': 192, 'patch_size': 4, 'window_size': 12, 'mlp_ratio': 4,
22+
'depths': [2, 2, 18, 2], 'num_heads': [6, 12, 24, 48], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
23+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
24+
{'type': 'SwinTransformer', 'structure_type': 'swin_base_patch4_window12_384', 'pretrained': True,
25+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
26+
'pretrain_img_size': 384, 'in_channels': 3, 'embed_dims': 128, 'patch_size': 4, 'window_size': 12, 'mlp_ratio': 4,
27+
'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
28+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
29+
{'type': 'SwinTransformer', 'structure_type': 'swin_base_patch4_window7_224', 'pretrained': True,
30+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
31+
'pretrain_img_size': 224, 'in_channels': 3, 'embed_dims': 128, 'patch_size': 4, 'window_size': 7, 'mlp_ratio': 4,
32+
'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
33+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
34+
{'type': 'SwinTransformer', 'structure_type': 'swin_base_patch4_window12_384_22k', 'pretrained': True,
35+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
36+
'pretrain_img_size': 384, 'in_channels': 3, 'embed_dims': 128, 'patch_size': 4, 'window_size': 12, 'mlp_ratio': 4,
37+
'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
38+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
39+
{'type': 'SwinTransformer', 'structure_type': 'swin_base_patch4_window7_224_22k', 'pretrained': True,
40+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
41+
'pretrain_img_size': 224, 'in_channels': 3, 'embed_dims': 128, 'patch_size': 4, 'window_size': 7, 'mlp_ratio': 4,
42+
'depths': [2, 2, 18, 2], 'num_heads': [4, 8, 16, 32], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
43+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
44+
{'type': 'SwinTransformer', 'structure_type': 'swin_small_patch4_window7_224', 'pretrained': True,
45+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
46+
'pretrain_img_size': 224, 'in_channels': 3, 'embed_dims': 96, 'patch_size': 4, 'window_size': 7, 'mlp_ratio': 4,
47+
'depths': [2, 2, 18, 2], 'num_heads': [3, 6, 12, 24], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
48+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
49+
{'type': 'SwinTransformer', 'structure_type': 'swin_tiny_patch4_window7_224', 'pretrained': True,
50+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
51+
'pretrain_img_size': 224, 'in_channels': 3, 'embed_dims': 96, 'patch_size': 4, 'window_size': 7, 'mlp_ratio': 4,
52+
'depths': [2, 2, 6, 2], 'num_heads': [3, 6, 12, 24], 'qkv_bias': True, 'qk_scale': None, 'patch_norm': True,
53+
'drop_rate': 0., 'attn_drop_rate': 0., 'drop_path_rate': 0.3, 'use_abs_pos_embed': False,},
54+
]
55+
for cfg in cfgs:
56+
swin = BuildBackbone(cfg)
57+
state_dict = loadpretrainedweights(
58+
structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
59+
)
60+
state_dict = swin.swinconvert(state_dict)
61+
# be consistent
62+
from collections import OrderedDict
63+
state_dict_new = OrderedDict()
64+
for k, v in state_dict.items():
65+
if k.startswith('backbone.'):
66+
state_dict_new[k[9:]] = v
67+
else:
68+
state_dict_new[k] = v
69+
state_dict = state_dict_new
70+
# strip prefix of state_dict
71+
if list(state_dict.keys())[0].startswith('module.'):
72+
state_dict = {k[7:]: v for k, v in state_dict.items()}
73+
# reshape absolute position embedding
74+
if state_dict.get('absolute_pos_embed') is not None:
75+
absolute_pos_embed = state_dict['absolute_pos_embed']
76+
N1, L, C1 = absolute_pos_embed.size()
77+
N2, C2, H, W = swin.absolute_pos_embed.size()
78+
if not (N1 != N2 or C1 != C2 or L != H * W):
79+
state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2).contiguous()
80+
# interpolate position bias table if needed
81+
relative_position_bias_table_keys = [k for k in state_dict.keys() if 'relative_position_bias_table' in k]
82+
for table_key in relative_position_bias_table_keys:
83+
table_pretrained = state_dict[table_key]
84+
table_current = swin.state_dict()[table_key]
85+
L1, nH1 = table_pretrained.size()
86+
L2, nH2 = table_current.size()
87+
if (nH1 == nH2) and (L1 != L2):
88+
S1 = int(L1**0.5)
89+
S2 = int(L2**0.5)
90+
table_pretrained_resized = F.interpolate(table_pretrained.permute(1, 0).reshape(1, nH1, S1, S1), size=(S2, S2), mode='bicubic')
91+
state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0).contiguous()
92+
try:
93+
swin.load_state_dict(state_dict, strict=False)
94+
except Exception as err:
95+
print(err)
96+
try:
97+
swin.load_state_dict(state_dict, strict=True)
98+
except Exception as err:
99+
print(err)

tests/test_backbones/test_twins.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
'''
2+
Function:
3+
Implementation of Testing Twins
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.twins import DEFAULT_MODEL_URLS
9+
10+
11+
'''Twins'''
12+
cfgs = [
13+
{'type': 'PCPVT', 'structure_type': 'pcpvt_base', 'pretrained': True, 'selected_indices': (0, 1, 2, 3),
14+
'norm_cfg': {'type': 'LayerNorm'}, 'depths': [3, 4, 18, 3], 'drop_path_rate': 0.3,},
15+
{'type': 'PCPVT', 'structure_type': 'pcpvt_large', 'pretrained': True, 'selected_indices': (0, 1, 2, 3),
16+
'norm_cfg': {'type': 'LayerNorm'}, 'depths': [3, 8, 27, 3], 'drop_path_rate': 0.3,},
17+
{'type': 'PCPVT', 'structure_type': 'pcpvt_small', 'pretrained': True, 'selected_indices': (0, 1, 2, 3),
18+
'norm_cfg': {'type': 'LayerNorm'}, 'depths': [3, 4, 6, 3], 'drop_path_rate': 0.2,},
19+
{'type': 'SVT', 'structure_type': 'svt_base', 'pretrained': True, 'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
20+
'embed_dims': [96, 192, 384, 768], 'num_heads': [3, 6, 12, 24], 'mlp_ratios': [4, 4, 4, 4], 'depths': [2, 2, 18, 2],
21+
'windiow_sizes': [7, 7, 7, 7], 'norm_after_stage': True, 'drop_path_rate': 0.2},
22+
{'type': 'SVT', 'structure_type': 'svt_large', 'pretrained': True, 'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
23+
'embed_dims': [128, 256, 512, 1024], 'num_heads': [4, 8, 16, 32], 'mlp_ratios': [4, 4, 4, 4], 'depths': [2, 2, 18, 2],
24+
'windiow_sizes': [7, 7, 7, 7], 'norm_after_stage': True, 'drop_path_rate': 0.3},
25+
{'type': 'SVT', 'structure_type': 'svt_small', 'pretrained': True, 'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm'},
26+
'embed_dims': [64, 128, 256, 512], 'num_heads': [2, 4, 8, 16], 'mlp_ratios': [4, 4, 4, 4], 'depths': [2, 2, 10, 4],
27+
'windiow_sizes': [7, 7, 7, 7], 'norm_after_stage': True, 'drop_path_rate': 0.2},
28+
]
29+
for cfg in cfgs:
30+
twins = BuildBackbone(backbone_cfg=cfg)
31+
state_dict = loadpretrainedweights(
32+
structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
33+
)
34+
state_dict = twins.twinsconvert(cfg['structure_type'], state_dict)
35+
try:
36+
twins.load_state_dict(state_dict, strict=False)
37+
except Exception as err:
38+
print(err)
39+
try:
40+
twins.load_state_dict(state_dict, strict=True)
41+
except Exception as err:
42+
print(err)

tests/test_backbones/test_vit.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
'''
2+
Function:
3+
Implementation of Testing ViT
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
import math
8+
from ssseg.modules import BuildBackbone, loadpretrainedweights
9+
from ssseg.modules.models.backbones.vit import DEFAULT_MODEL_URLS
10+
11+
12+
'''ViTs'''
13+
cfgs = [
14+
{'type': 'VisionTransformer', 'structure_type': 'jx_vit_large_p16_384', 'img_size': (512, 512), 'out_indices': (9, 14, 19, 23),
15+
'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6}, 'pretrained': True, 'selected_indices': (0, 1, 2, 3),
16+
'patch_size': 16, 'embed_dims': 1024, 'num_layers': 24, 'num_heads': 16, 'mlp_ratio': 4,
17+
'qkv_bias': True, 'drop_rate': 0.1, 'attn_drop_rate': 0., 'drop_path_rate': 0., 'with_cls_token': True,
18+
'output_cls_token': False, 'patch_norm': False, 'final_norm': False, 'num_fcs': 2,}
19+
]
20+
for cfg in cfgs:
21+
vit = BuildBackbone(backbone_cfg=cfg)
22+
state_dict = loadpretrainedweights(
23+
structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
24+
)
25+
state_dict = vit.vitconvert(state_dict)
26+
if 'pos_embed' in state_dict.keys():
27+
if vit.pos_embed.shape != state_dict['pos_embed'].shape:
28+
h, w = vit.img_size
29+
pos_size = int(math.sqrt(state_dict['pos_embed'].shape[1] - 1))
30+
state_dict['pos_embed'] = vit.resizeposembed(state_dict['pos_embed'], (h // vit.patch_size, w // vit.patch_size), (pos_size, pos_size), vit.interpolate_mode)
31+
try:
32+
vit.load_state_dict(state_dict, strict=False)
33+
except Exception as err:
34+
print(err)
35+
try:
36+
vit.load_state_dict(state_dict, strict=True)
37+
except Exception as err:
38+
print(err)

0 commit comments

Comments
 (0)