@@ -129,7 +129,10 @@ def avgpool(float(B, C, H, W) input) -> (output) {
129129 T = tc .define (tc_str , tc .make_naive_options_factory ())
130130 inp = torch .ones (1 , 1 , 4 , 4 , device = 'cuda' )
131131 out = T .avgpool (inp )
132- # TODO: test results!!!
132+
133+ from torch .nn .modules .pooling import AvgPool2d
134+ ref = AvgPool2d (2 , stride = 1 ).forward (inp )
135+ tc .assert_almost_equal (ref , out , inp )
133136
134137 #
135138 # This test implements group normalization as a single TC kernel.
@@ -138,13 +141,16 @@ def avgpool(float(B, C, H, W) input) -> (output) {
138141 def test_group_norm_fused (self ):
139142 group_normalization = """
140143 def group_normalization(
141- float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta) -> (Sum, SumSq, O)
144+ float(N, G, D, H, W) I, float(G, D) gamma, float(G, D) beta)
145+ -> (Sum, SumSq, O)
142146 {
143147 Sum(n, g) +=! I(n, g, r_d, r_h, r_w)
144148 SumSq(n, g) +=! I(n, g, r_d, r_h, r_w) * I(n, g, r_d, r_h, r_w)
145- O(n, g, d, h, w) = gamma(g, d)
149+ O(n, g, d, h, w) = gamma(g, d)
146150 * ( I(n, g, d, h, w) - Sum(n, g) / (D * H * W))
147- * rsqrt( (SumSq(n, g) / (D * H * W) - Sum(n, g) * Sum(n, g)) + 1e-5 )
151+ * rsqrt( (SumSq(n, g) - Sum(n, g) * Sum(n, g) / (D * H * W))
152+ / (D * H * W)
153+ + 1e-5)
148154 + beta(g, d)
149155 }
150156 """
@@ -157,10 +163,15 @@ def group_normalization(
157163 tuner_config = tuner_config ))
158164 I , gamma , beta = (
159165 torch .randn (N , G , D , H , W , device = 'cuda' ),
160- torch .randn (G , D , device = 'cuda' ),
161- torch .randn (G , D , device = 'cuda' ))
166+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
167+ torch .randn (G , D , device = 'cuda' ). zero_ () )
162168 Sum , SumSq , O = T .group_normalization (I , gamma , beta )
163- # TODO: test results!!!
169+
170+ from torch .nn .modules .normalization import GroupNorm
171+ GN = GroupNorm (G , G * D ).cuda ()
172+ ref = GN .forward (I .view ((N , G * D , H , W )))
173+
174+ tc .assert_almost_equal (ref , O .view ((N , G * D , H , W )), I , operations = D * H * W )
164175
165176 #
166177 # This test implements group normalization as 2 TC kernels
@@ -191,8 +202,8 @@ def group_normalization(
191202 N , G , D , H , W = 32 , 32 , 4 , 56 , 56
192203 I , gamma , beta = (
193204 torch .randn (N , G , D , H , W , device = 'cuda' ),
194- torch .randn (G , D , device = 'cuda' ),
195- torch .randn (G , D , device = 'cuda' ))
205+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
206+ torch .randn (G , D , device = 'cuda' ). zero_ () )
196207
197208 T = tc .define (
198209 group_normalization ,
@@ -208,7 +219,12 @@ def group_normalization(
208219 mean , var = T .moments (I .view ((N * G , - 1 )))
209220 out = T .group_normalization (
210221 I , gamma , beta , mean .view ((N , G )), var .view ((N , G )))
211- # TODO: test results!!!
222+
223+ from torch .nn .modules .normalization import GroupNorm
224+ GN = GroupNorm (G , G * D ).cuda ()
225+ ref = GN .forward (I .view ((N , G * D , H , W )))
226+
227+ tc .assert_almost_equal (ref , out .view ((N , G * D , H , W )), I , operations = D * H * W )
212228
213229 #
214230 # TC example without fallback but with tuning starting from MappingOptions('naive').
@@ -239,8 +255,8 @@ def group_normalization(
239255 N , G , D , H , W = 32 , 32 , 4 , 56 , 56
240256 I , gamma , beta = (
241257 torch .randn (N , G , D , H , W , device = 'cuda' ),
242- torch .randn (G , D , device = 'cuda' ),
243- torch .randn (G , D , device = 'cuda' ))
258+ torch .randn (G , D , device = 'cuda' ). fill_ ( 1.0 ) ,
259+ torch .randn (G , D , device = 'cuda' ). zero_ () )
244260
245261 T = tc .define (
246262 group_normalization ,
@@ -266,45 +282,63 @@ def group_normalization(
266282 out = T .group_normalization (
267283 I , gamma , beta , mean .view ((N , G )), var .view ((N , G )))
268284
285+ from torch .nn .modules .normalization import GroupNorm
286+ GN = GroupNorm (G , G * D ).cuda ()
287+ ref = GN .forward (I .view ((N , G * D , H , W )))
288+
289+ tc .assert_almost_equal (ref , out .view ((N , G * D , H , W )), I , operations = D * H * W )
290+
269291
270292 #
271293 # This tests single kernel forward/backward with tc.make_autograd.
272294 #
273295 def test_conv_with_backward_fused (self ):
274296 conv = """
275- def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
297+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
298+ -> (O)
299+ {
276300 O(n, m, h, w) +=!
277301 I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
302+ O(n, m, h, w) = O(n, m, h, w) + Bias(m)
278303 }
279304 def convolution_grad(
280- float(N,C,H,W) I, float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
281- -> (d_I, d_W1)
305+ float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias, float( N,M,H,W) d_O)
306+ -> (d_I, d_W1, d_Bias )
282307 {
283308 d_I(n, c, h, w) +=!
284309 d_O( n, r_m, h - r_kh, w - r_kw) * W1(r_m, c, r_kh, r_kw)
285310 d_W1(m, c, kh, kw) +=!
286311 d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
312+ # TODO: Bias incorrect + check
313+ d_Bias(m) = Bias(m)
287314 }
288315 """
289316
290317 N , C , H , W , O , kH , kW = 32 , 4 , 56 , 56 , 16 , 1 , 1
291- I , W = (
292- torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True ),
293- torch .randn (O , C , kH , kW , device = 'cuda' , requires_grad = True ))
318+ I = torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True )
294319 T = tc .define (
295320 conv ,
296321 tc .make_autotuned_options_factory (
297322 starting_options = 'naive' ,
298323 tuner_config = tuner_config ))
299324 convolution = tc .make_autograd (T .convolution , T .convolution_grad )
300325
326+ # Reference
327+ from torch .nn .modules .conv import Conv2d
328+ Conv = Conv2d (C , O , 1 , stride = 1 ).cuda ()
329+ ref = Conv .forward (I )
330+
331+ W = Conv .weight .clone ()
332+ Bias = Conv .bias .clone ()
333+
301334 # First occurrence triggers tuning (make_autotuned_options_factory)
302- out = convolution (I , W )
335+ out = convolution (I , W , Bias )
303336 out .sum ().backward ()
304337
305- out = convolution (I , W )
338+ out = convolution (I , W , Bias )
306339 out .sum ().backward ()
307- # TODO: test results!!!
340+
341+ tc .assert_almost_equal (ref , out , I , operations = C * kH * kW )
308342
309343 #
310344 # This tests 1-kernel forward/ 2-kernel backward with tc.make_autograd.
@@ -314,9 +348,12 @@ def convolution_grad(
314348 #
315349 def test_conv_with_backward_2kernels (self ):
316350 conv = """
317- def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
351+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1, float(M) Bias)
352+ -> (O)
353+ {
318354 O(n, m, h, w) +=!
319355 I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
356+ O(n, m, h, w) = O(n, m, h, w) + Bias(m)
320357 }
321358 def convolution_igrad(float(M,C,KH,KW) W1, float(N,M,H,W) d_O)
322359 -> (d_I)
@@ -329,6 +366,11 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
329366 d_W1(m, c, kh, kw) +=!
330367 d_O(r_n, m, r_h - kh, r_w - kw) * I(r_n, c, r_h, r_w)
331368 }
369+ def convolution_biasgrad(float(M) Bias) -> (d_Bias)
370+ {
371+ # TODO: Bias incorrect + check
372+ d_Bias(m) = Bias(m)
373+ }
332374 """
333375
334376 N , C , H , W , O , kH , kW = 32 , 4 , 56 , 56 , 16 , 1 , 1
@@ -337,26 +379,34 @@ def convolution_wgrad(float(N,C,H,W) I, float(N,M,H,W) d_O) -> (d_W1)
337379 tc .make_autotuned_options_factory (
338380 starting_options = 'naive' ,
339381 tuner_config = tuner_config ))
340- I , W = (
341- torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True ),
342- torch .randn (O , C , kH , kW , device = 'cuda' , requires_grad = True ))
382+ I = torch .randn (N , C , H , W , device = 'cuda' , requires_grad = True )
383+
384+ # Reference
385+ from torch .nn .modules .conv import Conv2d
386+ Conv = Conv2d (C , O , 1 , stride = 1 ).cuda ()
387+ ref = Conv .forward (I )
343388
344- def convolution_backward (I , W , d_O ):
389+ W = Conv .weight .clone ()
390+ Bias = Conv .bias .clone ()
391+
392+ def convolution_backward (I , W , Bias , d_O ):
345393 d_I = T .convolution_igrad (W , d_O )
346394 d_O = T .convolution_wgrad (I , d_O )
347- return (d_I , d_O )
395+ d_Bias = T .convolution_biasgrad (Bias )
396+ return (d_I , d_O , d_Bias )
348397
349398 convolution_function = tc .make_autograd (
350399 T .convolution , convolution_backward )
351400
352401 # First occurrence triggers tuning
353- out = convolution_function (I , W )
402+ out = convolution_function (I , W , Bias )
354403 out .sum ().backward ()
355404
356405 # Subsequent occurrences do not
357- out = convolution_function (I , W )
406+ out = convolution_function (I , W , Bias )
358407 out .sum ().backward ()
359- # TODO: test results!!!
408+
409+ tc .assert_almost_equal (ref , out , I , operations = C * kH * kW )
360410
361411 #
362412 # This tests the direct use of pybinds which are closer to C++
@@ -424,7 +474,13 @@ def tensordot(float(N, C1, C2, H, W) I0, float(N, C2, C3, H, W) I1)
424474 executor = tclib .compile (
425475 tensordot_str , entry_point , (I0 , I1 ), best_options )
426476 O = executor .run ((I0 , I1 ), ())
427- # TODO: test results!!!
477+
478+ # No simple torch baseline, compare against naive
479+ executor = tclib .compile (
480+ tensordot_str , entry_point , (I0 , I1 ), tc .MappingOptions ('naive' ))
481+ ref = executor .run ((I0 , I1 ), ())
482+
483+ tc .assert_almost_equal (ref , O , I0 , I1 , operations = C2 )
428484
429485if __name__ == '__main__' :
430486 unittest .main ()
0 commit comments