@@ -320,12 +320,8 @@ def slogdet(x: Array, /) -> SlogdetResult:
320320def _solve (a : np .ndarray , b : np .ndarray ) -> np .ndarray :
321321 try :
322322 from numpy .linalg ._linalg import ( # type: ignore[attr-defined]
323- _assert_stacked_2d ,
324- _assert_stacked_square ,
325- _commonType ,
326- _makearray ,
327- _raise_linalgerror_singular ,
328- isComplexType ,
323+ _makearray , _assert_stacked_2d , _assert_stacked_square ,
324+ _commonType , isComplexType , _raise_linalgerror_singular
329325 )
330326 except ImportError :
331327 from numpy .linalg .linalg import ( # type: ignore[attr-defined]
@@ -412,7 +408,8 @@ def trace(x: Array, /, *, offset: int = 0, dtype: DType | None = None) -> Array:
412408
413409 # Note: trace always operates on the last two axes, whereas np.trace
414410 # operates on the first two axes by default
415- return Array ._new (np .asarray (np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = np_dtype )), device = x .device )
411+ res = np .trace (x ._array , offset = offset , axis1 = - 2 , axis2 = - 1 , dtype = np_dtype )
412+ return Array ._new (np .asarray (res ), device = x .device )
416413
417414# Note: the name here is different from norm(). The array API norm is split
418415# into matrix_norm and vector_norm().
0 commit comments