@@ -37,6 +37,7 @@ cimport dpctl.memory as c_dpmem
3737cimport dpctl.tensor._dlpack as c_dlpack
3838
3939import dpctl.tensor._flags as _flags
40+ from dpctl.tensor._tensor_impl import default_device_fp_type
4041
4142include " _stride_utils.pxi"
4243include " _types.pxi"
@@ -104,7 +105,7 @@ cdef class InternalUSMArrayError(Exception):
104105
105106
106107cdef class usm_ndarray:
107- """ usm_ndarray(shape, dtype="|f8" , strides=None, buffer="device", \
108+ """ usm_ndarray(shape, dtype=None , strides=None, buffer="device", \
108109 offset=0, order="C", buffer_ctor_kwargs=dict(), \
109110 array_namespace=None)
110111
@@ -116,6 +117,8 @@ cdef class usm_ndarray:
116117 Shape of the array to be created.
117118 dtype (str, dtype):
118119 Array data type, i.e. the type of array elements.
120+ If ``dtype`` has the value ``None``, it is determined by default
121+ floating point type supported by target device.
119122 The supported types are
120123 * ``bool``
121124 boolean type
@@ -134,7 +137,7 @@ cdef class usm_ndarray:
134137 double-precision real and complex floating
135138 types, supported if target device's property
136139 ``has_aspect_fp64`` is ``True``.
137- Default: ``"|f8" ``.
140+ Default: ``None ``.
138141 strides (tuple, optional):
139142 Strides of the array to be created in elements.
140143 If ``strides`` has the value ``None``, it is determined by the
@@ -219,7 +222,7 @@ cdef class usm_ndarray:
219222 " Data pointers of cloned and original objects are different." )
220223 return res
221224
222- def __cinit__ (self , shape , dtype = " |f8 " , strides = None , buffer = ' device' ,
225+ def __cinit__ (self , shape , dtype = None , strides = None , buffer = ' device' ,
223226 Py_ssize_t offset = 0 , order = ' C' ,
224227 buffer_ctor_kwargs = dict (),
225228 array_namespace = None ):
@@ -252,6 +255,13 @@ cdef class usm_ndarray:
252255 except Exception :
253256 raise TypeError (" Argument shape must be a list or a tuple." )
254257 nd = len (shape)
258+ if dtype is None :
259+ q = buffer_ctor_kwargs.get(" queue" )
260+ if q is not None :
261+ dtype = default_device_fp_type(q)
262+ else :
263+ dev = dpctl.select_default_device()
264+ dtype = " f8" if dev.has_aspect_fp64 else " f4"
255265 typenum = dtype_to_typenum(dtype)
256266 if (typenum < 0 ):
257267 if typenum == - 2 :
0 commit comments