Skip to content

Commit c947ba6

Browse files
committed
Merge branch 'master' of github.com:mit-gfx/Halide
2 parents 33be949 + 308989d commit c947ba6

File tree

4 files changed

+308
-3
lines changed

4 files changed

+308
-3
lines changed

py_bindings/.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
*.def
2+
*.lowered
3+
*.sexp
4+
*.bak
5+
*.so

py_bindings/autotune.py

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
2+
# Autotuner (not using Petabricks)
3+
4+
# Connelly TODO research:
5+
# - enumerate valid schedules
6+
# - machine learning schedule at a given point
7+
# - dynamic programming or tree search or iterative algorithm to find initial guess / plausible schedules
8+
9+
# TODO:
10+
# - chunk (and all functions actually) should list variables that are created due to split/tile (see e.g. camera_raw).
11+
# - global creation (for more than one function)
12+
# - create random schedule as well as enumerate all schedules of a given length
13+
14+
import halide
15+
import random
16+
import collections
17+
import itertools
18+
19+
FUNC_ROOT = 0
20+
FUNC_INLINE = 1
21+
FUNC_CHUNK = 2 # Needs a variable in the caller
22+
# Chunk not implemented yet
23+
24+
VAR_SERIAL = 0
25+
VAR_VECTORIZE = 1
26+
VAR_PARALLEL = 2
27+
VAR_UNROLL = 3
28+
VAR_TILE = 4
29+
VAR_SPLIT = 5
30+
# Tile and split not implemented yet (recursion not implemented). Also vectorize() and unroll() implicitly create a new variable so they recurse also.
31+
# transpose() always there or sometimes not there
32+
# To get long schedules should be able to randomize from an existing schedule.
33+
# Also schedules have some global interactions when new variables are introduced so refactor to handle that.
34+
35+
def default_check(cls, L):
36+
def count(C):
37+
return sum([isinstance(x, C) for x in L])
38+
if len(L) == 0:
39+
return True
40+
else:
41+
# Handle singleton fragments
42+
if isinstance(L[0], FragmentRoot) and count(FragmentRoot) == 1 and count(FragmentChunk) == 0:
43+
return True
44+
elif isinstance(L[0], FragmentChunk) and len(L) == 1:
45+
return True
46+
return False
47+
48+
class Fragment:
49+
"Base class for schedule fragment e.g. .vectorize(x), .parallel(y), .root(), etc."
50+
def __init__(self, var=None, value=None):
51+
self.var = var
52+
self.value = value
53+
54+
@staticmethod
55+
def fragments(root_func, func, cls, vars):
56+
"Given class and variable list (of strings) returns fragments possible at this point."
57+
# print 'fragments base', cls
58+
return [cls()]
59+
60+
def ___str__(self):
61+
"Returns schedule_str, e.g. '.parallel(y)'."
62+
63+
def new_vars(self):
64+
"List of new variable names, e.g. ['v'] or []."
65+
return []
66+
67+
def randomize(self):
68+
"Randomize values e.g. change vectorize(x, 8) => vectorize(x, (random value))."
69+
70+
def check(self, L):
71+
"Given list of Schedule fragments (applied to a function) returns True if valid else False."
72+
return default_check(self.__class__, L)
73+
74+
class FragmentVarMixin:
75+
@staticmethod
76+
def fragments(root_func, func, cls, vars):
77+
# print 'fragments', cls
78+
return [cls(x) for x in vars]
79+
80+
def blocksize_random():
81+
return random.choice([2,4,8,16,32])
82+
83+
class FragmentBlocksizeMixin(FragmentVarMixin):
84+
def __init__(self, var=None, value=None):
85+
# print '__init__', self.__class__
86+
self.var = var
87+
self.value = value
88+
if self.value is None:
89+
self.randomize()
90+
91+
def randomize(self):
92+
self.value = blocksize_random()
93+
94+
def check(self, L):
95+
return check_duplicates(self.__class__, L)
96+
97+
def check_duplicates(cls, L):
98+
if not default_check(cls, L):
99+
return False
100+
#count = collections.defaultdict(lambda: 0)
101+
#for x in L:
102+
# if isinstance(x, cls):
103+
# count[x.var] += 1
104+
# if count[x.var] >= 2:
105+
# return False
106+
d = set()
107+
for x in L:
108+
s = repr(x)
109+
if s in d:
110+
return False
111+
d.add(s)
112+
113+
return True
114+
115+
class FragmentRoot(Fragment):
116+
def __str__(self):
117+
return '.root()'
118+
119+
class FragmentVectorize(FragmentBlocksizeMixin,Fragment):
120+
def __str__(self):
121+
return '.vectorize(%s,%d)'%(self.var, self.value)
122+
123+
class FragmentParallel(FragmentBlocksizeMixin,Fragment):
124+
def __str__(self):
125+
return '.parallel(%s,%d)'%(self.var,self.value)
126+
127+
class FragmentUnroll(FragmentBlocksizeMixin,Fragment):
128+
def __str__(self):
129+
return '.unroll(%s,%d)'%(self.var,self.value)
130+
131+
class FragmentChunk(Fragment):
132+
@staticmethod
133+
def fragments(root_func, func, cls, vars):
134+
return [cls(x) for x in caller_vars(root_func, func)]
135+
136+
def check(self, L):
137+
return check_duplicates(self.__class__, L)
138+
139+
def __str__(self):
140+
return '.chunk(%s)'%self.var
141+
142+
def create_var(vars): #count=[0]):
143+
#count[0] += 1
144+
for i in itertools.count(0):
145+
s = '_c%d'%i#count[0]
146+
if not s in vars:
147+
return s
148+
149+
class FragmentSplit(FragmentBlocksizeMixin,Fragment):
150+
def __init__(self, var=None, value=None, newvar=None, reuse_outer=False,vars=None):
151+
FragmentBlocksizeMixin.__init__(self, var, value)
152+
self.newvar = newvar
153+
if self.newvar is None:
154+
self.newvar = create_var(vars)
155+
self.reuse_outer = reuse_outer
156+
157+
@staticmethod
158+
def fragments(root_func, func, cls, vars):
159+
return ([cls(x,reuse_outer=False,vars=vars) for x in vars] +
160+
[cls(x,reuse_outer=True,vars=vars) for x in vars])
161+
162+
def new_vars(self):
163+
return [self.newvar]
164+
165+
def __str__(self):
166+
return '.split(%s,%s,%s,%d)'%(self.var,self.var if self.reuse_outer else self.newvar,
167+
self.var if not self.reuse_outer else self.newvar, self.value)
168+
169+
class FragmentTile(FragmentBlocksizeMixin,Fragment):
170+
def __init__(self, xvar=None, yvar=None, newvar=None, vars=None):
171+
self.xvar=xvar
172+
self.yvar=yvar
173+
self.randomize()
174+
self.xnewvar = create_var(vars)
175+
self.ynewvar = create_var(vars+[self.xnewvar])
176+
177+
def randomize(self):
178+
self.xsize = blocksize_random()
179+
self.ysize = blocksize_random()
180+
181+
def check(self, L):
182+
return check_duplicates(self.__class__, L)
183+
184+
@staticmethod
185+
def fragments(root_func, func, cls, vars):
186+
return [cls(x,y,vars=vars) for x in vars for y in vars if x != y]
187+
188+
def new_vars(self):
189+
return [self.xnewvar, self.ynewvar]
190+
191+
def __str__(self):
192+
return '.tile(%s,%s,%s,%s,%d,%d)'%(self.xvar,self.yvar,self.xnewvar,self.ynewvar,self.xsize,self.ysize)
193+
194+
fragment_classes = [FragmentRoot, FragmentVectorize, FragmentParallel, FragmentUnroll, FragmentChunk, FragmentSplit, FragmentTile]
195+
196+
class FragmentList(list):
197+
def __init__(self, func, L):
198+
self.func = func
199+
list.__init__(self, L)
200+
201+
def __str__(self):
202+
#print '__str__', list(self)
203+
ans = []
204+
for x in self:
205+
#print 'str of', x
206+
#print 'next'
207+
ans.append(str(x))
208+
if len(ans):
209+
#print 'returning list'
210+
return self.func.name() + ''.join(ans)
211+
#print 'returning empty'
212+
return ''
213+
214+
def __repr__(self):
215+
return str(self) #return 'FragmentList(%s, %r)' % (self.func, repr([str(x) for x in list(self)]))
216+
217+
def randomize(self):
218+
for x in self:
219+
x.randomize()
220+
221+
def schedules_depth(root_func, func, vars, depth=0):
222+
"Un-checked schedules of exactly the specified depth for the given function."
223+
# print func
224+
# print vars
225+
assert depth >= 0 and isinstance(depth, (int,long))
226+
227+
if depth == 0:
228+
yield FragmentList(func, [])
229+
else:
230+
for cls in fragment_classes:
231+
for L in schedules_depth(root_func, func, vars, depth-1):
232+
all_vars = list(vars)
233+
for fragment in L:
234+
all_vars.extend(fragment.new_vars())
235+
#print 'all_vars', all_vars
236+
for fragment in cls.fragments(root_func, func, cls, all_vars):
237+
#print 'fragment', fragment
238+
#print '=>', fragment
239+
#print '*', len(L), L
240+
yield FragmentList(func, list(L) + [fragment])
241+
242+
def valid_schedules(root_func, func, max_depth=4):
243+
"A sequence of valid schedules for a function, each of which is a list of schedule fragments (up to a maximum depth)."
244+
vars = halide.func_varlist(func)
245+
for depth in range(max_depth+1):
246+
for L in schedules_depth(root_func, func, vars, depth):
247+
ok = True
248+
for x in L:
249+
#print 'at depth=%d, checking'%depth, str(L)#, len(L)
250+
if not x.check(L):
251+
#print 'check failed'
252+
ok = False
253+
break
254+
if ok:
255+
yield L
256+
257+
def func_lhs_var_names(f):
258+
ans = []
259+
for y in f.args():
260+
for x in y.vars():
261+
ans.append(x.name())
262+
return ans
263+
264+
def caller_vars(root_func, func):
265+
"Given a root Func and current function return list of variables of the caller."
266+
func_name = func.name()
267+
for (name, g) in halide.all_funcs(root_func).items():
268+
rhs_names = [x.name() for x in g.rhs().funcs()]
269+
if func_name in rhs_names:
270+
return func_lhs_var_names(g)
271+
return []
272+
273+
def test_schedules():
274+
f = halide.Func('f')
275+
x = halide.Var('x')
276+
y = halide.Var('y')
277+
g = halide.Func('g')
278+
v = halide.Var('v')
279+
f[x,y] = 1
280+
g[v] = f[v,v]
281+
282+
print halide.func_varlist(f)
283+
print 'caller_vars(f) =', caller_vars(g, f)
284+
print 'caller_vars(g) =', caller_vars(g, g)
285+
286+
validL = list(valid_schedules(g, f, 3))
287+
288+
for L in validL:
289+
print repr(repr(L))
290+
291+
print 'number valid: ', len(validL)
292+
293+
if __name__ == '__main__':
294+
test_schedules()
295+

