1+ import torch
2+ import torch .nn as nn
3+
4+ from model_constructor .net import Net , NewResBlock , ResBlock
5+ # from model_constructor.layers import SEModule, SimpleSelfAttention
6+
7+
8+ bs_test = 4
9+
10+
11+ params = dict (
12+ block = [ResBlock , NewResBlock ],
13+ expansion = [1 , 2 ],
14+ groups = [1 , 2 ],
15+ dw = [0 , 1 ],
16+ div_groups = [None , 2 ],
17+ sa = [0 , 1 ],
18+ se = [0 , 1 ],
19+ bn_1st = [True , False ],
20+ zero_bn = [True , False ],
21+ stem_bn_end = [True , False ],
22+ stem_stride_on = [0 , 1 ]
23+ )
24+
25+
26+ def value_name (value ) -> str : # pragma: no cover
27+ name = getattr (value , "__name__" , None )
28+ if name is not None :
29+ return name
30+ if isinstance (value , nn .Module ):
31+ return value ._get_name ()
32+ else :
33+ return value
34+
35+
36+ def ids_fn (key , value ):
37+ return [f"{ key [:2 ]} _{ value_name (v )} " for v in value ]
38+
39+
40+ def pytest_generate_tests (metafunc ):
41+ for key , value in params .items ():
42+ if key in metafunc .fixturenames :
43+ metafunc .parametrize (key , value , ids = ids_fn (key , value ))
44+
45+
46+ def test_Net (
47+ block , expansion ,
48+ groups ,
49+ # dw, div_groups,
50+ ):
51+ """test Net"""
52+ c_in = 3
53+ img_size = 16
54+ c_out = 8
55+ name = "Test name"
56+
57+ mc = Net (
58+ name , c_in , c_out , block ,
59+ expansion = expansion ,
60+ stem_sizes = [8 , 16 ],
61+ block_sizes = [16 , 32 , 64 , 128 ],
62+ groups = groups ,
63+ # dw=dw,
64+ # div_groups=div_groups,
65+ # bn_1st=bn_1st, zero_bn=zero_bn,
66+ # stem_bn_end=stem_bn_end,
67+ )
68+ assert f"{ name } constructor" in str (mc )
69+ model = mc ()
70+ xb = torch .randn (bs_test , c_in , img_size , img_size )
71+ pred = model (xb )
72+ assert pred .shape == torch .Size ([bs_test , c_out ])
73+
74+
75+ def test_Net_div_gr (
76+ block , expansion ,
77+ div_groups ,
78+ ):
79+ """test Net"""
80+ c_in = 3
81+ img_size = 16
82+ c_out = 8
83+ name = "Test name"
84+
85+ mc = Net (
86+ name , c_in , c_out , block ,
87+ expansion = expansion ,
88+ stem_sizes = [8 , 16 ],
89+ block_sizes = [16 , 32 , 64 , 128 ],
90+ div_groups = div_groups ,
91+ )
92+ assert f"{ name } constructor" in str (mc )
93+ model = mc ()
94+ xb = torch .randn (bs_test , c_in , img_size , img_size )
95+ pred = model (xb )
96+ assert pred .shape == torch .Size ([bs_test , c_out ])
97+
98+
99+ def test_Net_dw (
100+ block , expansion ,
101+ dw
102+ ):
103+ """test Net"""
104+ c_in = 3
105+ img_size = 16
106+ c_out = 8
107+ name = "Test name"
108+
109+ mc = Net (
110+ name , c_in , c_out , block ,
111+ expansion = expansion ,
112+ stem_sizes = [8 , 16 ],
113+ block_sizes = [16 , 32 , 64 , 128 ],
114+ dw = dw
115+ )
116+ assert f"{ name } constructor" in str (mc )
117+ model = mc ()
118+ xb = torch .randn (bs_test , c_in , img_size , img_size )
119+ pred = model (xb )
120+ assert pred .shape == torch .Size ([bs_test , c_out ])
121+
122+
123+ def test_Net_2 (
124+ block , expansion ,
125+ bn_1st , zero_bn ,
126+ ):
127+ """test Net"""
128+ c_in = 3
129+ img_size = 16
130+ c_out = 8
131+ name = "Test name"
132+
133+ mc = Net (
134+ name , c_in , c_out , block ,
135+ expansion = expansion ,
136+ stem_sizes = [8 , 16 ],
137+ block_sizes = [16 , 32 , 64 , 128 ],
138+ bn_1st = bn_1st , zero_bn = zero_bn ,
139+ )
140+ assert f"{ name } constructor" in str (mc )
141+ model = mc ()
142+ xb = torch .randn (bs_test , c_in , img_size , img_size )
143+ pred = model (xb )
144+ assert pred .shape == torch .Size ([bs_test , c_out ])
145+
146+
147+ def test_Net_stem (
148+ stem_bn_end ,
149+ stem_stride_on
150+ ):
151+ """test Net"""
152+ c_in = 3
153+ img_size = 16
154+ c_out = 8
155+ name = "Test name"
156+
157+ mc = Net (
158+ name , c_in , c_out ,
159+ stem_sizes = [8 , 16 ],
160+ block_sizes = [16 , 32 , 64 , 128 ],
161+ stem_bn_end = stem_bn_end ,
162+ stem_stride_on = stem_stride_on
163+ )
164+ assert f"{ name } constructor" in str (mc )
165+ model = mc ()
166+ xb = torch .randn (bs_test , c_in , img_size , img_size )
167+ pred = model (xb )
168+ assert pred .shape == torch .Size ([bs_test , c_out ])
0 commit comments