@@ -16,13 +16,15 @@ using StaticArrays
1616using Adapt
1717
1818"""
19- @kernel function f(args) end
19+ @kernel [N] function f(args) end
2020
2121Takes a function definition and generates a [`Kernel`](@ref) constructor from it.
2222The enclosed function is allowed to contain kernel language constructs.
2323In order to call it the kernel has first to be specialized on the backend
2424and then invoked on the arguments.
2525
26+ The optional `N` parameter can be used to fix the number of dimensions used for the ndrange.
27+
2628# Kernel language
2729
2830- [`@Const`](@ref)
@@ -54,7 +56,7 @@ macro kernel(expr)
5456end
5557
5658"""
57- @kernel config function f(args) end
59+ @kernel [N] config function f(args) end
5860
5961This allows for two different configurations:
6062
@@ -584,17 +586,17 @@ in a workgroup.
584586 ```
585587 As well as the on-device functionality.
586588"""
587- struct Kernel{Backend, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
589+ struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
588590 backend:: Backend
589591 f:: Fun
590592end
591593
592- function Base. similar (kernel:: Kernel{D, WS, ND} , f:: F ) where {D, WS, ND, F}
593- Kernel {D, WS, ND, F} (kernel. backend, f)
594+ function Base. similar (kernel:: Kernel{D, N, WS, ND} , f:: F ) where {D, N , WS, ND, F}
595+ Kernel {D, N, WS, ND, F} (kernel. backend, f)
594596end
595597
596- workgroupsize (:: Kernel{D, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
597- ndrange (:: Kernel{D, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
598+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
599+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
598600backend (kernel:: Kernel ) = kernel. backend
599601
600602"""
@@ -657,8 +659,8 @@ Partition a kernel for the given ndrange and workgroupsize.
657659 return iterspace, dynamic
658660end
659661
660- function construct (backend:: Backend , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , S <: _Size , NDRange <: _Size , XPUName}
661- return Kernel {Backend, S, NDRange, XPUName} (backend, xpu_name)
662+ function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N , S <: _Size , NDRange <: _Size , XPUName}
663+ return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
662664end
663665
664666# ##
0 commit comments