11import ctypes
2+ from collections .abc import Iterable
23from typing import Any
34
45import mlir .runtime as rt
89import numpy as np
910import scipy .sparse as sps
1011
11- from ._common import RefableList , _hold_self_ref_in_ret , _take_owneship , fn_cache
12+ from ._common import PackedArgumentTuple , _hold_self_ref_in_ret , _take_owneship , fn_cache
1213from ._core import ctx , libc
1314from ._dtypes import DType , asdtype
1415
@@ -118,14 +119,16 @@ class Coo(ctypes.Structure):
118119 _index_dtype = index_dtype
119120
120121 @classmethod
121- def from_sps (cls , arr : sps .coo_array | np .ndarray ) -> "Coo" :
122+ def from_sps (cls , arr : sps .coo_array | Iterable [ np .ndarray ] ) -> "Coo" :
122123 if isinstance (arr , sps .coo_array ):
123- assert arr .has_canonical_format , "COO must have canonical format"
124+ if not arr .has_canonical_format :
125+ raise Exception ("COO must have canonical format" )
124126 np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
125127 np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
126128 np_data = arr .data
127129 else :
128- assert len (arr ) == 3 , "COO must be comprised of three arrays"
130+ if len (arr ) != 3 :
131+ raise Exception ("COO must be comprised of three arrays" )
129132 np_pos , np_coords , np_data = arr
130133
131134 pos = numpy_to_ranked_memref (np_pos )
@@ -142,7 +145,11 @@ def to_sps(self, shape: tuple[int, ...]) -> sps.coo_array | list[np.ndarray]:
142145 pos = ranked_memref_to_numpy (self .pos )
143146 coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
144147 data = ranked_memref_to_numpy (self .data )
145- return sps .coo_array ((data , coords .T ), shape = shape ) if len (shape ) == 2 else RefableList ([pos , coords , data ])
148+ return (
149+ sps .coo_array ((data , coords .T ), shape = shape )
150+ if len (shape ) == 2
151+ else PackedArgumentTuple ((pos , coords , data ))
152+ )
146153
147154 def to_module_arg (self ) -> list :
148155 return [
@@ -201,7 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
201208 return csf_instance
202209
203210 def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
204- return RefableList ( ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
211+ return PackedArgumentTuple ( tuple ( ranked_memref_to_numpy (field ) for field in self .get__fields_ () ))
205212
206213 def to_module_arg (self ) -> list :
207214 return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
0 commit comments