@@ -31,19 +31,6 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false)
3131 constargs[i] = false
3232 end
3333
34- # create two functions
35- # 1. GPU function
36- # 2. CPU function with work-group loops inserted
37- #
38- # Without the deepcopy we might accidentially modify expr shared between CPU and GPU
39- cpu_name = Symbol (:cpu_ , name)
40- if generate_cpu
41- def_cpu = deepcopy (def)
42- def_cpu[:name ] = cpu_name
43- transform_cpu! (def_cpu, constargs, force_inbounds)
44- cpu_function = combinedef (def_cpu)
45- end
46-
4734 def_gpu = deepcopy (def)
4835 def_gpu[:name ] = gpu_name = Symbol (:gpu_ , name)
4936 transform_gpu! (def_gpu, constargs, force_inbounds)
@@ -56,24 +43,12 @@ function __kernel(expr, generate_cpu = true, force_inbounds = false)
5643 $ name (dev, size) = $ name (dev, $ StaticSize (size), $ DynamicSize ())
5744 $ name (dev, size, range) = $ name (dev, $ StaticSize (size), $ StaticSize (range))
5845 function $name (dev:: Dev , sz:: S , range:: NDRange ) where {Dev, S <: $_Size , NDRange <: $_Size }
59- if $ isgpu (dev)
60- return $ construct (dev, sz, range, $ gpu_name)
61- else
62- if $ generate_cpu
63- return $ construct (dev, sz, range, $ cpu_name)
64- else
65- error (" This kernel is unavailable for backend CPU" )
66- end
67- end
46+ return $ construct (dev, sz, range, $ gpu_name)
6847 end
6948 end
7049 end
7150
72- if generate_cpu
73- return Expr (:block , esc (cpu_function), esc (gpu_function), esc (constructors))
74- else
75- return Expr (:block , esc (gpu_function), esc (constructors))
76- end
51+ return Expr (:block , esc (gpu_function), esc (constructors))
7752end
7853
7954# The easy case, transform the function for GPU execution
@@ -105,198 +80,3 @@ function transform_gpu!(def, constargs, force_inbounds)
10580 )
10681 return
10782end
108-
109- # The hard case, transform the function for CPU execution
110- # - mark constant arguments by applying `constify`.
111- # - insert aliasscope markers
112- # - insert implied loop bodys
113- # - handle indicies
114- # - hoist workgroup definitions
115- # - hoist uniform variables
116- function transform_cpu! (def, constargs, force_inbounds)
117- let_constargs = Expr[]
118- for (i, arg) in enumerate (def[:args ])
119- if constargs[i]
120- push! (let_constargs, :($ arg = $ constify ($ arg)))
121- end
122- end
123- pushfirst! (def[:args ], :__ctx__ )
124- new_stmts = Expr[]
125- body = MacroTools. flatten (def[:body ])
126- push! (new_stmts, Expr (:aliasscope ))
127- if force_inbounds
128- push! (new_stmts, Expr (:inbounds , true ))
129- end
130- append! (new_stmts, split (body. args))
131- if force_inbounds
132- push! (new_stmts, Expr (:inbounds , :pop ))
133- end
134- push! (new_stmts, Expr (:popaliasscope ))
135- push! (new_stmts, :(return nothing ))
136- def[:body ] = Expr (
137- :let ,
138- Expr (:block , let_constargs... ),
139- Expr (:block , new_stmts... ),
140- )
141- return
142- end
143-
144- struct WorkgroupLoop
145- indicies:: Vector{Any}
146- stmts:: Vector{Any}
147- allocations:: Vector{Any}
148- private_allocations:: Vector{Any}
149- private:: Set{Symbol}
150- end
151-
152- is_sync (expr) = @capture (expr, @synchronize () | @synchronize (a_))
153-
154- function is_scope_construct (expr:: Expr )
155- return expr. head === :block # ||
156- # expr.head === :let
157- end
158-
159- function find_sync (stmt)
160- result = false
161- postwalk (stmt) do expr
162- result |= is_sync (expr)
163- expr
164- end
165- return result
166- end
167-
168- # TODO proper handling of LineInfo
169- function split (
170- stmts,
171- indicies = Any[], private = Set {Symbol} (),
172- )
173- # 1. Split the code into blocks separated by `@synchronize`
174- # 2. Aggregate `@index` expressions
175- # 3. Hoist allocations
176- # 4. Hoist uniforms
177-
178- current = Any[]
179- allocations = Any[]
180- private_allocations = Any[]
181- new_stmts = Any[]
182- for stmt in stmts
183- has_sync = find_sync (stmt)
184- if has_sync
185- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
186- push! (new_stmts, emit (loop))
187- allocations = Any[]
188- private_allocations = Any[]
189- current = Any[]
190- is_sync (stmt) && continue
191-
192- # Recurse into scope constructs
193- # TODO : This currently implements hard scoping
194- # probably need to implemet soft scoping
195- # by not deepcopying the environment.
196- recurse (x) = x
197- function recurse (expr:: Expr )
198- expr = unblock (expr)
199- if is_scope_construct (expr) && any (find_sync, expr. args)
200- new_args = unblock (split (expr. args, deepcopy (indicies), deepcopy (private)))
201- return Expr (expr. head, new_args... )
202- else
203- return Expr (expr. head, map (recurse, expr. args)... )
204- end
205- end
206- push! (new_stmts, recurse (stmt))
207- continue
208- end
209-
210- if @capture (stmt, @uniform x_)
211- push! (allocations, stmt)
212- continue
213- elseif @capture (stmt, @private lhs_ = rhs_)
214- push! (private, lhs)
215- push! (private_allocations, :($ lhs = $ rhs))
216- continue
217- elseif @capture (stmt, lhs_ = rhs_ | (vs__, lhs_ = rhs_))
218- if @capture (rhs, @index (args__))
219- push! (indicies, stmt)
220- continue
221- elseif @capture (rhs, @localmem (args__) | @uniform (args__))
222- push! (allocations, stmt)
223- continue
224- elseif @capture (rhs, @private (T_, dims_))
225- # Implement the legacy `mem = @private T dims` as
226- # mem = Scratchpad(T, Val(dims))
227-
228- if dims isa Integer
229- dims = (dims,)
230- end
231- alloc = :($ Scratchpad (__ctx__, $ T, Val ($ dims)))
232- push! (allocations, :($ lhs = $ alloc))
233- push! (private, lhs)
234- continue
235- end
236- end
237-
238- push! (current, stmt)
239- end
240-
241- # everything since the last `@synchronize`
242- if ! isempty (current)
243- loop = WorkgroupLoop (deepcopy (indicies), current, allocations, private_allocations, deepcopy (private))
244- push! (new_stmts, emit (loop))
245- end
246- return new_stmts
247- end
248-
249- function emit (loop)
250- idx = gensym (:I )
251- for stmt in loop. indicies
252- # splice index into the i = @index(Cartesian, $idx)
253- @assert stmt. head === :(= )
254- rhs = stmt. args[2 ]
255- push! (rhs. args, idx)
256- end
257- stmts = Any[]
258- append! (stmts, loop. allocations)
259-
260- # private_allocations turn into lhs = ntuple(i->rhs, length(__workitems_iterspace()))
261- N = gensym (:N )
262- push! (stmts, :($ N = length ($ __workitems_iterspace (__ctx__))))
263-
264- for stmt in loop. private_allocations
265- if @capture (stmt, lhs_ = rhs_)
266- push! (stmts, :($ lhs = ntuple (_ -> $ rhs, $ N)))
267- else
268- error (" @private $stmt not an assignment" )
269- end
270- end
271-
272- # don't emit empty loops
273- if ! (isempty (loop. stmts) || all (s -> s isa LineNumberNode, loop. stmts))
274- body = Expr (:block , loop. stmts... )
275- body = postwalk (body) do expr
276- if @capture (expr, lhs_ = rhs_)
277- if lhs in loop. private
278- error (" Can't assign to variables marked private" )
279- end
280- elseif @capture (expr, A_[i__])
281- if A in loop. private
282- return :($ A[$ __index_Local_Linear (__ctx__, $ (idx))][$ (i... )])
283- end
284- elseif expr isa Symbol
285- if expr in loop. private
286- return :($ expr[$ __index_Local_Linear (__ctx__, $ (idx))])
287- end
288- end
289- return expr
290- end
291- loopexpr = quote
292- for $ idx in $ __workitems_iterspace (__ctx__)
293- $ __validindex (__ctx__, $ idx) || continue
294- $ (loop. indicies... )
295- $ (unblock (body))
296- end
297- end
298- push! (stmts, loopexpr)
299- end
300-
301- return unblock (Expr (:block , stmts... ))
302- end
0 commit comments