@@ -1555,15 +1555,47 @@ cdef class WorkGroupMemory:
15551555 This is based on a DPC++ SYCL extension and only available in newer
15561556 versions. Use ``is_available()`` to check availability in your build.
15571557
1558+ There are multiple ways to create a `WorkGroupMemory`.
1559+
1560+ - If the constructor is invoked with just a single argument, this argument
1561+ is interpreted as the number of bytes to allocated in the shared local
1562+ memory.
1563+
1564+ - If the constructor is invoked with two arguments, the first argument is
1565+ interpreted as the datatype of the local memory, using the numpy type
1566+ naming scheme.
1567+ The second argument is interpreted as the number of elements to allocate.
1568+ The number of bytes to allocate is then computed from the byte size of
1569+ the data type and the element count.
1570+
15581571 Args:
1559- nbytes (int)
1560- number of bytes to allocate in local memory.
1561- Expected to be positive.
1572+ args:
1573+ Variadic argument, see class documentation.
1574+
1575+ Raises:
1576+ TypeError: In case of incorrect arguments given to constructors,
1577+ unexpected types of input arguments.
15621578 """
1563- def __cinit__ (self , Py_ssize_t nbytes ):
1579+ def __cinit__ (self , *args ):
1580+ cdef size_t nbytes
15641581 if not DPCTLWorkGroupMemory_Available():
15651582 raise RuntimeError (" Workgroup memory extension not available" )
15661583
1584+ if not (0 < len (args) < 3 ):
1585+ raise TypeError (" WorkGroupMemory constructor takes 1 or 2 "
1586+ f" arguments, but {len(args)} were given" )
1587+
1588+ if len (args) == 1 :
1589+ nbytes = < size_t> (args[0 ])
1590+ else :
1591+ dtype = < str > (args[0 ])
1592+ count = < size_t> (args[1 ])
1593+ ty = dtype[0 ]
1594+ if not ty in [" i" , " u" , " f" ]:
1595+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1596+ byte_size = < size_t> (int (dtype[1 :]))
1597+ nbytes = count * byte_size
1598+
15671599 self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
15681600
15691601 """ Check whether the work_group_memory extension is available"""
0 commit comments