Skip to content

Commit 65ba520

Browse files
Merge pull request #80 from SegmentationBLWX/dev
merge from dev
2 parents fcab7be + aee79d9 commit 65ba520

File tree

3 files changed

+61
-0
lines changed

3 files changed

+61
-0
lines changed

ssseg/modules/models/backbones/mae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def initweights(self):
3030
class MAETransformerEncoderLayer(BEiTTransformerEncoderLayer):
3131
'''buildattn'''
3232
def buildattn(self, attn_cfg):
33+
valid_keys = ['embed_dims', 'num_heads', 'window_size', 'bias', 'qk_scale', 'attn_drop_rate', 'proj_drop_rate']
34+
for key in list(attn_cfg.keys()):
35+
if key not in valid_keys: attn_cfg.pop(key)
3336
self.attn = MAEAttention(**attn_cfg)
3437

3538

tests/test_backbones/test_mae.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
'''
2+
Function:
3+
Implementation of Testing MAE
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.mae import DEFAULT_MODEL_URLS
9+
10+
11+
'''MAEs'''
12+
cfgs = [{
13+
'type': 'MAE', 'structure_type': 'mae_pretrain_vit_base', 'pretrained': True,
14+
'img_size': (512, 512), 'patch_size': 16, 'embed_dims': 768, 'num_layers': 12,
15+
'num_heads': 12, 'mlp_ratio': 4, 'init_values': 1.0, 'drop_path_rate': 0.1,
16+
'selected_indices': (0, 1, 2, 3), 'norm_cfg': {'type': 'LayerNorm', 'eps': 1e-6},
17+
}]
18+
for cfg in cfgs:
19+
mae = BuildBackbone(cfg)
20+
state_dict = loadpretrainedweights(
21+
structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
22+
)
23+
state_dict = mae.beitconvert(state_dict)
24+
state_dict = mae.resizerelposembed(state_dict)
25+
try:
26+
mae.load_state_dict(state_dict, strict=False)
27+
except Exception as err:
28+
print(err)
29+
try:
30+
mae.load_state_dict(state_dict, strict=True)
31+
except Exception as err:
32+
print(err)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
'''
2+
Function:
3+
Implementation of Testing ResNeSt
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.resnest import DEFAULT_MODEL_URLS
9+
10+
11+
'''resnests'''
12+
for depth in [50, 101, 200]:
13+
resnest = BuildBackbone(backbone_cfg={
14+
'type': 'ResNeSt', 'depth': depth, 'structure_type': f'resnest{depth}', 'pretrained': True, 'outstride': 8, 'selected_indices': (0, 1, 2, 3), 'stem_channels': {50: 64, 101: 128, 200: 128}[depth]
15+
})
16+
state_dict = loadpretrainedweights(
17+
structure_type=f'resnest{depth}', pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS
18+
)
19+
try:
20+
resnest.load_state_dict(state_dict, strict=False)
21+
except Exception as err:
22+
print(err)
23+
try:
24+
resnest.load_state_dict(state_dict, strict=True)
25+
except Exception as err:
26+
print(err)

0 commit comments

Comments
 (0)