11module CUDAKernels
22
33import CUDA
4- import SpecialFunctions
54import StaticArrays
65import StaticArrays: MArray
7- import Cassette
86import Adapt
97import KernelAbstractions
108
@@ -191,7 +189,7 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
191189 ndrange, workgroupsize, iterspace, dynamic = launch_config (obj, ndrange, workgroupsize)
192190 # this might not be the final context, since we may tune the workgroupsize
193191 ctx = mkcontext (obj, ndrange, iterspace)
194- kernel = CUDA. @cuda launch= false name = String ( nameof ( obj. f)) Cassette . overdub (CUDACTX, obj . f, ctx, args... )
192+ kernel = CUDA. @cuda launch= false obj. f ( ctx, args... )
195193
196194 # figure out the optimal workgroupsize automatically
197195 if KernelAbstractions. workgroupsize (obj) <: DynamicSize && workgroupsize === nothing
@@ -220,52 +218,49 @@ function (obj::Kernel{CUDADevice})(args...; ndrange=nothing, dependencies=Event(
220218
221219 # Launch kernel
222220 event = CUDA. CuEvent (CUDA. EVENT_DISABLE_TIMING)
223- kernel (CUDACTX, obj . f, ctx, args... ; threads= threads, blocks= nblocks, stream= stream)
221+ kernel (ctx, args... ; threads= threads, blocks= nblocks, stream= stream)
224222
225223 CUDA. record (event, stream)
226224 return CudaEvent (event)
227225end
228226
229- Cassette . @context CUDACtx
227+ import CUDA : @device_override
230228
231229import KernelAbstractions: CompilerMetadata, CompilerPass, DynamicCheck, LinearIndices
232230import KernelAbstractions: __index_Local_Linear, __index_Group_Linear, __index_Global_Linear, __index_Local_Cartesian, __index_Group_Cartesian, __index_Global_Cartesian, __validindex, __print
233231import KernelAbstractions: mkcontext, expand, __iterspace, __ndrange, __dynamic_checkbounds
234232
235- const CUDACTX = Cassette. disablehooks (CUDACtx (pass = CompilerPass))
236- KernelAbstractions. cassette (:: Kernel{CUDADevice} ) = CUDACTX
237-
238233function mkcontext (kernel:: Kernel{CUDADevice} , _ndrange, iterspace)
239234 CompilerMetadata {KernelAbstractions.ndrange(kernel), DynamicCheck} (_ndrange, iterspace)
240235end
241236
242- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Local_Linear), ctx)
237+ @device_override @ inline function __index_Local_Linear ( ctx)
243238 return CUDA. threadIdx (). x
244239end
245240
246- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Group_Linear), ctx)
241+ @device_override @ inline function __index_Group_Linear ( ctx)
247242 return CUDA. blockIdx (). x
248243end
249244
250- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Global_Linear), ctx)
245+ @device_override @ inline function __index_Global_Linear ( ctx)
251246 I = @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
252247 # TODO : This is unfortunate, can we get the linear index cheaper
253248 @inbounds LinearIndices (__ndrange (ctx))[I]
254249end
255250
256- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Local_Cartesian), ctx)
251+ @device_override @ inline function __index_Local_Cartesian ( ctx)
257252 @inbounds workitems (__iterspace (ctx))[CUDA. threadIdx (). x]
258253end
259254
260- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Group_Cartesian), ctx)
255+ @device_override @ inline function __index_Group_Cartesian ( ctx)
261256 @inbounds blocks (__iterspace (ctx))[CUDA. blockIdx (). x]
262257end
263258
264- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __index_Global_Cartesian), ctx)
259+ @device_override @ inline function __index_Global_Cartesian ( ctx)
265260 return @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
266261end
267262
268- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __validindex), ctx)
263+ @device_override @ inline function __validindex ( ctx)
269264 if __dynamic_checkbounds (ctx)
270265 I = @inbounds expand (__iterspace (ctx), CUDA. blockIdx (). x, CUDA. threadIdx (). x)
271266 return I in __ndrange (ctx)
276271
277272import KernelAbstractions: groupsize, __groupsize, __workitems_iterspace, add_float_contract, sub_float_contract, mul_float_contract
278273
279- KernelAbstractions. generate_overdubs (@__MODULE__ , CUDACtx)
280-
281- # ##
282- # CUDA specific method rewrites
283- # ##
284-
285- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Float64 ) = ^ (x, y)
286- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Float32 ) = ^ (x, y)
287- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float64 , y:: Int32 ) = ^ (x, y)
288- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Float32 , y:: Int32 ) = ^ (x, y)
289- @inline Cassette. overdub (:: CUDACtx , :: typeof (^ ), x:: Union{Float32, Float64} , y:: Int64 ) = ^ (x, y)
290-
291- # libdevice.jl
292- const cudafuns = (:cos , :cospi , :sin , :sinpi , :tan ,
293- :acos , :asin , :atan ,
294- :cosh , :sinh , :tanh ,
295- :acosh , :asinh , :atanh ,
296- :log , :log10 , :log1p , :log2 ,
297- :exp , :exp2 , :exp10 , :expm1 , :ldexp ,
298- # :isfinite, :isinf, :isnan, :signbit,
299- :abs ,
300- :sqrt , :cbrt ,
301- :ceil , :floor ,)
302- for f in cudafuns
303- @eval function Cassette. overdub (ctx:: CUDACtx , :: typeof (Base.$ f), x:: Union{Float32, Float64} )
304- @Base . _inline_meta
305- return Base.$ f (x)
306- end
307- end
308-
309- @inline Cassette. overdub (:: CUDACtx , :: typeof (sincos), x:: Union{Float32, Float64} ) = (Base. sin (x), Base. cos (x))
310- @inline Cassette. overdub (:: CUDACtx , :: typeof (exp), x:: Union{ComplexF32, ComplexF64} ) = Base. exp (x)
311-
312- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. gamma), x:: Union{Float32, Float64} ) = CUDA. tgamma (x)
313- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erf), x:: Union{Float32, Float64} ) = SpecialFunctions. erf (x)
314- @inline Cassette. overdub (:: CUDACtx , :: typeof (SpecialFunctions. erfc), x:: Union{Float32, Float64} ) = SpecialFunctions. erfc (x)
315-
316274@static if Base. isbindingresolved (CUDA, :emit_shmem ) && Base. isdefined (CUDA, :emit_shmem )
317275 const emit_shmem = CUDA. emit_shmem
318276else
@@ -325,7 +283,7 @@ import KernelAbstractions: ConstAdaptor, SharedMemory, Scratchpad, __synchronize
325283# GPU implementation of shared memory
326284# ##
327285
328- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( SharedMemory), :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
286+ @device_override @ inline function SharedMemory ( :: Type{T} , :: Val{Dims} , :: Val{Id} ) where {T, Dims, Id}
329287 ptr = emit_shmem (T, Val (prod (Dims)))
330288 CUDA. CuDeviceArray (Dims, ptr)
331289end
@@ -335,15 +293,15 @@ end
335293# - private memory for each workitem
336294# ##
337295
338- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( Scratchpad), ctx, :: Type{T} , :: Val{Dims} ) where {T, Dims}
296+ @device_override @ inline function Scratchpad ( ctx, :: Type{T} , :: Val{Dims} ) where {T, Dims}
339297 MArray {__size(Dims), T} (undef)
340298end
341299
342- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __synchronize) )
300+ @device_override @ inline function __synchronize ( )
343301 CUDA. sync_threads ()
344302end
345303
346- @inline function Cassette . overdub ( :: CUDACtx , :: typeof ( __print), args... )
304+ @device_override @ inline function __print ( args... )
347305 CUDA. _cuprint (args... )
348306end
349307
@@ -356,29 +314,4 @@ Adapt.adapt_storage(to::ConstAdaptor, a::CUDA.CuDeviceArray) = Base.Experimental
356314# Argument conversion
357315KernelAbstractions. argconvert (k:: Kernel{CUDADevice} , arg) = CUDA. cudaconvert (arg)
358316
359- # Cassette.jl#195
360- # Device intrinsics are inferred in a different World (1.6) or using MethodOverlay tables (1.7)
361- # Cassette sees neither of them and thus overdubbing them fails.
362- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. arrayref), args... )
363- CUDA. arrayref (args... )
364- end
365- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. arrayset), args... )
366- CUDA. arrayset (args... )
367- end
368- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. const_arrayref), args... )
369- CUDA. const_arrayref (args... )
370- end
371- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. logb), args... )
372- CUDA. logb (args... )
373- end
374- # @inline function Cassette.overdub(::CUDACtx, ::typeof(CUDA.tgamma), args...)
375- # CUDA.tgamma(args...)
376- # end
377- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. compute_capability), args... )
378- CUDA. compute_capability (args... )
379- end
380- @inline function Cassette. overdub (:: CUDACtx , :: typeof (CUDA. ptx_isa_version), args... )
381- CUDA. ptx_isa_version (args... )
382- end
383-
384317end
0 commit comments