1
+
2
+ # Autotuner (not using Petabricks)
3
+
4
+ # Connelly TODO research:
5
+ # - enumerate valid schedules
6
+ # - machine learning schedule at a given point
7
+ # - dynamic programming or tree search or iterative algorithm to find initial guess / plausible schedules
8
+
9
+ # TODO:
10
+ # - chunk (and all functions actually) should list variables that are created due to split/tile (see e.g. camera_raw).
11
+ # - global creation (for more than one function)
12
+ # - create random schedule as well as enumerate all schedules of a given length
13
+
14
+ import halide
15
+ import random
16
+ import collections
17
+ import itertools
18
+
19
+ FUNC_ROOT = 0
20
+ FUNC_INLINE = 1
21
+ FUNC_CHUNK = 2 # Needs a variable in the caller
22
+ # Chunk not implemented yet
23
+
24
+ VAR_SERIAL = 0
25
+ VAR_VECTORIZE = 1
26
+ VAR_PARALLEL = 2
27
+ VAR_UNROLL = 3
28
+ VAR_TILE = 4
29
+ VAR_SPLIT = 5
30
+ # Tile and split not implemented yet (recursion not implemented). Also vectorize() and unroll() implicitly create a new variable so they recurse also.
31
+ # transpose() always there or sometimes not there
32
+ # To get long schedules should be able to randomize from an existing schedule.
33
+ # Also schedules have some global interactions when new variables are introduced so refactor to handle that.
34
+
35
+ def default_check (cls , L ):
36
+ def count (C ):
37
+ return sum ([isinstance (x , C ) for x in L ])
38
+ if len (L ) == 0 :
39
+ return True
40
+ else :
41
+ # Handle singleton fragments
42
+ if isinstance (L [0 ], FragmentRoot ) and count (FragmentRoot ) == 1 and count (FragmentChunk ) == 0 :
43
+ return True
44
+ elif isinstance (L [0 ], FragmentChunk ) and len (L ) == 1 :
45
+ return True
46
+ return False
47
+
48
+ class Fragment :
49
+ "Base class for schedule fragment e.g. .vectorize(x), .parallel(y), .root(), etc."
50
+ def __init__ (self , var = None , value = None ):
51
+ self .var = var
52
+ self .value = value
53
+
54
+ @staticmethod
55
+ def fragments (root_func , func , cls , vars ):
56
+ "Given class and variable list (of strings) returns fragments possible at this point."
57
+ # print 'fragments base', cls
58
+ return [cls ()]
59
+
60
+ def ___str__ (self ):
61
+ "Returns schedule_str, e.g. '.parallel(y)'."
62
+
63
+ def new_vars (self ):
64
+ "List of new variable names, e.g. ['v'] or []."
65
+ return []
66
+
67
+ def randomize (self ):
68
+ "Randomize values e.g. change vectorize(x, 8) => vectorize(x, (random value))."
69
+
70
+ def check (self , L ):
71
+ "Given list of Schedule fragments (applied to a function) returns True if valid else False."
72
+ return default_check (self .__class__ , L )
73
+
74
+ class FragmentVarMixin :
75
+ @staticmethod
76
+ def fragments (root_func , func , cls , vars ):
77
+ # print 'fragments', cls
78
+ return [cls (x ) for x in vars ]
79
+
80
+ def blocksize_random ():
81
+ return random .choice ([2 ,4 ,8 ,16 ,32 ])
82
+
83
+ class FragmentBlocksizeMixin (FragmentVarMixin ):
84
+ def __init__ (self , var = None , value = None ):
85
+ # print '__init__', self.__class__
86
+ self .var = var
87
+ self .value = value
88
+ if self .value is None :
89
+ self .randomize ()
90
+
91
+ def randomize (self ):
92
+ self .value = blocksize_random ()
93
+
94
+ def check (self , L ):
95
+ return check_duplicates (self .__class__ , L )
96
+
97
+ def check_duplicates (cls , L ):
98
+ if not default_check (cls , L ):
99
+ return False
100
+ #count = collections.defaultdict(lambda: 0)
101
+ #for x in L:
102
+ # if isinstance(x, cls):
103
+ # count[x.var] += 1
104
+ # if count[x.var] >= 2:
105
+ # return False
106
+ d = set ()
107
+ for x in L :
108
+ s = repr (x )
109
+ if s in d :
110
+ return False
111
+ d .add (s )
112
+
113
+ return True
114
+
115
+ class FragmentRoot (Fragment ):
116
+ def __str__ (self ):
117
+ return '.root()'
118
+
119
+ class FragmentVectorize (FragmentBlocksizeMixin ,Fragment ):
120
+ def __str__ (self ):
121
+ return '.vectorize(%s,%d)' % (self .var , self .value )
122
+
123
+ class FragmentParallel (FragmentBlocksizeMixin ,Fragment ):
124
+ def __str__ (self ):
125
+ return '.parallel(%s,%d)' % (self .var ,self .value )
126
+
127
+ class FragmentUnroll (FragmentBlocksizeMixin ,Fragment ):
128
+ def __str__ (self ):
129
+ return '.unroll(%s,%d)' % (self .var ,self .value )
130
+
131
+ class FragmentChunk (Fragment ):
132
+ @staticmethod
133
+ def fragments (root_func , func , cls , vars ):
134
+ return [cls (x ) for x in caller_vars (root_func , func )]
135
+
136
+ def check (self , L ):
137
+ return check_duplicates (self .__class__ , L )
138
+
139
+ def __str__ (self ):
140
+ return '.chunk(%s)' % self .var
141
+
142
+ def create_var (vars ): #count=[0]):
143
+ #count[0] += 1
144
+ for i in itertools .count (0 ):
145
+ s = '_c%d' % i #count[0]
146
+ if not s in vars :
147
+ return s
148
+
149
+ class FragmentSplit (FragmentBlocksizeMixin ,Fragment ):
150
+ def __init__ (self , var = None , value = None , newvar = None , reuse_outer = False ,vars = None ):
151
+ FragmentBlocksizeMixin .__init__ (self , var , value )
152
+ self .newvar = newvar
153
+ if self .newvar is None :
154
+ self .newvar = create_var (vars )
155
+ self .reuse_outer = reuse_outer
156
+
157
+ @staticmethod
158
+ def fragments (root_func , func , cls , vars ):
159
+ return ([cls (x ,reuse_outer = False ,vars = vars ) for x in vars ] +
160
+ [cls (x ,reuse_outer = True ,vars = vars ) for x in vars ])
161
+
162
+ def new_vars (self ):
163
+ return [self .newvar ]
164
+
165
+ def __str__ (self ):
166
+ return '.split(%s,%s,%s,%d)' % (self .var ,self .var if self .reuse_outer else self .newvar ,
167
+ self .var if not self .reuse_outer else self .newvar , self .value )
168
+
169
+ class FragmentTile (FragmentBlocksizeMixin ,Fragment ):
170
+ def __init__ (self , xvar = None , yvar = None , newvar = None , vars = None ):
171
+ self .xvar = xvar
172
+ self .yvar = yvar
173
+ self .randomize ()
174
+ self .xnewvar = create_var (vars )
175
+ self .ynewvar = create_var (vars + [self .xnewvar ])
176
+
177
+ def randomize (self ):
178
+ self .xsize = blocksize_random ()
179
+ self .ysize = blocksize_random ()
180
+
181
+ def check (self , L ):
182
+ return check_duplicates (self .__class__ , L )
183
+
184
+ @staticmethod
185
+ def fragments (root_func , func , cls , vars ):
186
+ return [cls (x ,y ,vars = vars ) for x in vars for y in vars if x != y ]
187
+
188
+ def new_vars (self ):
189
+ return [self .xnewvar , self .ynewvar ]
190
+
191
+ def __str__ (self ):
192
+ return '.tile(%s,%s,%s,%s,%d,%d)' % (self .xvar ,self .yvar ,self .xnewvar ,self .ynewvar ,self .xsize ,self .ysize )
193
+
194
+ fragment_classes = [FragmentRoot , FragmentVectorize , FragmentParallel , FragmentUnroll , FragmentChunk , FragmentSplit , FragmentTile ]
195
+
196
+ class FragmentList (list ):
197
+ def __init__ (self , func , L ):
198
+ self .func = func
199
+ list .__init__ (self , L )
200
+
201
+ def __str__ (self ):
202
+ #print '__str__', list(self)
203
+ ans = []
204
+ for x in self :
205
+ #print 'str of', x
206
+ #print 'next'
207
+ ans .append (str (x ))
208
+ if len (ans ):
209
+ #print 'returning list'
210
+ return self .func .name () + '' .join (ans )
211
+ #print 'returning empty'
212
+ return ''
213
+
214
+ def __repr__ (self ):
215
+ return str (self ) #return 'FragmentList(%s, %r)' % (self.func, repr([str(x) for x in list(self)]))
216
+
217
+ def randomize (self ):
218
+ for x in self :
219
+ x .randomize ()
220
+
221
+ def schedules_depth (root_func , func , vars , depth = 0 ):
222
+ "Un-checked schedules of exactly the specified depth for the given function."
223
+ # print func
224
+ # print vars
225
+ assert depth >= 0 and isinstance (depth , (int ,long ))
226
+
227
+ if depth == 0 :
228
+ yield FragmentList (func , [])
229
+ else :
230
+ for cls in fragment_classes :
231
+ for L in schedules_depth (root_func , func , vars , depth - 1 ):
232
+ all_vars = list (vars )
233
+ for fragment in L :
234
+ all_vars .extend (fragment .new_vars ())
235
+ #print 'all_vars', all_vars
236
+ for fragment in cls .fragments (root_func , func , cls , all_vars ):
237
+ #print 'fragment', fragment
238
+ #print '=>', fragment
239
+ #print '*', len(L), L
240
+ yield FragmentList (func , list (L ) + [fragment ])
241
+
242
+ def valid_schedules (root_func , func , max_depth = 4 ):
243
+ "A sequence of valid schedules for a function, each of which is a list of schedule fragments (up to a maximum depth)."
244
+ vars = halide .func_varlist (func )
245
+ for depth in range (max_depth + 1 ):
246
+ for L in schedules_depth (root_func , func , vars , depth ):
247
+ ok = True
248
+ for x in L :
249
+ #print 'at depth=%d, checking'%depth, str(L)#, len(L)
250
+ if not x .check (L ):
251
+ #print 'check failed'
252
+ ok = False
253
+ break
254
+ if ok :
255
+ yield L
256
+
257
+ def func_lhs_var_names (f ):
258
+ ans = []
259
+ for y in f .args ():
260
+ for x in y .vars ():
261
+ ans .append (x .name ())
262
+ return ans
263
+
264
+ def caller_vars (root_func , func ):
265
+ "Given a root Func and current function return list of variables of the caller."
266
+ func_name = func .name ()
267
+ for (name , g ) in halide .all_funcs (root_func ).items ():
268
+ rhs_names = [x .name () for x in g .rhs ().funcs ()]
269
+ if func_name in rhs_names :
270
+ return func_lhs_var_names (g )
271
+ return []
272
+
273
+ def test_schedules ():
274
+ f = halide .Func ('f' )
275
+ x = halide .Var ('x' )
276
+ y = halide .Var ('y' )
277
+ g = halide .Func ('g' )
278
+ v = halide .Var ('v' )
279
+ f [x ,y ] = 1
280
+ g [v ] = f [v ,v ]
281
+
282
+ print halide .func_varlist (f )
283
+ print 'caller_vars(f) =' , caller_vars (g , f )
284
+ print 'caller_vars(g) =' , caller_vars (g , g )
285
+
286
+ validL = list (valid_schedules (g , f , 3 ))
287
+
288
+ for L in validL :
289
+ print repr (repr (L ))
290
+
291
+ print 'number valid: ' , len (validL )
292
+
293
+ if __name__ == '__main__' :
294
+ test_schedules ()
295
+
0 commit comments