1+ import torch
2+ import torch .nn as nn
3+
4+ from model_constructor .net import Net , NewResBlock , ResBlock
5+ # from model_constructor.layers import SEModule, SimpleSelfAttention
6+
7+
8+ bs_test = 4
9+
10+
11+ params = dict (
12+ block = [ResBlock , NewResBlock ],
13+ expansion = [1 , 2 ],
14+ groups = [1 , 2 ],
15+ dw = [0 , 1 ],
16+ div_groups = [None , 2 ],
17+ sa = [0 , 1 ],
18+ se = [0 , 1 ],
19+ bn_1st = [True , False ],
20+ zero_bn = [True , False ],
21+ stem_bn_end = [True , False ],
22+ stem_stride_on = [0 , 1 ]
23+ )
24+
25+
26+ def value_name (value ) -> str : # pragma: no cover
27+ name = getattr (value , "__name__" , None )
28+ if name is not None :
29+ return name
30+ if isinstance (value , nn .Module ):
31+ return value ._get_name ()
32+ else :
33+ return value
34+
35+
36+ def ids_fn (key , value ):
37+ return [f"{ key [:2 ]} _{ value_name (v )} " for v in value ]
38+
39+
40+ def pytest_generate_tests (metafunc ):
41+ for key , value in params .items ():
42+ if key in metafunc .fixturenames :
43+ metafunc .parametrize (key , value , ids = ids_fn (key , value ))
44+
45+
46+ def test_Net (
47+ block , expansion ,
48+ groups ,
49+ ):
50+ """test Net"""
51+ c_in = 3
52+ img_size = 16
53+ c_out = 8
54+ name = "Test name"
55+
56+ mc = Net (
57+ name , c_in , c_out , block ,
58+ expansion = expansion ,
59+ stem_sizes = [8 , 16 ],
60+ block_sizes = [16 , 32 , 64 , 128 ],
61+ groups = groups ,
62+ # dw=dw,
63+ # div_groups=div_groups,
64+ # bn_1st=bn_1st, zero_bn=zero_bn,
65+ # stem_bn_end=stem_bn_end,
66+ )
67+ assert f"{ name } constructor" in str (mc )
68+ model = mc ()
69+ xb = torch .randn (bs_test , c_in , img_size , img_size )
70+ pred = model (xb )
71+ assert pred .shape == torch .Size ([bs_test , c_out ])
72+
73+
74+ def test_Net_SE_SA (
75+ block , expansion ,
76+ se , sa
77+ ):
78+ """test Net"""
79+ c_in = 3
80+ img_size = 16
81+ c_out = 8
82+ name = "Test name"
83+
84+ mc = Net (
85+ name , c_in , c_out , block ,
86+ expansion = expansion ,
87+ stem_sizes = [8 , 16 ],
88+ block_sizes = [16 , 32 , 64 , 128 ],
89+ se = se , sa = sa
90+ )
91+ assert f"{ name } constructor" in str (mc )
92+ model = mc ()
93+ xb = torch .randn (bs_test , c_in , img_size , img_size )
94+ pred = model (xb )
95+ assert pred .shape == torch .Size ([bs_test , c_out ])
96+
97+
98+ def test_Net_div_gr (
99+ block , expansion ,
100+ div_groups ,
101+ ):
102+ """test Net"""
103+ c_in = 3
104+ img_size = 16
105+ c_out = 8
106+ name = "Test name"
107+
108+ mc = Net (
109+ name , c_in , c_out , block ,
110+ expansion = expansion ,
111+ stem_sizes = [8 , 16 ],
112+ block_sizes = [16 , 32 , 64 , 128 ],
113+ div_groups = div_groups ,
114+ )
115+ assert f"{ name } constructor" in str (mc )
116+ model = mc ()
117+ xb = torch .randn (bs_test , c_in , img_size , img_size )
118+ pred = model (xb )
119+ assert pred .shape == torch .Size ([bs_test , c_out ])
120+
121+
122+ def test_Net_dw (
123+ block , expansion ,
124+ dw
125+ ):
126+ """test Net"""
127+ c_in = 3
128+ img_size = 16
129+ c_out = 8
130+ name = "Test name"
131+
132+ mc = Net (
133+ name , c_in , c_out , block ,
134+ expansion = expansion ,
135+ stem_sizes = [8 , 16 ],
136+ block_sizes = [16 , 32 , 64 , 128 ],
137+ dw = dw
138+ )
139+ assert f"{ name } constructor" in str (mc )
140+ model = mc ()
141+ xb = torch .randn (bs_test , c_in , img_size , img_size )
142+ pred = model (xb )
143+ assert pred .shape == torch .Size ([bs_test , c_out ])
144+
145+
146+ def test_Net_2 (
147+ block , expansion ,
148+ bn_1st , zero_bn ,
149+ ):
150+ """test Net"""
151+ c_in = 3
152+ img_size = 16
153+ c_out = 8
154+ name = "Test name"
155+
156+ mc = Net (
157+ name , c_in , c_out , block ,
158+ expansion = expansion ,
159+ stem_sizes = [8 , 16 ],
160+ block_sizes = [16 , 32 , 64 , 128 ],
161+ bn_1st = bn_1st , zero_bn = zero_bn ,
162+ )
163+ assert f"{ name } constructor" in str (mc )
164+ model = mc ()
165+ xb = torch .randn (bs_test , c_in , img_size , img_size )
166+ pred = model (xb )
167+ assert pred .shape == torch .Size ([bs_test , c_out ])
168+
169+
170+ def test_Net_stem (
171+ stem_bn_end ,
172+ stem_stride_on
173+ ):
174+ """test Net"""
175+ c_in = 3
176+ img_size = 16
177+ c_out = 8
178+ name = "Test name"
179+
180+ mc = Net (
181+ name , c_in , c_out ,
182+ stem_sizes = [8 , 16 ],
183+ block_sizes = [16 , 32 , 64 , 128 ],
184+ stem_bn_end = stem_bn_end ,
185+ stem_stride_on = stem_stride_on
186+ )
187+ assert f"{ name } constructor" in str (mc )
188+ model = mc ()
189+ xb = torch .randn (bs_test , c_in , img_size , img_size )
190+ pred = model (xb )
191+ assert pred .shape == torch .Size ([bs_test , c_out ])
0 commit comments