@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
5959 DPCTLWorkGroupMemory_Delete,
6060 _arg_data_type,
6161 _backend_type,
62+ _md_local_accessor,
6263 _queue_property_type,
6364)
6465from .memory._memory cimport _Memory
@@ -125,6 +126,47 @@ cdef class kernel_arg_type_attribute:
125126 return self .attr_value
126127
127128
129+ cdef class LocalAccessor:
130+ cdef _md_local_accessor lacc
131+
132+ def __cinit__ (self , size_t ndim , str type , size_t dim0 , size_t dim1 , size_t dim2 ):
133+ self .lacc.ndim = ndim
134+ self .lacc.dim0 = dim0
135+ self .lacc.dim1 = dim1
136+ self .lacc.dim2 = dim2
137+
138+ if ndim < 1 or ndim > 3 :
139+ raise ValueError
140+ if type == ' i1' :
141+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
142+ elif type == ' u1' :
143+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
144+ elif type == ' i2' :
145+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
146+ elif type == ' u2' :
147+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
148+ elif type == ' i4' :
149+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
150+ elif type == ' u4' :
151+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
152+ elif type == ' i8' :
153+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
154+ elif type == ' u8' :
155+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
156+ elif type == ' f4' :
157+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
158+ elif type == ' f8' :
159+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
160+ else :
161+ raise ValueError (f" Unrecornigzed type value: '{type}'" )
162+
163+ def __repr__ (self ):
164+ return " LocalAccessor(" + self .ndim + " )"
165+
166+ cdef size_t addressof(self ):
167+ return < size_t> & self .lacc
168+
169+
128170cdef class _kernel_arg_type:
129171 """
130172 An enumeration of supported kernel argument types in
@@ -865,6 +907,9 @@ cdef class SyclQueue(_SyclQueue):
865907 elif isinstance (arg, WorkGroupMemory):
866908 kargs[idx] = < void * > (< size_t> arg._ref)
867909 kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
910+ elif isinstance (arg, LocalAccessor):
911+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
912+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868913 else :
869914 ret = - 1
870915 return ret
0 commit comments