@@ -94,29 +94,7 @@ function transform_gpu!(def, constargs, force_inbounds)
9494 if force_inbounds
9595 push! (new_stmts, Expr (:inbounds , true ))
9696 end
97-
98- # fix convergence
99- active_stmts = Any[]
100- for stmt in stmts
101- has_sync = find_sync (stmt)
102- if has_sync
103- push! (new_stmts, Expr (:if , :__active_lane__ , Expr (:block , active_stmts... )))
104- empty! (active_stmts)
105- push! (new_stmts, stmt)
106- continue
107- end
108- if @capture (stmt, @uniform x_)
109- push! (new_stmts, stmt)
110- continue
111- elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
112- if @capture (rhs, @localmem (args__) | @uniform (args__))
113- push! (new_stmts, stmt)
114- continue
115- end
116- end
117- push! (active_stmts, stmt)
118- end
119- push! (new_stmts, Expr (:if , :__active_lane__ , Expr (:block , active_stmts... )))
97+ append! (new_stmts, split (emit_gpu, body. args))
12098 if force_inbounds
12199 push! (new_stmts, Expr (:inbounds , :pop ))
122100 end
@@ -151,7 +129,7 @@ function transform_cpu!(def, constargs, force_inbounds)
151129 if force_inbounds
152130 push! (new_stmts, Expr (:inbounds , true ))
153131 end
154- append! (new_stmts, split (body. args))
132+ append! (new_stmts, split (emit_cpu, body. args))
155133 if force_inbounds
156134 push! (new_stmts, Expr (:inbounds , :pop ))
157135 end
191169
192170# TODO proper handling of LineInfo
193171function split (
172+ emit,
194173 stmts,
195174 indicies = Any[], private = Set {Symbol} (),
196175 )
@@ -221,7 +200,7 @@ function split(
221200 function recurse (expr:: Expr )
222201 expr = unblock (expr)
223202 if is_scope_construct (expr) && any (find_sync, expr. args)
224- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
203+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
225204 return Expr (expr. head, new_args... )
226205 else
227206 return Expr (expr. head, map (recurse, expr. args)... )
@@ -270,7 +249,7 @@ function split(
270249 return new_stmts
271250end
272251
273- function emit (loop)
252+ function emit_cpu (loop)
274253 idx = gensym (:I )
275254 for stmt in loop. indicies
276255 # splice index into the i = @index(Cartesian, $idx)
@@ -324,3 +303,37 @@ function emit(loop)
324303
325304 return unblock (Expr (:block , stmts... ))
326305end
306+
307+ function emit_gpu (loop)
308+ stmts = Any[]
309+ append! (stmts, loop. allocations)
310+ for stmt in loop. private_allocations
311+ if @capture (stmt, lhs_ = rhs_)
312+ push! (stmts, :($ lhs = $ rhs))
313+ else
314+ error (" @private $stmt not an assignment" )
315+ end
316+ end
317+
318+ # don't emit empty loops
319+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
320+ body = Expr (:block , loop. stmts... )
321+ body = postwalk (body) do expr
322+ if @capture (expr, lhs_ = rhs_)
323+ if lhs in loop. private
324+ error (" Can't assign to variables marked private" )
325+ end
326+ end
327+ return expr
328+ end
329+ loopexpr = quote
330+ $ (loop. indicies... )
331+ if __active_lane__
332+ $ (unblock (body))
333+ end
334+ end
335+ push! (stmts, loopexpr)
336+ end
337+
338+ return unblock (Expr (:block , stmts... ))
339+ end
0 commit comments