@@ -58,22 +58,165 @@ function transform_gpu!(def, constargs, force_inbounds)
5858 end
5959 end
6060 pushfirst! (def[:args ], :__ctx__ )
61- body = def[:body ]
61+ new_stmts = Expr[]
62+ body = MacroTools. flatten (def[:body ])
63+ stmts = body. args
64+ push! (new_stmts, Expr (:aliasscope ))
65+ push! (new_stmts, :(__active_lane__ = $ __validindex (__ctx__)))
6266 if force_inbounds
63- body = quote
64- @inbounds $ (body)
65- end
67+ push! (new_stmts, Expr (:inbounds , true ))
6668 end
67- body = quote
68- if $ __validindex (__ctx__)
69- $ (body)
70- end
71- return nothing
69+ append! (new_stmts, split (emit_gpu, body. args))
70+ if force_inbounds
71+ push! (new_stmts, Expr (:inbounds , :pop ))
7272 end
73+ push! (new_stmts, Expr (:popaliasscope ))
74+ push! (new_stmts, :(return nothing ))
7375 def[:body ] = Expr (
7476 :let ,
7577 Expr (:block , let_constargs... ),
76- body ,
78+ Expr ( :block , new_stmts ... ) ,
7779 )
7880 return
7981end
82+
83+ struct WorkgroupLoop
84+ indicies:: Vector{Any}
85+ stmts:: Vector{Any}
86+ allocations:: Vector{Any}
87+ private_allocations:: Vector{Any}
88+ private:: Set{Symbol}
89+ end
90+
91+ is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
92+
93+ function is_scope_construct (expr:: Expr )
94+ return expr. head === :block # ||
95+ # expr.head === :let
96+ end
97+
98+ function find_sync (stmt)
99+ result = false
100+ postwalk (stmt) do expr
101+ result |= is_sync (expr)
102+ expr
103+ end
104+ return result
105+ end
106+
107+ # TODO proper handling of LineInfo
108+ function split (
109+ emit,
110+ stmts,
111+ indicies = Any[], private = Set {Symbol} (),
112+ )
113+ # 1. Split the code into blocks separated by `@synchronize`
114+ # 2. Aggregate `@index` expressions
115+ # 3. Hoist allocations
116+ # 4. Hoist uniforms
117+
118+ current = Any[]
119+ allocations = Any[]
120+ private_allocations = Any[]
121+ new_stmts = Any[]
122+ for stmt in stmts
123+ has_sync = find_sync (stmt)
124+ if has_sync
125+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
126+ push! (new_stmts, emit (loop))
127+ allocations = Any[]
128+ private_allocations = Any[]
129+ current = Any[]
130+ is_sync (stmt) && continue
131+
132+ # Recurse into scope constructs
133+ # TODO : This currently implements hard scoping
134+ # probably need to implemet soft scoping
135+ # by not deepcopying the environment.
136+ recurse (x) = x
137+ function recurse (expr:: Expr )
138+ expr = unblock (expr)
139+ if is_scope_construct (expr) && any (find_sync, expr. args)
140+ new_args = unblock (split (emit, expr. args, deepcopy (indicies), deepcopy (private)))
141+ return Expr (expr. head, new_args... )
142+ else
143+ return Expr (expr. head, map (recurse, expr. args)... )
144+ end
145+ end
146+ push! (new_stmts, recurse (stmt))
147+ continue
148+ end
149+
150+ if @capture (stmt, @uniform x_)
151+ push! (allocations, stmt)
152+ continue
153+ elseif @capture (stmt, @private lhs_ = rhs_)
154+ push! (private, lhs)
155+ push! (private_allocations, :($ lhs = $ rhs))
156+ continue
157+ elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
158+ if @capture (rhs, @index (args__))
159+ push! (indicies, stmt)
160+ continue
161+ elseif @capture (rhs, @localmem (args__) | @uniform (args__))
162+ push! (allocations, stmt)
163+ continue
164+ elseif @capture (rhs, @private (T_, dims_))
165+ # Implement the legacy `mem = @private T dims` as
166+ # mem = Scratchpad(T, Val(dims))
167+
168+ if dims isa Integer
169+ dims = (dims,)
170+ end
171+ alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
172+ push! (allocations, :($ lhs = $ alloc))
173+ push! (private, lhs)
174+ continue
175+ end
176+ end
177+
178+ push! (current, stmt)
179+ end
180+
181+ # everything since the last `@synchronize`
182+ if ! isempty (current)
183+ loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
184+ push! (new_stmts, emit (loop))
185+ end
186+ return new_stmts
187+ end
188+
189+ function emit_gpu (loop)
190+ stmts = Any[]
191+ append! (stmts, loop. allocations)
192+ for stmt in loop. private_allocations
193+ if @capture (stmt, lhs_ = rhs_)
194+ push! (stmts, :($ lhs = $ rhs))
195+ else
196+ error (" @private $stmt not an assignment" )
197+ end
198+ end
199+
200+ # don't emit empty loops
201+ if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
202+ body = Expr (:block , loop. stmts... )
203+ body = postwalk (body) do expr
204+ if @capture (expr, lhs_ = rhs_)
205+ if lhs in loop. private
206+ error (" Can't assign to variables marked private" )
207+ end
208+ end
209+ return expr
210+ end
211+ loopexpr = quote
212+ if __active_lane__
213+ $ (loop. indicies... )
214+ $ (unblock (body))
215+ end
216+ $ __synchronize ()
217+ end
218+ push! (stmts, loopexpr)
219+ end
220+
221+ return unblock (Expr (:block , stmts... ))
222+ end
0 commit comments