11import ctypes
2+ from collections .abc import Iterable
23from typing import Any
34
45import mlir .runtime as rt
89import numpy as np
910import scipy .sparse as sps
1011
11- from ._common import _hold_self_ref_in_ret , _take_owneship , fn_cache
12+ from ._common import PackedArgumentTuple , _hold_self_ref_in_ret , _take_owneship , fn_cache
1213from ._core import ctx , libc
1314from ._dtypes import DType , asdtype
1415
@@ -118,26 +119,37 @@ class Coo(ctypes.Structure):
118119 _index_dtype = index_dtype
119120
120121 @classmethod
121- def from_sps (cls , arr : sps .coo_array ) -> "Coo" :
122- assert arr .has_canonical_format , "COO must have canonical format"
123- np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
124- np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
122+ def from_sps (cls , arr : sps .coo_array | Iterable [np .ndarray ]) -> "Coo" :
123+ if isinstance (arr , sps .coo_array ):
124+ if not arr .has_canonical_format :
125+ raise Exception ("COO must have canonical format" )
126+ np_pos = np .array ([0 , arr .size ], dtype = index_dtype .np_dtype )
127+ np_coords = np .stack (arr .coords , axis = 1 , dtype = index_dtype .np_dtype )
128+ np_data = arr .data
129+ else :
130+ if len (arr ) != 3 :
131+ raise Exception ("COO must be comprised of three arrays" )
132+ np_pos , np_coords , np_data = arr
133+
125134 pos = numpy_to_ranked_memref (np_pos )
126135 coords = numpy_to_ranked_memref (np_coords )
127- data = numpy_to_ranked_memref (arr .data )
128-
136+ data = numpy_to_ranked_memref (np_data )
129137 coo_instance = cls (pos = pos , coords = coords , data = data )
130138 _take_owneship (coo_instance , np_pos )
131139 _take_owneship (coo_instance , np_coords )
132- _take_owneship (coo_instance , arr )
140+ _take_owneship (coo_instance , np_data )
133141
134142 return coo_instance
135143
136- def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array :
144+ def to_sps (self , shape : tuple [int , ...]) -> sps .coo_array | list [ np . ndarray ] :
137145 pos = ranked_memref_to_numpy (self .pos )
138146 coords = ranked_memref_to_numpy (self .coords )[pos [0 ] : pos [1 ]]
139147 data = ranked_memref_to_numpy (self .data )
140- return sps .coo_array ((data , coords .T ), shape = shape )
148+ return (
149+ sps .coo_array ((data , coords .T ), shape = shape )
150+ if len (shape ) == 2
151+ else PackedArgumentTuple ((pos , coords , data ))
152+ )
141153
142154 def to_module_arg (self ) -> list :
143155 return [
@@ -159,8 +171,13 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
159171 compressed_lvl = sparse_tensor .EncodingAttr .build_level_type (
160172 sparse_tensor .LevelFormat .compressed , [sparse_tensor .LevelProperty .non_unique ]
161173 )
162- levels = (compressed_lvl , sparse_tensor .LevelFormat .singleton )
163- ordering = ir .AffineMap .get_permutation ([0 , 1 ])
174+ mid_singleton_lvls = [
175+ sparse_tensor .EncodingAttr .build_level_type (
176+ sparse_tensor .LevelFormat .singleton , [sparse_tensor .LevelProperty .non_unique ]
177+ )
178+ ] * (len (shape ) - 2 )
179+ levels = (compressed_lvl , * mid_singleton_lvls , sparse_tensor .LevelFormat .singleton )
180+ ordering = ir .AffineMap .get_permutation ([* range (len (shape ))])
164181 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
165182 return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
166183
@@ -191,10 +208,7 @@ def from_sps(cls, arrs: list[np.ndarray]) -> "Csf":
191208 return csf_instance
192209
193210 def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194- class List (list ):
195- pass
196-
197- return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
211+ return PackedArgumentTuple (tuple (ranked_memref_to_numpy (field ) for field in self .get__fields_ ()))
198212
199213 def to_module_arg (self ) -> list :
200214 return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
@@ -310,20 +324,20 @@ def __init__(
310324
311325 if obj .format in ("csr" , "csc" ):
312326 order = "r" if obj .format == "csr" else "c"
313- index_dtype = asdtype (obj .indptr .dtype )
314- self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
327+ self . _index_dtype = asdtype (obj .indptr .dtype )
328+ self ._format_class = get_csx_class (self ._values_dtype , self . _index_dtype , order )
315329 self ._obj = self ._format_class .from_sps (obj )
316330 elif obj .format == "coo" :
317- index_dtype = asdtype (obj .coords [0 ].dtype )
318- self ._format_class = get_coo_class (self ._values_dtype , index_dtype )
331+ self . _index_dtype = asdtype (obj .coords [0 ].dtype )
332+ self ._format_class = get_coo_class (self ._values_dtype , self . _index_dtype )
319333 self ._obj = self ._format_class .from_sps (obj )
320334 else :
321335 raise Exception (f"{ obj .format } SciPy format not supported." )
322336
323337 elif _is_numpy_obj (obj ):
324338 self ._owns_memory = False
325- index_dtype = asdtype (np .intp )
326- self ._format_class = get_dense_class (self ._values_dtype , index_dtype )
339+ self . _index_dtype = asdtype (np .intp )
340+ self ._format_class = get_dense_class (self ._values_dtype , self . _index_dtype )
327341 self ._obj = self ._format_class .from_sps (obj )
328342
329343 elif _is_mlir_obj (obj ):
@@ -332,11 +346,13 @@ def __init__(
332346 self ._obj = obj
333347
334348 elif format is not None :
335- if format == "csf" :
349+ if format in ["csf" , "coo" ]:
350+ fn_format_class = get_csf_class if format == "csf" else get_coo_class
336351 self ._owns_memory = False
337- index_dtype = asdtype (np .intp )
338- self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
352+ self . _index_dtype = asdtype (np .intp )
353+ self ._format_class = fn_format_class (self ._values_dtype , self . _index_dtype )
339354 self ._obj = self ._format_class .from_sps (obj )
355+
340356 else :
341357 raise Exception (f"Format { format } not supported." )
342358
0 commit comments