@@ -52,7 +52,7 @@ synchronize(backend)
5252```
5353"""
5454macro kernel (expr)
55- __kernel (expr, #= generate_cpu=# true , #= force_inbounds=# false )
55+ __kernel (DynamicSize (), expr, #= generate_cpu=# true , #= force_inbounds=# false )
5656end
5757
5858"""
@@ -70,10 +70,11 @@ This allows for two different configurations:
7070"""
7171macro kernel (ex... )
7272 if length (ex) == 1
73- __kernel (ex[1 ], true , false )
73+ __kernel (DynamicSize (), ex[1 ], true , false )
7474 else
7575 generate_cpu = true
7676 force_inbounds = false
77+ N = DynamicSize () # TODO parse N
7778 for i in 1 : (length (ex) - 1 )
7879 if ex[i] isa Expr && ex[i]. head == :(= ) &&
7980 ex[i]. args[1 ] == :cpu && ex[i]. args[2 ] isa Bool
@@ -90,7 +91,7 @@ macro kernel(ex...)
9091 )
9192 end
9293 end
93- __kernel (ex[end ], generate_cpu, force_inbounds)
94+ __kernel (N, ex[end ], generate_cpu, force_inbounds)
9495 end
9596end
9697
@@ -586,7 +587,7 @@ in a workgroup.
586587 ```
587588 As well as the on-device functionality.
588589"""
589- struct Kernel{Backend, N, WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590+ struct Kernel{Backend, N <: _Size , WorkgroupSize <: _Size , NDRange <: _Size , Fun}
590591 backend:: Backend
591592 f:: Fun
592593end
@@ -595,8 +596,9 @@ function Base.similar(kernel::Kernel{D, N, WS, ND}, f::F) where {D, N, WS, ND, F
595596 Kernel {D, N, WS, ND, F} (kernel. backend, f)
596597end
597598
598- workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, WorkgroupSize} = WorkgroupSize
599- ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, WorkgroupSize, NDRange} = NDRange
599+ workgroupsize (:: Kernel{D, N, WorkgroupSize} ) where {D, N, WorkgroupSize} = WorkgroupSize
600+ ndrange (:: Kernel{D, N, WorkgroupSize, NDRange} ) where {D, N, WorkgroupSize, NDRange} = NDRange
601+ ndims (:: Kernel{D, N} ) where {D, N} = N
600602backend (kernel:: Kernel ) = kernel. backend
601603
602604"""
@@ -605,6 +607,7 @@ Partition a kernel for the given ndrange and workgroupsize.
605607@inline function partition (kernel, ndrange, workgroupsize)
606608 static_ndrange = KernelAbstractions. ndrange (kernel)
607609 static_workgroupsize = KernelAbstractions. workgroupsize (kernel)
610+ static_ndims = KernelAbstractions. ndims (kernel)
608611
609612 if ndrange === nothing && static_ndrange <: DynamicSize ||
610613 workgroupsize === nothing && static_workgroupsize <: DynamicSize
@@ -655,11 +658,16 @@ Partition a kernel for the given ndrange and workgroupsize.
655658 workgroupsize = CartesianIndices (workgroupsize)
656659 end
657660
661+ if static_ndims <: StaticSize
662+ @assert get (static_ndims) == length (ndrange)
663+ end
664+
665+ # TODO : Add static_ndims
658666 iterspace = NDRange {length(ndrange), static_blocks, static_workgroupsize} (blocks, workgroupsize)
659667 return iterspace, dynamic
660668end
661669
662- function construct (backend:: Backend , :: Val{N} , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N, S <: _Size , NDRange <: _Size , XPUName}
670+ function construct (backend:: Backend , :: N , :: S , :: NDRange , xpu_name:: XPUName ) where {Backend <: Union{CPU, GPU} , N <: _Size , S <: _Size , NDRange <: _Size , XPUName}
663671 return Kernel {Backend, N, S, NDRange, XPUName} (backend, xpu_name)
664672end
665673
0 commit comments