1+ import time
2+ import torch
3+ import torch .optim as optim
4+ import torch .nn as nn
5+ import torch .utils .data
6+ import torch .nn .functional as F
7+ import tensor_comprehensions as tc
8+ import numpy as np
9+ from enum import IntEnum
10+
11+ NB_HYPERPARAMS = 26
12+
13+ class ExpTunerConfig :
14+ def __init__ (self , use_max_shared_memory = 0 ):
15+ self .INIT_INPUT_SZ = - 1
16+ self .USE_MAX_SHARED_MEMORY = use_max_shared_memory
17+ self .tc_code = ""
18+ self .tc_name = ""
19+ self .inp = - 1
20+ self .cat_val = - 1
21+ self .cat_sz = - 1
22+
23+ def set_convolution_tc (self , size_type = "default" , inp_sz_list = [], use_max_shared_memory = False ):
24+ self .INIT_INPUT_SZ = 7
25+ self .tc_name = "convolution"
26+ self .tc_code = """
27+ def convolution(float(N,C,H,W) I, float(M,C,KH,KW) W1) -> (O) {
28+ O(n, m, h, w) +=! I(n, r_c, h + r_kh, w + r_kw) * W1(m, r_c, r_kh, r_kw)
29+ }
30+ """
31+
32+ if (size_type == "input" ):
33+ N , C , H , W , O , kH , kW = tuple (inp_sz_list )
34+ elif (size_type == "default" ):
35+ N , C , H , W , O , kH , kW = 16 , 4 , 56 , 56 , 16 , 1 , 1 #8, 2, 28, 28, 8, 1, 1
36+ elif (size_type == "random" ):
37+ N , C , H , W , O , kH , kW = \
38+ getrand ([8 , 16 , 32 , 64 ]), \
39+ getrand ([2 , 4 , 8 , 16 ]), \
40+ getrand ([28 , 56 , 112 ]), \
41+ getrand ([28 , 56 , 112 ]), \
42+ getrand ([8 , 16 , 32 ]), \
43+ getrand ([1 , 2 , 4 ]), \
44+ getrand ([1 , 2 , 4 ])
45+ else :
46+ print ("Unknown size type" )
47+ exit ()
48+ I , W1 = torch .randn (N , C , H , W , device = 'cuda' ), torch .randn (O , C , kH , kW , device = 'cuda' )
49+ self .inp = (I , W1 )
50+ self .init_input_sz = np .array ([N ,C ,H ,W ,O , kH , kW ])
51+ print (self .init_input_sz )
52+ self .init_input_sz = torch .from_numpy (self .init_input_sz ).float ()
53+
54+ self .computeCat ()
55+
56+ def computeCat (self ):
57+ inp = self .inp
58+ self .cat_sz = np .zeros (NB_HYPERPARAMS ).astype (int )
59+ self .cat_val = [[] for _ in range (NB_HYPERPARAMS )]
60+
61+ divs = getAllDivs (inp )
62+ if (self .USE_MAX_SHARED_MEMORY ):
63+ divs2 = getAllDivs ([np .array ([tc .tclib .shared_memory_size ()])])
64+
65+ self .cat_val [MappingOptionsIdx .outerScheduleFusionStrategy ] = \
66+ [0 ,1 ,2 ]
67+ self .cat_val [MappingOptionsIdx .intraTileScheduleFusionStrategy ] = \
68+ [0 ,1 ,2 ]
69+ self .cat_val [MappingOptionsIdx .fixParametersBeforeScheduling ] = \
70+ [0 ,1 ]
71+ self .cat_val [MappingOptionsIdx .nTiledDims ] = \
72+ [i + 1 for i in range (6 )]
73+ for i in range (6 ): #tiling
74+ self .cat_val [MappingOptionsIdx .tiling1 + i ] = \
75+ divs + [0 ]
76+ self .cat_val [MappingOptionsIdx .unroll ] = \
77+ [2 ** i for i in range (8 )]
78+ self .cat_val [MappingOptionsIdx .matchLibraryCalls ] = \
79+ [0 ,1 ]
80+ self .cat_val [MappingOptionsIdx .nMappedToBlocksDims ] = \
81+ [i + 1 for i in range (3 )]
82+ for i in range (3 ): #mapping to blocks
83+ self .cat_val [MappingOptionsIdx .mappingToBlocks1 + i ] = \
84+ divs
85+ self .cat_val [MappingOptionsIdx .nMappedToThreadsDims ] = \
86+ [i + 1 for i in range (3 )]
87+ for i in range (3 ): #mapping to threads
88+ self .cat_val [MappingOptionsIdx .mappingToThreads1 + i ] = \
89+ divs
90+ self .cat_val [MappingOptionsIdx .useSharedMemory ] = \
91+ [0 ,1 ]
92+ self .cat_val [MappingOptionsIdx .usePrivateMemory ] = \
93+ [0 ,1 ]
94+ self .cat_val [MappingOptionsIdx .unrollCopyShared ] = \
95+ [0 ,1 ]
96+ self .cat_val [MappingOptionsIdx .maxSharedMemory ] = \
97+ divs2 if USE_MAX_SHARED_MEMORY else [0 ]
98+ self .cat_val [MappingOptionsIdx .useReadOnlyCache ] = \
99+ [0 ,1 ]
100+ self .cat_val [MappingOptionsIdx .privateDepth ] = \
101+ [i for i in range (6 )]
102+
103+ for i in range (NB_HYPERPARAMS ):
104+ self .cat_sz [i ] = len (self .cat_val [i ])
105+
106+ def catVec_to_optVec (self , catVec ):
107+ opt = [self .cat_val [i ][catVec [i ]] for i in range (NB_HYPERPARAMS )]
108+ return opt
109+
110+
111+ class MappingOptionsIdx (IntEnum ):
112+ outerScheduleFusionStrategy = 0
113+ intraScheduleFusionStrategy = 1
114+ fixParametersBeforeScheduling = 2
115+ nTiledDims = 3
116+ tiling1 = 4
117+ tiling2 = 5
118+ tiling3 = 6
119+ tiling4 = 7
120+ tiling5 = 8
121+ tiling6 = 9
122+ unroll = 10
123+ matchLibraryCalls = 11
124+ nMappedToBlocksDims = 12
125+ mappingToBlocks1 = 13
126+ mappingToBlocks2 = 14
127+ mappingToBlocks3 = 15
128+ nMappedToThreadsDims = 16
129+ mappingToThreads1 = 17
130+ mappingToThreads2 = 18
131+ mappingToThreads3 = 19
132+ useSharedMemory = 20
133+ usePrivateMemory = 21
134+ unrollCopyShared = 22
135+ maxSharedMemory = 23
136+ useReadOnlyCache = 24
137+ privateDepth = 25
138+
139+ def get_rand (l ):
140+ return np .random .choice (l ).item ()
141+
142+ def print_opt (options ):
143+ print (options .tolist ())
144+
145+ def evalTime (opt , exptuner_config , iters = 50 , warmup = 10 , estimator = "mean" , prune = - 1 , curr_best = - 1 ):
146+ tc_code , tc_name , inp = \
147+ exptuner_config .tc_code , exptuner_config .tc_name , exptuner_config .inp
148+ infty = 30000
149+ opt = exptuner_config .catVec_to_optVec (opt )
150+ opt = optionsFromVector (opt )
151+ try :
152+ tc_prog = tc .compile (tc_code , tc_name , opt , * inp )
153+ first_ft = tc_prog .executor .profile_kernel (inp )
154+ except (KeyboardInterrupt , SystemExit ):
155+ raise
156+ except :
157+ return infty
158+ if (prune != - 1 and first_ft > 100 * curr_best ):
159+ return first_ft
160+ for _ in range (warmup - 1 ):
161+ tc_prog .executor .profile_kernel (inp )
162+
163+ first_t = tc_prog .executor .profile_kernel (inp )
164+
165+ if (prune != - 1 and first_t > prune * curr_best ):
166+ return first_t
167+
168+ tc_time_list = [first_t ]
169+ for i in range (iters - 1 ):
170+ iter_time = tc_prog .executor .profile_kernel (inp )
171+ tc_time_list .append (iter_time )
172+ if (estimator == "mean" ):
173+ mean_time = np .mean (tc_time_list )
174+ return mean_time
175+ elif (estimator == "median" ):
176+ median_time = np .median (tc_time_list )
177+ return median_time
178+ elif (estimator == "p25" ):
179+ p25_time = np .percentile (tc_time_list , 25 )
180+ return p25_time
181+ print ("Unknown estimator" )
182+ return infty
183+
184+ def getRawVectorFromTcOpt (tc_opt ):
185+ tr_dic = {"Max" :0 , "Preserve3Coincident" :1 , "Min" :2 }
186+ opt_vect = np .zeros (NB_HYPERPARAMS ).astype (int )
187+ opt_vect [MappingOptionsIdx .outerScheduleFusionStrategy ] = \
188+ tr_dic [tc_opt ["outerScheduleFusionStrategy" ]]
189+ opt_vect [MappingOptionsIdx .intraTileScheduleFusionStrategy ] = \
190+ tr_dic [tc_opt ["intraTileScheduleFusionStrategy" ]]
191+ opt_vect [MappingOptionsIdx .fixParametersBeforeScheduling ] = \
192+ tc_opt ["fixParametersBeforeScheduling" ]
193+ opt_vect [MappingOptionsIdx .nTiledDims ] = \
194+ len (tc_opt ["tile" ])
195+ assert opt_vect [MappingOptionsIdx .nTiledDims ] < 7 , "Too many tilings"
196+ opt_vect [
197+ MappingOptionsIdx .tiling1 : MappingOptionsIdx .tiling1 + opt_vect [MappingOptionsIdx .nTiledDims ]] = \
198+ tc_opt ["tile" ]
199+ opt_vect [MappingOptionsIdx .unroll ] = \
200+ tc_opt ["unroll" ]
201+ #opt_vect[MappingOptionsIdx.tileImperfectlyNested] = \
202+ # tc_opt["tileImperfectlyNested"] #todo: pybind
203+ opt_vect [MappingOptionsIdx .matchLibraryCalls ] = \
204+ tc_opt ["matchLibraryCalls" ]
205+ opt_vect [MappingOptionsIdx .nMappedToBlocksDims ] = \
206+ len (tc_opt ["mapToBlocks" ])
207+ opt_vect [
208+ MappingOptionsIdx .mappingToBlocks1 : MappingOptionsIdx .mappingToBlocks1 + opt_vect [MappingOptionsIdx .nMappedToBlocksDims ]] = \
209+ tc_opt ["mapToBlocks" ]
210+ opt_vect [MappingOptionsIdx .nMappedToThreadsDims ] = \
211+ len (tc_opt ["mapToThreads" ])
212+ opt_vect [
213+ MappingOptionsIdx .mappingToThreads1 : MappingOptionsIdx .mappingToThreads1 + opt_vect [MappingOptionsIdx .nMappedToThreadsDims ]] = \
214+ tc_opt ["mapToThreads" ]
215+ opt_vect [MappingOptionsIdx .useSharedMemory ] = \
216+ tc_opt ["useSharedMemory" ]
217+ opt_vect [MappingOptionsIdx .usePrivateMemory ] = \
218+ tc_opt ["usePrivateMemory" ]
219+ opt_vect [MappingOptionsIdx .unrollCopyShared ] = \
220+ tc_opt ["unrollCopyShared" ]
221+ if (USE_MAX_SHARED_MEMORY and "maxSharedMemory" in tc_opt ):
222+ opt_vect [MappingOptionsIdx .maxSharedMemory ] = \
223+ tc_opt ["maxSharedMemory" ]
224+ opt_vect [MappingOptionsIdx .useReadOnlyCache ] = \
225+ tc_opt ["useReadOnlyCache" ]
226+ opt_vect [MappingOptionsIdx .privateDepth ] = \
227+ tc_opt ["privateDepth" ]
228+ return opt_vect
229+
230+ def optionsFromVector (vect ):
231+ strat_str = ["Max" , "Preserve3Coincident" , "Min" ]
232+ options = tc .MappingOptions ("naive" )
233+ options .outerScheduleFusionStrategy (
234+ strat_str [vect [
235+ MappingOptionsIdx .outerScheduleFusionStrategy ]])
236+ options .intraTileScheduleFusionStrategy (
237+ strat_str [vect [
238+ MappingOptionsIdx .intraTileScheduleFusionStrategy ]])
239+ options .fixParametersBeforeScheduling (
240+ vect [MappingOptionsIdx .fixParametersBeforeScheduling ])
241+ options .tile (
242+ list (vect [
243+ MappingOptionsIdx .tiling1 : MappingOptionsIdx .tiling1 + vect [MappingOptionsIdx .nTiledDims ]]))
244+ options .unroll (
245+ vect [MappingOptionsIdx .unroll ])
246+ options .matchLibraryCalls (
247+ vect [MappingOptionsIdx .matchLibraryCalls ])
248+ options .mapToBlocks (
249+ list (vect [
250+ MappingOptionsIdx .mappingToBlocks1 : MappingOptionsIdx .mappingToBlocks1 + vect [MappingOptionsIdx .nMappedToBlocksDims ]]))
251+ options .mapToThreads (
252+ list (vect [
253+ MappingOptionsIdx .mappingToThreads1 : MappingOptionsIdx .mappingToThreads1 + vect [MappingOptionsIdx .nMappedToThreadsDims ]]))
254+ options .useSharedMemory (
255+ vect [MappingOptionsIdx .useSharedMemory ])
256+ options .usePrivateMemory (
257+ vect [MappingOptionsIdx .usePrivateMemory ])
258+ options .unrollCopyShared (
259+ vect [MappingOptionsIdx .unrollCopyShared ])
260+ if (USE_MAX_SHARED_MEMORY ):
261+ options .maxSharedMemory (
262+ vect [MappingOptionsIdx .maxSharedMemory ])
263+ options .useReadOnlyCache (
264+ vect [MappingOptionsIdx .useReadOnlyCache ])
265+ options .privateDepth (
266+ vect [MappingOptionsIdx .privateDepth ])
267+ return options
268+
269+ def computeDivs (sz ):
270+ l = []
271+ for i in range (sz ):
272+ if (2 ** i > sz ):
273+ break
274+ l .append ((sz + 2 ** i - 1 ) // (2 ** i ))
275+ return l
276+
277+ def getAllDivs (inp , maxp2 = 8 ):
278+ p2 = [2 ** i for i in range (maxp2 + 1 )]
279+ l = []
280+ for elem in inp :
281+ for sz in elem .shape :
282+ l += computeDivs (sz )
283+ divs_list = list (set (l + p2 ))
284+ return sorted (divs_list )
0 commit comments