7979 "svd" ,
8080 "slogdet" ,
8181 "tensorinv" ,
82+ "tensorsolve" ,
8283]
8384
8485
@@ -935,7 +936,7 @@ def slogdet(a):
935936
936937def tensorinv (a , ind = 2 ):
937938 """
938- Compute the ` inverse` of a tensor .
939+ Compute the ' inverse' of an N-dimensional array .
939940
940941 For full documentation refer to :obj:`numpy.linalg.tensorinv`.
941942
@@ -944,7 +945,7 @@ def tensorinv(a, ind=2):
944945 a : {dpnp.ndarray, usm_ndarray}
945946 Tensor to `invert`. Its shape must be 'square', i. e.,
946947 ``prod(a.shape[:ind]) == prod(a.shape[ind:])``.
947- ind : int
948+ ind : int, optional
948949 Number of first indices that are involved in the inverse sum.
949950 Must be a positive integer.
950951 Default: 2.
@@ -989,3 +990,74 @@ def tensorinv(a, ind=2):
989990 a_inv = inv (a )
990991
991992 return a_inv .reshape (* inv_shape )
993+
994+
995+ def tensorsolve (a , b , axes = None ):
996+ """
997+ Solve the tensor equation ``a x = b`` for x.
998+
999+ For full documentation refer to :obj:`numpy.linalg.tensorsolve`.
1000+
1001+ Parameters
1002+ ----------
1003+ a : {dpnp.ndarray, usm_ndarray}
1004+ Coefficient tensor, of shape ``b.shape + Q``. `Q`, a tuple, equals
1005+ the shape of that sub-tensor of `a` consisting of the appropriate
1006+ number of its rightmost indices, and must be such that
1007+ ``prod(Q) == prod(b.shape)`` (in which sense `a` is said to be
1008+ 'square').
1009+ b : {dpnp.ndarray, usm_ndarray}
1010+ Right-hand tensor, which can be of any shape.
1011+ axes : tuple of ints, optional
1012+ Axes in `a` to reorder to the right, before inversion.
1013+ If ``None`` , no reordering is done.
1014+ Default: ``None``.
1015+
1016+ Returns
1017+ -------
1018+ out : dpnp.ndarray
1019+ The tensor with shape ``Q`` such that ``b.shape + Q == a.shape``.
1020+
1021+ See Also
1022+ --------
1023+ :obj:`dpnp.linalg.tensordot` : Compute tensor dot product along specified axes.
1024+ :obj:`dpnp.linalg.tensorinv` : Compute the 'inverse' of an N-dimensional array.
1025+ :obj:`dpnp.einsum` : Evaluates the Einstein summation convention on the operands.
1026+
1027+ Examples
1028+ --------
1029+ >>> import dpnp as np
1030+ >>> a = np.eye(2*3*4)
1031+ >>> a.shape = (2*3, 4, 2, 3, 4)
1032+ >>> b = np.random.randn(2*3, 4)
1033+ >>> x = np.linalg.tensorsolve(a, b)
1034+ >>> x.shape
1035+ (2, 3, 4)
1036+ >>> np.allclose(np.tensordot(a, x, axes=3), b)
1037+ array([ True])
1038+
1039+ """
1040+
1041+ dpnp .check_supported_arrays_type (a , b )
1042+ a_ndim = a .ndim
1043+
1044+ if axes is not None :
1045+ all_axes = list (range (a_ndim ))
1046+ for k in axes :
1047+ all_axes .remove (k )
1048+ all_axes .insert (a_ndim , k )
1049+ a = a .transpose (tuple (all_axes ))
1050+
1051+ old_shape = a .shape [- (a_ndim - b .ndim ) :]
1052+ prod = numpy .prod (old_shape )
1053+
1054+ if a .size != prod ** 2 :
1055+ raise dpnp .linalg .LinAlgError (
1056+ "Input arrays must satisfy the requirement \
1057+ prod(a.shape[b.ndim:]) == prod(a.shape[:b.ndim])"
1058+ )
1059+
1060+ a = a .reshape (- 1 , prod )
1061+ b = b .ravel ()
1062+ res = solve (a , b )
1063+ return res .reshape (old_shape )
0 commit comments