@@ -208,3 +208,47 @@ def test_pinv_size_0(self):
208208 self .check_x ((0 , 0 ), rcond = 1e-15 )
209209 self .check_x ((0 , 2 , 3 ), rcond = 1e-15 )
210210 self .check_x ((2 , 0 , 3 ), rcond = 1e-15 )
211+
212+
213+ class TestTensorInv (unittest .TestCase ):
214+ @testing .for_dtypes ("ifdFD" )
215+ @_condition .retry (10 )
216+ def check_x (self , a_shape , ind , dtype ):
217+ a_cpu = numpy .random .randint (0 , 10 , size = a_shape ).astype (dtype )
218+ a_gpu = cupy .asarray (a_cpu )
219+ a_gpu_copy = a_gpu .copy ()
220+ result_cpu = numpy .linalg .tensorinv (a_cpu , ind = ind )
221+ result_gpu = cupy .linalg .tensorinv (a_gpu , ind = ind )
222+ assert_dtype_allclose (result_gpu , result_cpu )
223+ testing .assert_array_equal (a_gpu_copy , a_gpu )
224+
225+ def check_shape (self , a_shape , ind ):
226+ a = cupy .random .rand (* a_shape )
227+ with self .assertRaises (
228+ (numpy .linalg .LinAlgError , cupy .linalg .LinAlgError )
229+ ):
230+ cupy .linalg .tensorinv (a , ind = ind )
231+
232+ def check_ind (self , a_shape , ind ):
233+ a = cupy .random .rand (* a_shape )
234+ with self .assertRaises (ValueError ):
235+ cupy .linalg .tensorinv (a , ind = ind )
236+
237+ def test_tensorinv (self ):
238+ self .check_x ((12 , 3 , 4 ), ind = 1 )
239+ self .check_x ((3 , 8 , 24 ), ind = 2 )
240+ self .check_x ((18 , 3 , 3 , 2 ), ind = 1 )
241+ self .check_x ((1 , 4 , 2 , 2 ), ind = 2 )
242+ self .check_x ((2 , 3 , 5 , 30 ), ind = 3 )
243+ self .check_x ((24 , 2 , 2 , 3 , 2 ), ind = 1 )
244+ self .check_x ((3 , 4 , 2 , 3 , 2 ), ind = 2 )
245+ self .check_x ((1 , 2 , 3 , 2 , 3 ), ind = 3 )
246+ self .check_x ((3 , 2 , 1 , 2 , 12 ), ind = 4 )
247+
248+ def test_invalid_shape (self ):
249+ self .check_shape ((2 , 3 , 4 ), ind = 1 )
250+ self .check_shape ((1 , 2 , 3 , 4 ), ind = 3 )
251+
252+ def test_invalid_index (self ):
253+ self .check_ind ((12 , 3 , 4 ), ind = - 1 )
254+ self .check_ind ((18 , 3 , 3 , 2 ), ind = 0 )
0 commit comments