22import torch .nn as nn
33
44from model_constructor .net import Net , NewResBlock , ResBlock
5+
56# from model_constructor.layers import SEModule, SimpleSelfAttention
67
78
1920 bn_1st = [True , False ],
2021 zero_bn = [True , False ],
2122 stem_bn_end = [True , False ],
22- stem_stride_on = [0 , 1 ]
23+ stem_stride_on = [0 , 1 ],
2324)
2425
2526
@@ -28,9 +29,8 @@ def value_name(value) -> str: # pragma: no cover
2829 if name is not None :
2930 return name
3031 if isinstance (value , nn .Module ):
31- return value ._get_name ()
32- else :
33- return value
32+ return value ._get_name () # pylint: disable=W0212
33+ return value
3434
3535
3636def ids_fn (key , value ):
@@ -44,7 +44,8 @@ def pytest_generate_tests(metafunc):
4444
4545
4646def test_Net (
47- block , expansion ,
47+ block ,
48+ expansion ,
4849 groups ,
4950):
5051 """test Net"""
@@ -54,15 +55,14 @@ def test_Net(
5455 name = "Test name"
5556
5657 mc = Net (
57- name , c_in , c_out , block ,
58+ name ,
59+ c_in ,
60+ c_out ,
61+ block ,
5862 expansion = expansion ,
5963 stem_sizes = [8 , 16 ],
6064 block_sizes = [16 , 32 , 64 , 128 ],
6165 groups = groups ,
62- # dw=dw,
63- # div_groups=div_groups,
64- # bn_1st=bn_1st, zero_bn=zero_bn,
65- # stem_bn_end=stem_bn_end,
6666 )
6767 assert f"{ name } constructor" in str (mc )
6868 model = mc ()
@@ -71,22 +71,23 @@ def test_Net(
7171 assert pred .shape == torch .Size ([bs_test , c_out ])
7272
7373
74- def test_Net_SE_SA (
75- block , expansion ,
76- se , sa
77- ):
74+ def test_Net_SE_SA (block , expansion , se , sa ):
7875 """test Net"""
7976 c_in = 3
8077 img_size = 16
8178 c_out = 8
8279 name = "Test name"
8380
8481 mc = Net (
85- name , c_in , c_out , block ,
82+ name ,
83+ c_in ,
84+ c_out ,
85+ block ,
8686 expansion = expansion ,
8787 stem_sizes = [8 , 16 ],
8888 block_sizes = [16 , 32 , 64 , 128 ],
89- se = se , sa = sa
89+ se = se ,
90+ sa = sa ,
9091 )
9192 assert f"{ name } constructor" in str (mc )
9293 model = mc ()
@@ -96,7 +97,8 @@ def test_Net_SE_SA(
9697
9798
9899def test_Net_div_gr (
99- block , expansion ,
100+ block ,
101+ expansion ,
100102 div_groups ,
101103):
102104 """test Net"""
@@ -106,7 +108,10 @@ def test_Net_div_gr(
106108 name = "Test name"
107109
108110 mc = Net (
109- name , c_in , c_out , block ,
111+ name ,
112+ c_in ,
113+ c_out ,
114+ block ,
110115 expansion = expansion ,
111116 stem_sizes = [8 , 16 ],
112117 block_sizes = [16 , 32 , 64 , 128 ],
@@ -119,22 +124,22 @@ def test_Net_div_gr(
119124 assert pred .shape == torch .Size ([bs_test , c_out ])
120125
121126
122- def test_Net_dw (
123- block , expansion ,
124- dw
125- ):
127+ def test_Net_dw (block , expansion , dw ):
126128 """test Net"""
127129 c_in = 3
128130 img_size = 16
129131 c_out = 8
130132 name = "Test name"
131133
132134 mc = Net (
133- name , c_in , c_out , block ,
135+ name ,
136+ c_in ,
137+ c_out ,
138+ block ,
134139 expansion = expansion ,
135140 stem_sizes = [8 , 16 ],
136141 block_sizes = [16 , 32 , 64 , 128 ],
137- dw = dw
142+ dw = dw ,
138143 )
139144 assert f"{ name } constructor" in str (mc )
140145 model = mc ()
@@ -144,8 +149,10 @@ def test_Net_dw(
144149
145150
146151def test_Net_2 (
147- block , expansion ,
148- bn_1st , zero_bn ,
152+ block ,
153+ expansion ,
154+ bn_1st ,
155+ zero_bn ,
149156):
150157 """test Net"""
151158 c_in = 3
@@ -154,11 +161,15 @@ def test_Net_2(
154161 name = "Test name"
155162
156163 mc = Net (
157- name , c_in , c_out , block ,
164+ name ,
165+ c_in ,
166+ c_out ,
167+ block ,
158168 expansion = expansion ,
159169 stem_sizes = [8 , 16 ],
160170 block_sizes = [16 , 32 , 64 , 128 ],
161- bn_1st = bn_1st , zero_bn = zero_bn ,
171+ bn_1st = bn_1st ,
172+ zero_bn = zero_bn ,
162173 )
163174 assert f"{ name } constructor" in str (mc )
164175 model = mc ()
@@ -167,25 +178,24 @@ def test_Net_2(
167178 assert pred .shape == torch .Size ([bs_test , c_out ])
168179
169180
170- def test_Net_stem (
171- stem_bn_end ,
172- stem_stride_on
173- ):
181+ def test_Net_stem (stem_bn_end , stem_stride_on ):
174182 """test Net"""
175183 c_in = 3
176184 img_size = 16
177185 c_out = 8
178186 name = "Test name"
179187
180188 mc = Net (
181- name , c_in , c_out ,
189+ name ,
190+ c_in ,
191+ c_out ,
182192 stem_sizes = [8 , 16 ],
183193 block_sizes = [16 , 32 , 64 , 128 ],
184194 stem_bn_end = stem_bn_end ,
185- stem_stride_on = stem_stride_on
195+ stem_stride_on = stem_stride_on ,
186196 )
187197 assert f"{ name } constructor" in str (mc )
188198 model = mc ()
189199 xb = torch .randn (bs_test , c_in , img_size , img_size )
190200 pred = model (xb )
191- assert pred .shape == torch .Size ([bs_test , c_out ])
201+ assert pred .shape == torch .Size ([bs_test , c_out ])
0 commit comments