11import ctypes
2+ from typing import Any
23
34import mlir .runtime as rt
45from mlir import ir
@@ -48,18 +49,23 @@ def free_memref(obj: ctypes.Structure) -> None:
4849
4950
5051@fn_cache
51- def get_csr_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
52- class Csr (ctypes .Structure ):
52+ def get_csx_class (
53+ values_dtype : type [DType ],
54+ index_dtype : type [DType ],
55+ order : str ,
56+ ) -> type [ctypes .Structure ]:
57+ class Csx (ctypes .Structure ):
5358 _fields_ = [
5459 ("indptr" , get_nd_memref_descr (1 , index_dtype )),
5560 ("indices" , get_nd_memref_descr (1 , index_dtype )),
5661 ("data" , get_nd_memref_descr (1 , values_dtype )),
5762 ]
5863 dtype = values_dtype
5964 _index_dtype = index_dtype
65+ _order = order
6066
6167 @classmethod
62- def from_sps (cls , arr : sps .csr_array ) -> "Csr " :
68+ def from_sps (cls , arr : sps .csr_array | sps . csc_array ) -> "Csx " :
6369 indptr = numpy_to_ranked_memref (arr .indptr )
6470 indices = numpy_to_ranked_memref (arr .indices )
6571 data = numpy_to_ranked_memref (arr .data )
@@ -69,11 +75,11 @@ def from_sps(cls, arr: sps.csr_array) -> "Csr":
6975
7076 return csr_instance
7177
72- def to_sps (self , shape : tuple [int , ...]) -> sps .csr_array :
78+ def to_sps (self , shape : tuple [int , ...]) -> sps .csr_array | sps . csc_array :
7379 pos = ranked_memref_to_numpy (self .indptr )
7480 crd = ranked_memref_to_numpy (self .indices )
7581 data = ranked_memref_to_numpy (self .data )
76- return sps . csr_array ((data , crd , pos ), shape = shape )
82+ return get_csx_scipy_class ( self . _order ) ((data , crd , pos ), shape = shape )
7783
7884 def to_module_arg (self ) -> list :
7985 return [
@@ -93,15 +99,15 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
9399 index_dtype = cls ._index_dtype .get_mlir_type ()
94100 index_width = getattr (index_dtype , "width" , 0 )
95101 levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .compressed )
96- ordering = ir .AffineMap .get_permutation ([ 0 , 1 ] )
102+ ordering = ir .AffineMap .get_permutation (get_order_tuple ( cls . _order ) )
97103 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
98104 return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
99105
100- return Csr
106+ return Csx
101107
102108
103109@fn_cache
104- def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
110+ def get_coo_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type [ ctypes . Structure ] :
105111 class Coo (ctypes .Structure ):
106112 _fields_ = [
107113 ("pos" , get_nd_memref_descr (1 , index_dtype )),
@@ -162,12 +168,61 @@ def get_tensor_definition(cls, shape: tuple[int, ...]) -> ir.RankedTensorType:
162168
163169
164170@fn_cache
165- def get_csf_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
166- raise NotImplementedError
171+ def get_csf_class (
172+ values_dtype : type [DType ],
173+ index_dtype : type [DType ],
174+ ) -> type [ctypes .Structure ]:
175+ class Csf (ctypes .Structure ):
176+ _fields_ = [
177+ ("indptr_1" , get_nd_memref_descr (1 , index_dtype )),
178+ ("indices_1" , get_nd_memref_descr (1 , index_dtype )),
179+ ("indptr_2" , get_nd_memref_descr (1 , index_dtype )),
180+ ("indices_2" , get_nd_memref_descr (1 , index_dtype )),
181+ ("data" , get_nd_memref_descr (1 , values_dtype )),
182+ ]
183+ dtype = values_dtype
184+ _index_dtype = index_dtype
185+
186+ @classmethod
187+ def from_sps (cls , arrs : list [np .ndarray ]) -> "Csf" :
188+ csf_instance = cls (* [numpy_to_ranked_memref (arr ) for arr in arrs ])
189+ for arr in arrs :
190+ _take_owneship (csf_instance , arr )
191+ return csf_instance
192+
193+ def to_sps (self , shape : tuple [int , ...]) -> list [np .ndarray ]:
194+ class List (list ):
195+ pass
196+
197+ return List (ranked_memref_to_numpy (field ) for field in self .get__fields_ ())
198+
199+ def to_module_arg (self ) -> list :
200+ return [ctypes .pointer (ctypes .pointer (field )) for field in self .get__fields_ ()]
201+
202+ def get__fields_ (self ) -> list :
203+ return [self .indptr_1 , self .indices_1 , self .indptr_2 , self .indices_2 , self .data ]
204+
205+ @classmethod
206+ @fn_cache
207+ def get_tensor_definition (cls , shape : tuple [int , ...]) -> ir .RankedTensorType :
208+ with ir .Location .unknown (ctx ):
209+ values_dtype = cls .dtype .get_mlir_type ()
210+ index_dtype = cls ._index_dtype .get_mlir_type ()
211+ index_width = getattr (index_dtype , "width" , 0 )
212+ levels = (
213+ sparse_tensor .LevelFormat .dense ,
214+ sparse_tensor .LevelFormat .compressed ,
215+ sparse_tensor .LevelFormat .compressed ,
216+ )
217+ ordering = ir .AffineMap .get_permutation ([0 , 1 , 2 ])
218+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
219+ return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
220+
221+ return Csf
167222
168223
169224@fn_cache
170- def get_dense_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type :
225+ def get_dense_class (values_dtype : type [DType ], index_dtype : type [DType ]) -> type [ ctypes . Structure ] :
171226 class Dense (ctypes .Structure ):
172227 _fields_ = [
173228 ("data" , get_nd_memref_descr (1 , values_dtype )),
@@ -221,22 +276,42 @@ def _is_mlir_obj(x) -> bool:
221276 return isinstance (x , ctypes .Structure )
222277
223278
279+ def get_order_tuple (order : str ) -> tuple [int , int ]:
280+ if order in ("r" , "c" ):
281+ return (0 , 1 ) if order == "r" else (1 , 0 )
282+ raise Exception (f"Invalid order: { order } " )
283+
284+
285+ def get_csx_scipy_class (order : str ) -> type [sps .sparray ]:
286+ if order in ("r" , "c" ):
287+ return sps .csr_array if order == "r" else sps .csc_array
288+ raise Exception (f"Invalid order: { order } " )
289+
290+
224291################
225292# Tensor class #
226293################
227294
228295
229296class Tensor :
230- def __init__ (self , obj , shape = None ) -> None :
297+ def __init__ (
298+ self ,
299+ obj : Any ,
300+ shape : tuple [int , ...] | None = None ,
301+ dtype : type [DType ] | None = None ,
302+ format : str | None = None ,
303+ ) -> None :
231304 self .shape = shape if shape is not None else obj .shape
232- self ._values_dtype = asdtype (obj .dtype )
305+ self .ndim = len (self .shape )
306+ self ._values_dtype = dtype if dtype is not None else asdtype (obj .dtype )
233307
234308 if _is_scipy_sparse_obj (obj ):
235309 self ._owns_memory = False
236310
237- if obj .format == "csr" :
311+ if obj .format in ("csr" , "csc" ):
312+ order = "r" if obj .format == "csr" else "c"
238313 index_dtype = asdtype (obj .indptr .dtype )
239- self ._format_class = get_csr_class (self ._values_dtype , index_dtype )
314+ self ._format_class = get_csx_class (self ._values_dtype , index_dtype , order )
240315 self ._obj = self ._format_class .from_sps (obj )
241316 elif obj .format == "coo" :
242317 index_dtype = asdtype (obj .coords [0 ].dtype )
@@ -256,6 +331,15 @@ def __init__(self, obj, shape=None) -> None:
256331 self ._format_class = type (obj )
257332 self ._obj = obj
258333
334+ elif format is not None :
335+ if format == "csf" :
336+ self ._owns_memory = False
337+ index_dtype = asdtype (np .intp )
338+ self ._format_class = get_csf_class (self ._values_dtype , index_dtype )
339+ self ._obj = self ._format_class .from_sps (obj )
340+ else :
341+ raise Exception (f"Format { format } not supported." )
342+
259343 else :
260344 raise Exception (f"{ type (obj )} not supported." )
261345
@@ -269,5 +353,5 @@ def to_scipy_sparse(self) -> sps.sparray | np.ndarray:
269353 return self ._obj .to_sps (self .shape )
270354
271355
272- def asarray (obj ) -> Tensor :
273- return Tensor (obj )
356+ def asarray (obj , shape = None , dtype = None , format = None ) -> Tensor :
357+ return Tensor (obj , shape , dtype , format )
0 commit comments