@@ -6,39 +6,60 @@ import numpy as np
66import optype .numpy as onp
77import optype .numpy .compat as npc
88
9- from scipy .linalg import solve , solve_banded , solve_circulant , solve_toeplitz , solve_triangular
9+ from scipy .linalg import inv , solve , solve_banded , solve_circulant , solve_toeplitz , solve_triangular
10+
11+ b1_nd : onp .ArrayND [np .bool_ ]
1012
1113i8_1d : onp .Array1D [np .int8 ]
1214i8_2d : onp .Array2D [np .int8 ]
1315i8_3d : onp .Array3D [np .int8 ]
16+ i8_nd : onp .ArrayND [np .int8 ]
17+
18+ i32_1d : onp .Array1D [np .int32 ]
19+ i32_2d : onp .Array2D [np .int32 ]
20+ i32_3d : onp .Array3D [np .int32 ]
21+ i32_nd : onp .ArrayND [np .int32 ]
1422
1523f16_1d : onp .Array1D [np .float16 ]
1624f16_2d : onp .Array2D [np .float16 ]
1725f16_3d : onp .Array3D [np .float16 ]
26+ f16_nd : onp .ArrayND [np .float16 ]
1827
1928f32_1d : onp .Array1D [np .float32 ]
2029f32_2d : onp .Array2D [np .float32 ]
2130f32_3d : onp .Array3D [np .float32 ]
31+ f32_nd : onp .ArrayND [np .float32 ]
2232
2333f64_1d : onp .Array1D [np .float64 ]
2434f64_2d : onp .Array2D [np .float64 ]
2535f64_3d : onp .Array3D [np .float64 ]
36+ f64_nd : onp .ArrayND [np .float64 ]
2637
2738f80_1d : onp .Array1D [np .longdouble ]
2839f80_2d : onp .Array2D [np .longdouble ]
2940f80_3d : onp .Array3D [np .longdouble ]
41+ f80_nd : onp .ArrayND [np .longdouble ]
3042
3143c64_1d : onp .Array1D [np .complex64 ]
3244c64_2d : onp .Array2D [np .complex64 ]
3345c64_3d : onp .Array3D [np .complex64 ]
46+ c64_nd : onp .ArrayND [np .complex64 ]
3447
3548c128_1d : onp .Array1D [np .complex128 ]
3649c128_2d : onp .Array2D [np .complex128 ]
3750c128_3d : onp .Array3D [np .complex128 ]
51+ c128_nd : onp .ArrayND [np .complex128 ]
3852
3953c160_1d : onp .Array1D [np .clongdouble ]
4054c160_2d : onp .Array2D [np .clongdouble ]
4155c160_3d : onp .Array3D [np .clongdouble ]
56+ c160_nd : onp .ArrayND [np .clongdouble ]
57+
58+ py_b_2d : list [list [bool ]]
59+ py_b_3d : list [list [list [bool ]]]
60+
61+ py_i_2d : list [list [int ]]
62+ py_i_3d : list [list [list [int ]]]
4263
4364py_f_1d : list [float ]
4465py_f_2d : list [list [float ]]
@@ -314,4 +335,50 @@ assert_type(solve_circulant(py_c_1d, py_c_3d), onp.ArrayND[np.complex128])
314335assert_type (solve_circulant (py_c_2d , py_c_1d ), onp .ArrayND [np .complex128 ])
315336
316337###
317- # TODO(jorenham): inv, pinv, pinvh, det, lstsq, matrix_balance, matmul_toeplitz
338+ # inv
339+
340+ assert_type (inv (f32_2d ), onp .Array2D [np .float32 ])
341+ assert_type (inv (f64_2d ), onp .Array2D [np .float64 ])
342+ assert_type (inv (c64_2d ), onp .Array2D [np .complex64 ])
343+ assert_type (inv (c128_2d ), onp .Array2D [np .complex128 ])
344+
345+ assert_type (inv (py_b_2d ), onp .Array2D [np .float32 ])
346+ assert_type (inv (py_i_2d ), onp .Array2D [np .float64 ])
347+ assert_type (inv (py_f_2d ), onp .Array2D [np .float64 ])
348+ assert_type (inv (py_c_2d ), onp .Array2D [np .complex128 ])
349+
350+ assert_type (inv (f32_3d ), onp .Array3D [np .float32 ])
351+ assert_type (inv (f64_3d ), onp .Array3D [np .float64 ])
352+ assert_type (inv (c64_3d ), onp .Array3D [np .complex64 ])
353+ assert_type (inv (c128_3d ), onp .Array3D [np .complex128 ])
354+
355+ assert_type (inv (py_b_3d ), onp .ArrayND [np .float32 ])
356+ assert_type (inv (py_i_3d ), onp .ArrayND [np .float64 ])
357+ assert_type (inv (py_f_3d ), onp .ArrayND [np .float64 ])
358+ assert_type (inv (py_c_3d ), onp .ArrayND [np .complex128 ])
359+
360+ assert_type (inv (b1_nd ), onp .ArrayND [np .float32 ])
361+ assert_type (inv (i8_nd ), onp .ArrayND [np .float32 ])
362+ assert_type (inv (f16_nd ), onp .ArrayND [np .float32 ])
363+ assert_type (inv (f32_nd ), onp .ArrayND [np .float32 ])
364+ assert_type (inv (i32_nd ), onp .ArrayND [np .float64 ])
365+ assert_type (inv (f64_nd ), onp .ArrayND [np .float64 ])
366+ assert_type (inv (f80_nd ), onp .ArrayND [np .float64 ])
367+ assert_type (inv (c64_nd ), onp .ArrayND [np .complex64 ])
368+ assert_type (inv (c128_nd ), onp .ArrayND [np .complex128 ])
369+ assert_type (inv (c160_nd ), onp .ArrayND [np .complex128 ])
370+
371+ ###
372+ # TODO(jorenham): det
373+
374+ ###
375+ # TODO(jorenham): lstsq
376+
377+ ###
378+ # TODO(jorenham): pinv[h]
379+
380+ ###
381+ # TODO(jorenham): matrix_balance
382+
383+ ###
384+ # TODO(jorenham): matmul_toeplitz
0 commit comments