@@ -866,9 +866,14 @@ def diag(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
866866 v : array_like
867867 Input data, in any form that can be converted to an array. This
868868 includes scalars, lists, lists of tuples, tuples, tuples of tuples,
869- tuples of lists, and ndarrays. If `v` is a 2-D array, return a copy of
870- its k-th diagonal. If `v` is a 1-D array, return a 2-D array with `v`
871- on the k-th diagonal.
869+ tuples of lists, and ndarrays.
870+ If `v` is a 1-D array, return a 2-D array with `v`
871+ on the `k`-th diagonal.
872+ If `v` is a 2-D array and is an instance of
873+ {dpnp.ndarray, usm_ndarray}, then:
874+ - If `device`, `usm_type`, and `sycl_queue` are set to their
875+ default values, returns a read/write view of its k-th diagonal.
876+ - Otherwise, returns a copy of its k-th diagonal.
872877 k : int, optional
873878 Diagonal in question. The default is 0. Use k > 0 for diagonals above
874879 the main diagonal, and k < 0 for diagonals below the main diagonal.
@@ -894,79 +899,62 @@ def diag(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
894899 --------
895900 :obj:`diagonal` : Return specified diagonals.
896901 :obj:`diagflat` : Create a 2-D array with the flattened input as a diagonal.
897- :obj:`trace` : Return sum along diagonals.
898- :obj:`triu` : Return upper triangle of an array.
899- :obj:`tril` : Return lower triangle of an array.
902+ :obj:`trace` : Return the sum along diagonals of the array .
903+ :obj:`triu` : Upper triangle of an array.
904+ :obj:`tril` : Lower triangle of an array.
900905
901906 Examples
902907 --------
903908 >>> import dpnp as np
904- >>> x0 = np.arange(9).reshape((3, 3))
905- >>> x0
909+ >>> x = np.arange(9).reshape((3, 3))
910+ >>> x
906911 array([[0, 1, 2],
907912 [3, 4, 5],
908913 [6, 7, 8]])
909914
910- >>> np.diag(x0 )
915+ >>> np.diag(x )
911916 array([0, 4, 8])
912- >>> np.diag(x0 , k=1)
917+ >>> np.diag(x , k=1)
913918 array([1, 5])
914- >>> np.diag(x0 , k=-1)
919+ >>> np.diag(x , k=-1)
915920 array([3, 7])
916921
917- >>> np.diag(np.diag(x0 ))
922+ >>> np.diag(np.diag(x ))
918923 array([[0, 0, 0],
919924 [0, 4, 0],
920925 [0, 0, 8]])
921926
922927 Creating an array on a different device or with a specified usm_type
923928
924- >>> x = np.diag(x0 ) # default case
925- >>> x, x .device, x .usm_type
929+ >>> res = np.diag(x ) # default case
930+ >>> res, res .device, res .usm_type
926931 (array([0, 4, 8]), Device(level_zero:gpu:0), 'device')
927932
928- >>> y = np.diag(x0 , device="cpu")
929- >>> y, y .device, y .usm_type
933+ >>> res_cpu = np.diag(x , device="cpu")
934+ >>> res_cpu, res_cpu .device, res_cpu .usm_type
930935 (array([0, 4, 8]), Device(opencl:cpu:0), 'device')
931936
932- >>> z = np.diag(x0 , usm_type="host")
933- >>> z, z .device, z .usm_type
937+ >>> res_host = np.diag(x , usm_type="host")
938+ >>> res_host, res_host .device, res_host .usm_type
934939 (array([0, 4, 8]), Device(level_zero:gpu:0), 'host')
935940
936941 """
937942
938943 if not isinstance (k , int ):
939944 raise TypeError (f"An integer is required, but got { type (k )} " )
940945
941- v = dpnp .asarray (v , device = device , usm_type = usm_type , sycl_queue = sycl_queue )
946+ v = dpnp .asanyarray (
947+ v , device = device , usm_type = usm_type , sycl_queue = sycl_queue
948+ )
942949
943- init0 = max (0 , - k )
944- init1 = max (0 , k )
945950 if v .ndim == 1 :
946951 size = v .shape [0 ] + abs (k )
947- m = dpnp .zeros (
948- (size , size ),
949- dtype = v .dtype ,
950- usm_type = v .usm_type ,
951- sycl_queue = v .sycl_queue ,
952- )
953- for i in range (v .shape [0 ]):
954- m [(init0 + i ), init1 + i ] = v [i ]
955- return m
952+ ret = dpnp .zeros_like (v , shape = (size , size ))
953+ ret .diagonal (k )[:] = v
954+ return ret
956955
957956 if v .ndim == 2 :
958- size = max (
959- 0 , min (v .shape [0 ], v .shape [0 ] + k , v .shape [1 ], v .shape [1 ] - k )
960- )
961- m = dpnp .zeros (
962- (size ,),
963- dtype = v .dtype ,
964- usm_type = v .usm_type ,
965- sycl_queue = v .sycl_queue ,
966- )
967- for i in range (size ):
968- m [i ] = v [(init0 + i ), init1 + i ]
969- return m
957+ return v .diagonal (k )
970958
971959 raise ValueError ("Input must be a 1-D or 2-D array." )
972960
@@ -1008,9 +996,9 @@ def diagflat(v, /, k=0, *, device=None, usm_type=None, sycl_queue=None):
1008996
1009997 See Also
1010998 --------
1011- :obj:`diag` : Return the extracted diagonal or constructed diagonal array.
1012- :obj:`diagonal` : Return specified diagonals.
1013- :obj:`trace` : Return sum along diagonals.
999+ :obj:`dpnp. diag` : Extract a diagonal or construct a diagonal array.
1000+ :obj:`dpnp. diagonal` : Return specified diagonals.
1001+ :obj:`dpnp. trace` : Return sum along diagonals.
10141002
10151003 Examples
10161004 --------
@@ -1324,6 +1312,11 @@ def eye(
13241312 Parameter `like` is supported only with default value ``None``.
13251313 Otherwise, the function raises `NotImplementedError` exception.
13261314
1315+ See Also
1316+ --------
1317+ :obj:`dpnp.identity` : Return the identity array.
1318+ :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
1319+
13271320 Examples
13281321 --------
13291322 >>> import dpnp as np
@@ -2264,7 +2257,7 @@ def identity(
22642257 :obj:`dpnp.eye` : Return a 2-D array with ones on the diagonal and zeros
22652258 elsewhere.
22662259 :obj:`dpnp.ones` : Return a new array setting values to one.
2267- :obj:`dpnp.diag` : Return diagonal 2-D array from an input 1-D array.
2260+ :obj:`dpnp.diag` : Extract a diagonal or construct a diagonal array.
22682261
22692262 Examples
22702263 --------
0 commit comments