@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
5959 DPCTLWorkGroupMemory_Delete,
6060 _arg_data_type,
6161 _backend_type,
62+ _md_local_accessor,
6263 _queue_property_type,
6364)
6465from .memory._memory cimport _Memory
@@ -125,6 +126,95 @@ cdef class kernel_arg_type_attribute:
125126 return self .attr_value
126127
127128
129+ cdef class LocalAccessor:
130+ """
131+ LocalAccessor(dtype, shape)
132+
133+ Python class for specifying the dimensionality and type of a
134+ ``sycl::local_accessor``, to be used as a kernel argument type.
135+
136+ Args:
137+ dtype (str):
138+ the data type of the local memory.
139+ The permitted values are
140+
141+ `'i1'`, `'i2'`, `'i4'`, `'i8'`:
142+ signed integral types int8_t, int16_t, int32_t, int64_t
143+ `'u1'`, `'u2'`, `'u4'`, `'u8'`
144+ unsigned integral types uint8_t, uint16_t, uint32_t,
145+ uint64_t
146+ `'f4'`, `'f8'`,
147+ single- and double-precision floating-point types float and
148+ double
149+ shape (tuple, list):
150+ Size of LocalAccessor dimensions. Dimension of the LocalAccessor is
151+ determined by the length of the tuple. Must be of length 1, 2, or 3,
152+ and contain only non-negative integers.
153+
154+ Raises:
155+ TypeError:
156+ If the given shape is not a tuple or list.
157+ ValueError:
158+ If the given shape sequence is not between one and three elements long.
159+ TypeError:
160+ If the shape is not a sequence of integers.
161+ ValueError:
162+ If the shape contains a negative integer.
163+ ValueError:
164+ If the dtype string is unrecognized.
165+ """
166+ cdef _md_local_accessor lacc
167+
168+ def __cinit__ (self , str dtype , shape ):
169+ if not isinstance (shape, (list , tuple )):
170+ raise TypeError (f" `shape` must be a list or tuple, got {type(shape)}" )
171+ ndim = len (shape)
172+ if ndim < 1 or ndim > 3 :
173+ raise ValueError (" LocalAccessor must have dimension between one and three" )
174+ for s in shape:
175+ if not isinstance (s, numbers.Integral):
176+ raise TypeError (" LocalAccessor shape must be a sequence of integers" )
177+ if s < 0 :
178+ raise ValueError (" LocalAccessor dimensions must be non-negative" )
179+ self .lacc.ndim = ndim
180+ self .lacc.dim0 = < size_t> shape[0 ]
181+ self .lacc.dim1 = < size_t> shape[1 ] if ndim > 1 else 1
182+ self .lacc.dim2 = < size_t> shape[2 ] if ndim > 2 else 1
183+
184+ if dtype == ' i1' :
185+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
186+ elif dtype == ' u1' :
187+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
188+ elif dtype == ' i2' :
189+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
190+ elif dtype == ' u2' :
191+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
192+ elif dtype == ' i4' :
193+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
194+ elif dtype == ' u4' :
195+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
196+ elif dtype == ' i8' :
197+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
198+ elif dtype == ' u8' :
199+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
200+ elif dtype == ' f4' :
201+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
202+ elif dtype == ' f8' :
203+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
204+ else :
205+ raise ValueError (f" Unrecognized type value: '{dtype}'" )
206+
207+ def __repr__ (self ):
208+ return f" LocalAccessor({self.lacc.ndim})"
209+
210+ cdef size_t addressof(self ):
211+ """
212+ Returns the address of the _md_local_accessor for this LocalAccessor
213+ cast to ``size_t``.
214+ """
215+ return < size_t> & self .lacc
216+
217+
128218cdef class _kernel_arg_type:
129219 """
130220 An enumeration of supported kernel argument types in
@@ -865,6 +955,9 @@ cdef class SyclQueue(_SyclQueue):
865955 elif isinstance (arg, WorkGroupMemory):
866956 kargs[idx] = < void * > (< size_t> arg._ref)
867957 kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
958+ elif isinstance (arg, LocalAccessor):
959+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
960+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868961 else :
869962 ret = - 1
870963 return ret
0 commit comments