@@ -49,6 +49,55 @@ def free_memref(obj: ctypes.Structure) -> None:
4949###########
5050
5151
52+ @fn_cache
53+ def get_sparse_vector_class (
54+ values_dtype : type [DType ],
55+ index_dtype : type [DType ],
56+ ) -> type [ctypes .Structure ]:
57+ class SparseVector (ctypes .Structure ):
58+ _fields_ = [
59+ ("indptr" , get_nd_memref_descr (1 , index_dtype )),
60+ ("indices" , get_nd_memref_descr (1 , index_dtype )),
61+ ("data" , get_nd_memref_descr (1 , values_dtype )),
62+ ]
63+ dtype = values_dtype
64+ _index_dtype = index_dtype
65+
66+ @classmethod
67+ def from_sps (cls , arrs : list [np .ndarray ]) -> "SparseVector" :
68+ sv_instance = cls (* [numpy_to_ranked_memref (arr ) for arr in arrs ])
69+ for arr in arrs :
70+ _take_owneship (sv_instance , arr )
71+ return sv_instance
72+
73+ def to_sps (self , shape : tuple [int , ...]) -> int :
74+ return PackedArgumentTuple (tuple (ranked_memref_to_numpy (field ) for field in self .get__fields_ ()))
75+
76+ def to_module_arg (self ) -> list :
77+ return [
78+ ctypes .pointer (ctypes .pointer (self .indptr )),
79+ ctypes .pointer (ctypes .pointer (self .indices )),
80+ ctypes .pointer (ctypes .pointer (self .data )),
81+ ]
82+
83+ def get__fields_ (self ) -> list :
84+ return [self .indptr , self .indices , self .data ]
85+
86+ @classmethod
87+ @fn_cache
88+ def get_tensor_definition (cls , shape : tuple [int , ...]) -> ir .RankedTensorType :
89+ with ir .Location .unknown (ctx ):
90+ values_dtype = cls .dtype .get_mlir_type ()
91+ index_dtype = cls ._index_dtype .get_mlir_type ()
92+ index_width = getattr (index_dtype , "width" , 0 )
93+ levels = (sparse_tensor .LevelFormat .compressed ,)
94+ ordering = ir .AffineMap .get_permutation ([0 ])
95+ encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
96+ return ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
97+
98+ return SparseVector
99+
100+
52101@fn_cache
53102def get_csx_class (
54103 values_dtype : type [DType ],
@@ -302,6 +351,16 @@ def get_csx_scipy_class(order: str) -> type[sps.sparray]:
302351 raise Exception (f"Invalid order: { order } " )
303352
304353
354+ _constructor_class_dict = {
355+ "csr" : get_csx_class ,
356+ "csc" : get_csx_class ,
357+ "csf" : get_csf_class ,
358+ "coo" : get_coo_class ,
359+ "sparse_vector" : get_sparse_vector_class ,
360+ "dense" : get_dense_class ,
361+ }
362+
363+
305364################
306365# Tensor class #
307366################
@@ -346,8 +405,8 @@ def __init__(
346405 self ._obj = obj
347406
348407 elif format is not None :
349- if format in ["csf" , "coo" ]:
350- fn_format_class = get_csf_class if format == "csf" else get_coo_class
408+ if format in ["csf" , "coo" , "sparse_vector" ]:
409+ fn_format_class = _constructor_class_dict [ format ]
351410 self ._owns_memory = False
352411 self ._index_dtype = asdtype (np .intp )
353412 self ._format_class = fn_format_class (self ._values_dtype , self ._index_dtype )
0 commit comments