1+ '''
2+ Function:
3+ Implementation of Testing SwinTransformer
4+ Author:
5+ Zhenchao Jin
6+ '''
7+ import torch .nn .functional as F
8+ from ssseg .modules import BuildBackbone , loadpretrainedweights
9+ from ssseg .modules .models .backbones .swin import DEFAULT_MODEL_URLS
10+
11+
12+ '''SwinTransformers'''
13+ cfgs = [
14+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_large_patch4_window12_384_22kto1k' , 'pretrained' : True ,
15+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
16+ 'pretrain_img_size' : 384 , 'in_channels' : 3 , 'embed_dims' : 192 , 'patch_size' : 4 , 'window_size' : 12 , 'mlp_ratio' : 4 ,
17+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [6 , 12 , 24 , 48 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
18+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
19+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_large_patch4_window12_384_22k' , 'pretrained' : True ,
20+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
21+ 'pretrain_img_size' : 384 , 'in_channels' : 3 , 'embed_dims' : 192 , 'patch_size' : 4 , 'window_size' : 12 , 'mlp_ratio' : 4 ,
22+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [6 , 12 , 24 , 48 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
23+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
24+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_base_patch4_window12_384' , 'pretrained' : True ,
25+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
26+ 'pretrain_img_size' : 384 , 'in_channels' : 3 , 'embed_dims' : 128 , 'patch_size' : 4 , 'window_size' : 12 , 'mlp_ratio' : 4 ,
27+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [4 , 8 , 16 , 32 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
28+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
29+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_base_patch4_window7_224' , 'pretrained' : True ,
30+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
31+ 'pretrain_img_size' : 224 , 'in_channels' : 3 , 'embed_dims' : 128 , 'patch_size' : 4 , 'window_size' : 7 , 'mlp_ratio' : 4 ,
32+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [4 , 8 , 16 , 32 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
33+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
34+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_base_patch4_window12_384_22k' , 'pretrained' : True ,
35+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
36+ 'pretrain_img_size' : 384 , 'in_channels' : 3 , 'embed_dims' : 128 , 'patch_size' : 4 , 'window_size' : 12 , 'mlp_ratio' : 4 ,
37+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [4 , 8 , 16 , 32 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
38+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
39+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_base_patch4_window7_224_22k' , 'pretrained' : True ,
40+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
41+ 'pretrain_img_size' : 224 , 'in_channels' : 3 , 'embed_dims' : 128 , 'patch_size' : 4 , 'window_size' : 7 , 'mlp_ratio' : 4 ,
42+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [4 , 8 , 16 , 32 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
43+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
44+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_small_patch4_window7_224' , 'pretrained' : True ,
45+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
46+ 'pretrain_img_size' : 224 , 'in_channels' : 3 , 'embed_dims' : 96 , 'patch_size' : 4 , 'window_size' : 7 , 'mlp_ratio' : 4 ,
47+ 'depths' : [2 , 2 , 18 , 2 ], 'num_heads' : [3 , 6 , 12 , 24 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
48+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
49+ {'type' : 'SwinTransformer' , 'structure_type' : 'swin_tiny_patch4_window7_224' , 'pretrained' : True ,
50+ 'selected_indices' : (0 , 1 , 2 , 3 ), 'norm_cfg' : {'type' : 'LayerNorm' },
51+ 'pretrain_img_size' : 224 , 'in_channels' : 3 , 'embed_dims' : 96 , 'patch_size' : 4 , 'window_size' : 7 , 'mlp_ratio' : 4 ,
52+ 'depths' : [2 , 2 , 6 , 2 ], 'num_heads' : [3 , 6 , 12 , 24 ], 'qkv_bias' : True , 'qk_scale' : None , 'patch_norm' : True ,
53+ 'drop_rate' : 0. , 'attn_drop_rate' : 0. , 'drop_path_rate' : 0.3 , 'use_abs_pos_embed' : False ,},
54+ ]
55+ for cfg in cfgs :
56+ swin = BuildBackbone (cfg )
57+ state_dict = loadpretrainedweights (
58+ structure_type = cfg ['structure_type' ], pretrained_model_path = '' , default_model_urls = DEFAULT_MODEL_URLS
59+ )
60+ state_dict = swin .swinconvert (state_dict )
61+ # be consistent
62+ from collections import OrderedDict
63+ state_dict_new = OrderedDict ()
64+ for k , v in state_dict .items ():
65+ if k .startswith ('backbone.' ):
66+ state_dict_new [k [9 :]] = v
67+ else :
68+ state_dict_new [k ] = v
69+ state_dict = state_dict_new
70+ # strip prefix of state_dict
71+ if list (state_dict .keys ())[0 ].startswith ('module.' ):
72+ state_dict = {k [7 :]: v for k , v in state_dict .items ()}
73+ # reshape absolute position embedding
74+ if state_dict .get ('absolute_pos_embed' ) is not None :
75+ absolute_pos_embed = state_dict ['absolute_pos_embed' ]
76+ N1 , L , C1 = absolute_pos_embed .size ()
77+ N2 , C2 , H , W = swin .absolute_pos_embed .size ()
78+ if not (N1 != N2 or C1 != C2 or L != H * W ):
79+ state_dict ['absolute_pos_embed' ] = absolute_pos_embed .view (N2 , H , W , C2 ).permute (0 , 3 , 1 , 2 ).contiguous ()
80+ # interpolate position bias table if needed
81+ relative_position_bias_table_keys = [k for k in state_dict .keys () if 'relative_position_bias_table' in k ]
82+ for table_key in relative_position_bias_table_keys :
83+ table_pretrained = state_dict [table_key ]
84+ table_current = swin .state_dict ()[table_key ]
85+ L1 , nH1 = table_pretrained .size ()
86+ L2 , nH2 = table_current .size ()
87+ if (nH1 == nH2 ) and (L1 != L2 ):
88+ S1 = int (L1 ** 0.5 )
89+ S2 = int (L2 ** 0.5 )
90+ table_pretrained_resized = F .interpolate (table_pretrained .permute (1 , 0 ).reshape (1 , nH1 , S1 , S1 ), size = (S2 , S2 ), mode = 'bicubic' )
91+ state_dict [table_key ] = table_pretrained_resized .view (nH2 , L2 ).permute (1 , 0 ).contiguous ()
92+ try :
93+ swin .load_state_dict (state_dict , strict = False )
94+ except Exception as err :
95+ print (err )
96+ try :
97+ swin .load_state_dict (state_dict , strict = True )
98+ except Exception as err :
99+ print (err )
0 commit comments