88import numpy as np
99import scipy .sparse as sps
1010
11- from ._common import _hold_self_ref_in_ret , _take_owneship , fn_cache
11+ from ._common import RefableList , _hold_self_ref_in_ret , _take_owneship , fn_cache
1212from ._core import ctx , libc
1313from ._dtypes import DType , asdtype
1414
@@ -118,26 +118,31 @@ class Coo(ctypes.Structure):
118118 _index_dtype = index_dtype
119119
120120 @classmethod
121- def from_sps (cls , arr : sps .coo_array ) -> "Coo" :
122- assert arr .has_canonical_format , "COO must have canonical format"
123- np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
124- np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
121+ def from_sps (cls , arr : sps .coo_array | np .ndarray ) -> "Coo" :
122+ if isinstance (arr , sps .coo_array ):
123+ assert arr .has_canonical_format , "COO must have canonical format"
124+ np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
125+ np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
126+ np_data = arr .data
127+ else :
128+ assert len (arr ) == 3 , "COO must be comprised of three arrays"
129+ np_pos , np_coords , np_data = arr
130+
125131 pos = numpy_to_ranked_memref (np_pos )
126132 coords = numpy_to_ranked_memref (np_coords )
127- data = numpy_to_ranked_memref (arr .data )
128-
133+ data = numpy_to_ranked_memref (np_data )
129134 coo_instance = cls (pos = pos , coords = coords , data = data )
130135 _take_owneship (coo_instance , np_pos )
131136 _take_owneship (coo_instance , np_coords )
132- _take_owneship (coo_instance , arr )
137+ _take_owneship (coo_instance , np_data )
133138
134139 return coo_instance
135140
136- def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array :
141+ def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array | list [ np . ndarray ] :
137142 pos = ranked_memref_to_numpy (self .pos )
138143 coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
139144 data = ranked_memref_to_numpy (self .data )
140- return sps .coo_array ((data , coords .T ), shape = shape )
145+ return sps .coo_array ((data , coords .T ), shape = shape ) if len ( shape ) == 2 else RefableList ([ pos , coords , data ])
141146
142147 def to_module_arg (self ) -> list :
143148 return [
@@ -159,8 +164,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159164 compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
160165 sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
161166 )
162- levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
163- ordering = ir .AffineMap .get_permutation ([0 , 1 ])
167+ mid_singleton_lvls = [
168+ sparse_tensor .EncodingAttr .build_level_type (
169+ sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .non_unique ]
170+ )
171+ ] * (len (shape ) - 2 )
172+ levels = (compressed_lvl , * mid_singleton_lvls , sparse_tensor .LevelFormat .singleton )
173+ ordering = ir .AffineMap .get_permutation ([* range (len (shape ))])
164174 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
165175 return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
166176
@@ -191,10 +201,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191201 return csf_instance
192202
193203 def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194- class List (list ):
195- pass
196-
197- return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
204+ return RefableList (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
198205
199206 def to_module_arg (self ) -> list :
200207 return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
@@ -310,20 +317,20 @@ def __init__(
310317
311318 if obj .format in ("csr" , "csc" ):
312319 order = "r" if obj .format == "csr" else "c"
313- index_dtype = asdtype (obj .indptr .dtype )
314- self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
320+ self . _index_dtype = asdtype (obj .indptr .dtype )
321+ self ._format_class = get_csx_class (self ._values_dtype , self . _index_dtype , order )
315322 self ._obj = self ._format_class .from_sps (obj )
316323 elif obj .format == "coo" :
317- index_dtype = asdtype (obj .coords [0 ].dtype )
318- self ._format_class = get_coo_class (self ._values_dtype , index_dtype )
324+ self . _index_dtype = asdtype (obj .coords [0 ].dtype )
325+ self ._format_class = get_coo_class (self ._values_dtype , self . _index_dtype )
319326 self ._obj = self ._format_class .from_sps (obj )
320327 else :
321328 raise Exception (f"{ obj .format } SciPy format not supported." )
322329
323330 elif _is_numpy_obj (obj ):
324331 self ._owns_memory = False
325- index_dtype = asdtype (np .intp )
326- self ._format_class = get_dense_class (self ._values_dtype , index_dtype )
332+ self . _index_dtype = asdtype (np .intp )
333+ self ._format_class = get_dense_class (self ._values_dtype , self . _index_dtype )
327334 self ._obj = self ._format_class .from_sps (obj )
328335
329336 elif _is_mlir_obj (obj ):
@@ -332,11 +339,13 @@ def __init__(
332339 self ._obj = obj
333340
334341 elif format is not None :
335- if format == "csf" :
342+ if format in ["csf" , "coo" ]:
343+ fn_format_class = get_csf_class if format == "csf" else get_coo_class
336344 self ._owns_memory = False
337- index_dtype = asdtype (np .intp )
338- self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
345+ self . _index_dtype = asdtype (np .intp )
346+ self ._format_class = fn_format_class (self ._values_dtype , self . _index_dtype )
339347 self ._obj = self ._format_class .from_sps (obj )
348+
340349 else :
341350 raise Exception (f"Format { format } not supported." )
342351
0 commit comments