11import ctypes
2- import ctypes .util
32import functools
43import weakref
54
@@ -61,7 +60,7 @@ def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
6160 values_dtype = values_dtype .get_mlir_type ()
6261 index_dtype = index_dtype .get_mlir_type ()
6362 index_width = getattr (index_dtype , "width" , 0 )
64- levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .dense )
63+ levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .dense )
6564 ordering = ir .AffineMap .get_permutation ([0 , 1 ])
6665 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
6766 dense_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
@@ -71,19 +70,19 @@ def get_module(shape: tuple[int], values_dtype: DType, index_dtype: DType):
7170
7271 @func .FuncOp .from_py_func (tensor_1d )
7372 def assemble (data ):
74- return sparse_tensor .assemble (dense_shaped , data , [] )
73+ return sparse_tensor .assemble (dense_shaped , [], data )
7574
7675 @func .FuncOp .from_py_func (dense_shaped )
7776 def disassemble (tensor_shaped ):
7877 data = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], values_dtype )
7978 data , data_len = sparse_tensor .disassemble (
79+ [],
8080 tensor_1d ,
8181 [],
8282 index_dtype ,
83- [],
8483 tensor_shaped ,
85- data ,
8684 [],
85+ data ,
8786 )
8887 shape_x = arith .constant (index_dtype , shape [0 ])
8988 shape_y = arith .constant (index_dtype , shape [1 ])
@@ -154,7 +153,7 @@ def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[D
154153 values_dtype = values_dtype .get_mlir_type ()
155154 index_dtype = index_dtype .get_mlir_type ()
156155 index_width = getattr (index_dtype , "width" , 0 )
157- levels = (sparse_tensor .LevelType .dense , sparse_tensor .LevelType .compressed )
156+ levels = (sparse_tensor .LevelFormat .dense , sparse_tensor .LevelFormat .compressed )
158157 ordering = ir .AffineMap .get_permutation ([0 , 1 ])
159158 encoding = sparse_tensor .EncodingAttr .get (levels , ordering , ordering , index_width , index_width )
160159 csr_shaped = ir .RankedTensorType .get (list (shape ), values_dtype , encoding )
@@ -166,25 +165,25 @@ def get_module(shape: tuple[int], values_dtype: type[DType], index_dtype: type[D
166165
167166 @func .FuncOp .from_py_func (tensor_1d_index , tensor_1d_index , tensor_1d_values )
168167 def assemble (pos , crd , data ):
169- return sparse_tensor .assemble (csr_shaped , data , (pos , crd ))
168+ return sparse_tensor .assemble (csr_shaped , (pos , crd ), data )
170169
171170 @func .FuncOp .from_py_func (csr_shaped )
172171 def disassemble (tensor_shaped ):
173172 pos = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], index_dtype )
174173 crd = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], index_dtype )
175174 data = tensor .EmptyOp ([arith .constant (ir .IndexType .get (), 0 )], values_dtype )
176- data , pos , crd , data_len , pos_len , crd_len = sparse_tensor .disassemble (
177- tensor_1d_values ,
175+ pos , crd , data , pos_len , crd_len , data_len = sparse_tensor .disassemble (
178176 (tensor_1d_index , tensor_1d_index ),
179- index_dtype ,
177+ tensor_1d_values ,
180178 (index_dtype , index_dtype ),
179+ index_dtype ,
181180 tensor_shaped ,
182- data ,
183181 (pos , crd ),
182+ data ,
184183 )
185184 shape_x = arith .constant (index_dtype , shape [0 ])
186185 shape_y = arith .constant (index_dtype , shape [1 ])
187- return data , pos , crd , data_len , pos_len , crd_len , shape_x , shape_y
186+ return pos , crd , data , pos_len , crd_len , data_len , shape_x , shape_y
188187
189188 @func .FuncOp .from_py_func (csr_shaped )
190189 def free_tensor (tensor_shaped ):
@@ -219,12 +218,12 @@ def assemble(cls, module: ir.Module, arr: sps.csr_array) -> ctypes.c_void_p:
219218 def disassemble (cls , module : ir .Module , ptr : ctypes .c_void_p , dtype : type [DType ]) -> sps .csr_array :
220219 class Csr (ctypes .Structure ):
221220 _fields_ = [
222- ("data" , make_memref_ctype (dtype , 1 )),
223221 ("pos" , make_memref_ctype (Index , 1 )),
224222 ("crd" , make_memref_ctype (Index , 1 )),
225- ("data_len " , np . ctypeslib . c_intp ),
223+ ("data " , make_memref_ctype ( dtype , 1 ) ),
226224 ("pos_len" , np .ctypeslib .c_intp ),
227225 ("crd_len" , np .ctypeslib .c_intp ),
226+ ("data_len" , np .ctypeslib .c_intp ),
228227 ("shape_x" , np .ctypeslib .c_intp ),
229228 ("shape_y" , np .ctypeslib .c_intp ),
230229 ]
0 commit comments