py_bindings/halide.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
from cHalide import *
23
import numpy
34
import Image as PIL
@@ -635,8 +636,8 @@ def test(func):
635636
def test_autotune():
636637
locals_d = test_func()
637638

638-
import halide_autotune
639-
halide_autotune.autotune(locals_d['blur_y'], locals_d['test'], locals_d)
639+
import petabricks_autotune
640+
petabricks_autotune.autotune(locals_d['blur_y'], locals_d['test'], locals_d)
640641

641642
def test_segfault():
642643
locals_d = test_func(compile=False)
@@ -677,6 +678,7 @@ def test_examples():
677678
names = []
678679
do_filter = True
679680

681+
# for example_name in ['interpolate']: #
680682
for example_name in 'interpolate snake blur dilate boxblur_sat boxblur_cumsum local_laplacian'.split(): #[examples.snake, examples.blur, examples.dilate, examples.boxblur_sat, examples.boxblur_cumsum, examples.local_laplacian]:
681683
example = getattr(examples, example_name)
682684
first = True
@@ -723,7 +725,7 @@ def test_examples():
723725
print 'Function names:'
724726
for (example_name, func_names) in names:
725727
print example_name, func_names
726-
728+
727729
def test():
728730
exit_on_signal()
729731
# print 'a'
@@ -733,10 +735,13 @@ def test():
733735
# print 'c'
734736
# pass
735737

738+
"""
736739
test_core()
737740
test_segfault()
738741
test_blur()
739742
test_examples()
743+
"""
744+
test_examples()
740745
# test_autotune()
741746
print
742747
print 'All tests passed, done'
File renamed without changes.

0 commit comments

Comments
 (0)