@@ -11,8 +11,11 @@ import optype.typing as opt
1111from ._base import _spbase , sparray
1212from ._bsr import bsr_array , bsr_matrix
1313from ._coo import coo_array , coo_matrix
14+ from ._csc import csc_array , csc_matrix
1415from ._csr import csr_array , csr_matrix
1516from ._dia import dia_array , dia_matrix
17+ from ._dok import dok_array , dok_matrix
18+ from ._lil import lil_array , lil_matrix
1619from ._matrix import spmatrix
1720from ._typing import Numeric , SPFormat , ToShape2D , _CanStack , _CanStackAs
1821
@@ -43,7 +46,7 @@ _ShapeT = TypeVar("_ShapeT", bound=tuple[int, *tuple[int, ...]], default=tuple[A
4346_ToArray1D : TypeAlias = Seq [_SCT ] | onp .CanArrayND [_SCT ]
4447_ToArray2D : TypeAlias = Seq [Seq [_SCT | int ] | onp .CanArrayND [_SCT ]] | onp .CanArrayND [_SCT ]
4548_ToSpMatrix : TypeAlias = spmatrix [_SCT ] | _ToArray2D [_SCT ]
46- _ToSparse : TypeAlias = _spbase [_SCT ] | _ToArray2D [_SCT ]
49+ _ToSparse2D : TypeAlias = _spbase [_SCT , tuple [ int , int ] ] | _ToArray2D [_SCT ]
4750
4851_SpBase : TypeAlias = _spbase [_SCT , _ShapeT ] | Any
4952_SpMatrix : TypeAlias = spmatrix [_SCT ] | Any
@@ -54,15 +57,21 @@ _SpArray1D: TypeAlias = _SpArray[_SCT, tuple[int]]
5457_SpArray2D : TypeAlias = _SpArray [_SCT , tuple [int , int ]]
5558
5659_BSRArray : TypeAlias = bsr_array [_SCT ]
57- _CSRArray : TypeAlias = csr_array [_SCT , tuple [int , int ]]
60+ _COOArray2D : TypeAlias = coo_array [_SCT , tuple [int , int ]]
61+ _CSCArray : TypeAlias = csc_array [_SCT ]
62+ _CSRArray2D : TypeAlias = csr_array [_SCT , tuple [int , int ]]
63+ _DIAArray : TypeAlias = dia_array [_SCT ]
64+ _DOKArray2D : TypeAlias = dok_array [_SCT , tuple [int , int ]]
65+ _LILArray : TypeAlias = lil_array [_SCT ]
5866
5967_FmtBSR : TypeAlias = Literal ["bsr" ]
6068_FmtCOO : TypeAlias = Literal ["coo" ]
69+ _FmtCSC : TypeAlias = Literal ["csc" ]
6170_FmtCSR : TypeAlias = Literal ["csr" ]
6271_FmtDIA : TypeAlias = Literal ["dia" ]
63- _FmtNonBSR : TypeAlias = Literal ["coo" , "csc" , "csr" , "dia" , "dok" , "lil" ]
72+ _FmtDOK : TypeAlias = Literal ["dok" ]
73+ _FmtLIL : TypeAlias = Literal ["lil" ]
6474_FmtNonCOO : TypeAlias = Literal ["bsr" , "csc" , "csr" , "dia" , "dok" , "lil" ]
65- _FmtNonCSR : TypeAlias = Literal ["bsr" , "coo" , "csc" , "dia" , "dok" , "lil" ]
6675_FmtNonDIA : TypeAlias = Literal ["bsr" , "coo" , "csc" , "csr" , "dok" , "lil" ]
6776
6877_DataRVS : TypeAlias = Callable [[int ], onp .ArrayND [Numeric ]]
@@ -506,32 +515,92 @@ def eye(
506515#
507516@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"bsr", None} = ...
508517def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtBSR | None = None ) -> bsr_matrix [_SCT ]: ...
509- @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
510- def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtNonBSR ) -> _SpMatrix [_SCT ]: ...
511- @overload # A: sparray, B: sparse, format: {"bsr", None} = ...
512- def kron (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
513- @overload # A: sparray, B: sparse, format: <otherwise>
514- def kron (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtNonBSR ) -> _SpArray2D [_SCT ]: ...
518+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
519+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCOO ) -> coo_matrix [_SCT ]: ...
520+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
521+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSC ) -> csc_matrix [_SCT ]: ...
522+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csr"
523+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSR ) -> csr_matrix [_SCT ]: ...
524+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
525+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDIA ) -> dia_matrix [_SCT ]: ...
526+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
527+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDOK ) -> dok_matrix [_SCT ]: ...
528+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
529+ def kron (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtLIL ) -> lil_matrix [_SCT ]: ...
530+ @overload # A: sparray, B: 2D sparse, format: {"bsr", None} = ...
531+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
532+ @overload # A: sparray, B: sparse, format: "coo"
533+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
534+ @overload # A: sparray, B: sparse, format: "csc"
535+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
536+ @overload # A: sparray, B: sparse, format: "csr"
537+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSR ) -> _CSRArray2D [_SCT ]: ...
538+ @overload # A: sparray, B: sparse, format: "dia"
539+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
540+ @overload # A: sparray, B: sparse, format: "dok"
541+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
542+ @overload # A: sparray, B: sparse, format: "lil"
543+ def kron (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
515544@overload # A: sparse, B: sparray, format: {"bsr", None} = ...
516- def kron (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
517- @overload # A: sparse, B: sparray, format: <otherwise>
518- def kron (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtNonBSR ) -> _SpArray2D [_SCT ]: ...
545+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtBSR | None = None ) -> _BSRArray [_SCT ]: ...
546+ @overload # A: sparray, B: sparse, format: "coo"
547+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
548+ @overload # A: sparray, B: sparse, format: "csc"
549+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
550+ @overload # A: sparray, B: sparse, format: "csr"
551+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSR ) -> _CSRArray2D [_SCT ]: ...
552+ @overload # A: sparray, B: sparse, format: "dia"
553+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
554+ @overload # A: sparray, B: sparse, format: "dok"
555+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
556+ @overload # A: sparray, B: sparse, format: "lil"
557+ def kron (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
519558@overload # A: unknown array-like, B: unknown array-like (catch-all)
520559def kron (A : onp .ToComplex2D , B : onp .ToComplex2D , format : SPFormat | None = None ) -> _SpBase2D [Incomplete ]: ...
521560
522561# NOTE: The `overload-overlap` mypy errors are false positives.
523562@overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: {"csr", None} = ...
524563def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSR | None = None ) -> csr_matrix [_SCT ]: ...
525- @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: <otherwise>
526- def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtNonCSR ) -> _SpMatrix [_SCT ]: ...
564+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "bsr"
565+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtBSR ) -> bsr_matrix [_SCT ]: ...
566+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "coo"
567+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCOO ) -> coo_matrix [_SCT ]: ...
568+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "csc"
569+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtCSC ) -> csc_matrix [_SCT ]: ...
570+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dia"
571+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDIA ) -> dia_matrix [_SCT ]: ...
572+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "dok"
573+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtDOK ) -> dok_matrix [_SCT ]: ...
574+ @overload # A: spmatrix or 2d array-like, B: spmatrix or 2d array-like, format: "lil"
575+ def kronsum (A : _ToSpMatrix [_SCT ], B : _ToSpMatrix [_SCT ], format : _FmtLIL ) -> lil_matrix [_SCT ]: ...
527576@overload # A: sparray, B: sparse, format: {"csr", None} = ...
528- def kronsum (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray [_SCT ]: ...
529- @overload # A: sparray, B: sparse, format: <otherwise>
530- def kronsum (A : sparray [_SCT ], B : _ToSparse [_SCT ], format : _FmtNonCSR ) -> _SpArray2D [_SCT ]: ...
577+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray2D [_SCT ]: ...
578+ @overload # A: sparray, B: sparse, format: "bsr"
579+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtBSR ) -> _BSRArray [_SCT ]: ...
580+ @overload # A: sparray, B: sparse, format: "coo"
581+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
582+ @overload # A: sparray, B: sparse, format: "csc"
583+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
584+ @overload # A: sparray, B: sparse, format: "dia"
585+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
586+ @overload # A: sparray, B: sparse, format: "dok"
587+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
588+ @overload # A: sparray, B: sparse, format: "lil"
589+ def kronsum (A : sparray [_SCT , tuple [int , int ]], B : _ToSparse2D [_SCT ], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
531590@overload # A: sparse, B: sparray, format: {"csr", None} = ...
532- def kronsum (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtCSR | None = None ) -> _CSRArray [_SCT ]: ...
533- @overload # A: sparse, B: sparray, format: <otherwise>
534- def kronsum (A : _ToSparse [_SCT ], B : sparray [_SCT ], format : _FmtNonCSR ) -> _SpArray2D [_SCT ]: ...
591+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSR | None = None ) -> _CSRArray2D [_SCT ]: ...
592+ @overload # A: sparse, B: sparray, format: "bsr"
593+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtBSR ) -> _BSRArray [_SCT ]: ...
594+ @overload # A: sparse, B: sparray, format: "coo"
595+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCOO ) -> _COOArray2D [_SCT ]: ...
596+ @overload # A: sparse, B: sparray, format: "csc"
597+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtCSC ) -> _CSCArray [_SCT ]: ...
598+ @overload # A: sparse, B: sparray, format: "dia"
599+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDIA ) -> _DIAArray [_SCT ]: ...
600+ @overload # A: sparse, B: sparray, format: "dok"
601+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtDOK ) -> _DOKArray2D [_SCT ]: ...
602+ @overload # A: sparse, B: sparray, format: "lil"
603+ def kronsum (A : _ToSparse2D [_SCT ], B : sparray [_SCT , tuple [int , int ]], format : _FmtLIL ) -> _LILArray [_SCT ]: ...
535604@overload # A: unknown array-like, B: unknown array-like (catch-all)
536605def kronsum (A : onp .ToComplex2D , B : onp .ToComplex2D , format : SPFormat | None = None ) -> _SpBase2D [Incomplete ]: ...
537606
@@ -571,8 +640,6 @@ def vstack(blocks: Seq[_CanStackAs[Any, _T]], format: None = None, *, dtype: npt
571640@overload # TODO(jorenham): Support for `format=...`
572641def vstack (blocks : Seq [_spbase ], format : SPFormat , dtype : npt .DTypeLike | None = None ) -> Incomplete : ...
573642
574- _COOArray2D : TypeAlias = coo_array [_SCT , tuple [int , int ]]
575-
576643# TODO(jorenham): Use `_CanStack` here, which requires a way to map matrix types to array types.
577644@overload # blocks: <known dtype>, format: <default>, dtype: <default>
578645def block_array (blocks : _ToBlocks [_SCT ], * , format : _FmtCOO | None = None , dtype : None = None ) -> _COOArray2D [_SCT ]: ...
0 commit comments