1818import hypothesis .strategies as st
1919import caffe2 .python .hypothesis_test_util as hu
2020import tensor_comprehensions as tc
21+ import torch
2122
2223from hypothesis import given , settings
2324from caffe2 .python import core , dyndep
3334def matmul(float(M,N) A, float(N,K) B) -> (output) {
3435 output(m, k) +=! A(m, r_n) * B(r_n, k)
3536}
36- """
37-
38- MATMUL_GRAD_LANG = """
3937def matmul_grad(float(M, N) A, float(N, K) B, float(M, K) d_O) -> (d_A, d_B) {
4038 d_A(m, n) +=! d_O(m, r_k) * B(n, r_k)
4139 d_B(n, k) +=! d_O(r_m, k) * A(r_m, n)
@@ -61,7 +59,7 @@ def ref(X, W):
6159 "TcOp" , ["X" , "Y" ], "out" ,
6260 tc_def = MATMUL_LANG ,
6361 tc_name = "matmul" ,
64- tc_grad_def = MATMUL_GRAD_LANG ,
62+ tc_grad_def = MATMUL_LANG ,
6563 tc_grad_name = "matmul_grad" ,
6664 inputs_used_by_gradient = [0 , 1 ],
6765 output_gradients_used_by_gradient = [0 ],
@@ -91,24 +89,23 @@ def ref(X, W):
9189 ** hu .gcs_gpu_only )
9290 @settings (max_examples = 2 )
9391 def test_matmul_tune_and_run (self , n , m , k , seed , gc , dc ):
94- matmul = tc .define (MATMUL_LANG , name = "matmul" )
95- matmul_grad = tc .define (MATMUL_GRAD_LANG , name = "matmul_grad" )
96-
97- mapping_options = matmul .autotune (
98- (n , k ), (k , m ),
99- generations = 3 ,
100- threads = 32 ,
101- pop_size = 2 ,
102- tuner_min_launch_total_threads = 1 ,
103- )
104-
105- grad_mapping_options = matmul_grad .autotune (
106- (n , k ), (k , m ), (n , m ),
107- generations = 1 ,
108- threads = 32 ,
109- pop_size = 2 ,
110- tuner_min_launch_total_threads = 1 ,
111- )
92+ tuner = tc .Tuner (MATMUL_LANG )
93+ tuner_config = (
94+ tc .TunerConfig ().generations (3 ).threads (32 ).pop_size (2 )
95+ .tuner_min_launch_total_threads (1 ))
96+ matmul_top1 = tuner .tune (
97+ 'matmul' ,
98+ (torch .randn (n , k , device = 'cuda' ),
99+ torch .randn (k , m , device = 'cuda' )),
100+ tc .MappingOptions ('naive' ),
101+ tuner_config )
102+ matmul_grad_top1 = tuner .tune (
103+ 'matmul_grad' ,
104+ (torch .randn (n , k , device = 'cuda' ),
105+ torch .randn (k , m , device = 'cuda' ),
106+ torch .randn (n , m , device = 'cuda' )),
107+ tc .MappingOptions ('naive' ),
108+ tuner_config )
112109
113110 X = np .random .rand (m , k ).astype (np .float32 )
114111 W = np .random .rand (k , n ).astype (np .float32 )
@@ -120,13 +117,13 @@ def ref(X, W):
120117 "TcOp" , ["X" , "Y" ], "out" ,
121118 tc_def = MATMUL_LANG ,
122119 tc_name = "matmul" ,
123- tc_grad_def = MATMUL_GRAD_LANG ,
120+ tc_grad_def = MATMUL_LANG ,
124121 tc_grad_name = "matmul_grad" ,
125122 inputs_used_by_gradient = [0 , 1 ],
126123 output_gradients_used_by_gradient = [0 ],
127124 inputs_to_compute_gradients_of = [0 , 1 ],
128- mapping_options = mapping_options .serialize (),
129- grad_mapping_options = grad_mapping_options .serialize (),
125+ mapping_options = matmul_top1 .serialize (),
126+ grad_mapping_options = matmul_grad_top1 .serialize (),
130127 )
131128
132129 self .assertReferenceChecks (
0 commit comments