Skip to content

Commit be448ba

Browse files
tests for hrnet
1 parent e1cefb5 commit be448ba

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/test_backbones/test_hrnet.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
'''
2+
Function:
3+
Implementation of Testing HRNet
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from ssseg.modules import BuildBackbone, loadpretrainedweights
8+
from ssseg.modules.models.backbones.hrnet import DEFAULT_MODEL_URLS
9+
10+
11+
'''hrnets'''
12+
cfgs = [
13+
{'type': 'HRNet', 'structure_type': 'hrnetv2_w18_small', 'arch': 'hrnetv2_w18_small', 'pretrained': True, 'selected_indices': (0, 0),},
14+
{'type': 'HRNet', 'structure_type': 'hrnetv2_w18', 'arch': 'hrnetv2_w18', 'pretrained': True, 'selected_indices': (0, 0),},
15+
{'type': 'HRNet', 'structure_type': 'hrnetv2_w32', 'arch': 'hrnetv2_w32', 'pretrained': True, 'selected_indices': (0, 0),},
16+
{'type': 'HRNet', 'structure_type': 'hrnetv2_w40', 'arch': 'hrnetv2_w40', 'pretrained': True, 'selected_indices': (0, 0),},
17+
{'type': 'HRNet', 'structure_type': 'hrnetv2_w48', 'arch': 'hrnetv2_w48', 'pretrained': True, 'selected_indices': (0, 0),},
18+
]
19+
for cfg in cfgs:
20+
hrnet = BuildBackbone(backbone_cfg=cfg)
21+
state_dict = loadpretrainedweights(structure_type=cfg['structure_type'], pretrained_model_path='', default_model_urls=DEFAULT_MODEL_URLS)
22+
try:
23+
hrnet.load_state_dict(state_dict, strict=False)
24+
except Exception as err:
25+
print(err)
26+
try:
27+
hrnet.load_state_dict(state_dict, strict=True)
28+
except Exception as err:
29+
print(err)

0 commit comments

Comments
 (0)