@@ -64,6 +64,7 @@ from ._backend cimport ( # noqa: E211
6464from .memory._memory cimport _Memory
6565
6666import ctypes
67+ import numbers
6768
6869from .enum_types import backend_type
6970
@@ -1586,14 +1587,24 @@ cdef class WorkGroupMemory:
15861587 f" arguments, but {len(args)} were given" )
15871588
15881589 if len (args) == 1 :
1590+ if not isinstance (args[0 ], numbers.Integral):
1591+ raise TypeError (" WorkGroupMemory single argument constructor"
1592+ " expects number of bytes as integer value" )
15891593 nbytes = < size_t> (args[0 ])
15901594 else :
1595+ if not isinstance (args[0 ], str ) or not isinstance (args[1 ], numbers.Integral):
1596+ raise TypeError (" WorkGroupMemory constructor expects type as"
1597+ " string and number of bytes as integer value." )
15911598 dtype = < str > (args[0 ])
15921599 count = < size_t> (args[1 ])
1593- ty = dtype[0 ]
1594- if not ty in [" i" , " u" , " f" ]:
1600+ if not dtype[0 ] in [" i" , " u" , " f" ]:
15951601 raise TypeError (f" Unrecognized type value: '{dtype}'" )
1596- byte_size = < size_t> (int (dtype[1 :]))
1602+ try :
1603+ bit_width = int (dtype[1 :])
1604+ except ValueError :
1605+ raise TypeError (f" Unrecognized type value: '{dtype}'" )
1606+
1607+ byte_size = < size_t> bit_width
15971608 nbytes = count * byte_size
15981609
15991610 self ._mem_ref = DPCTLWorkGroupMemory_Create(nbytes)
0 commit comments