11import numpy as np
22
33from ._common import _check_device
4- from ._compressed import GCXS
4+ from ._compressed import CSC , CSR , GCXS
55from ._coo .core import COO
66from ._sparse_array import SparseArray
77
@@ -145,7 +145,6 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
145145
146146 format = desc ["format" ]
147147 format_err_str = f"Unsupported format: `{ format !r} `."
148- invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
149148
150149 if isinstance (format , str ):
151150 match format :
@@ -180,15 +179,15 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
180179 case _:
181180 raise RuntimeError (format_err_str )
182181
183- format = desc ["format" ]
182+ format = desc ["format" ]["custom" ]
183+ rank = 0
184+ level = format
185+ while "level" in level :
186+ if "rank" not in level :
187+ level ["rank" ] = 1
188+ rank += level ["rank" ]
189+ level = level ["level" ]
184190 if "transpose" not in format :
185- rank = 0
186- level = format
187- while "level" in level :
188- if "rank" not in level :
189- level ["rank" ] = 1
190- rank += level ["rank" ]
191-
192191 format ["transpose" ] = list (range (rank ))
193192
194193 match desc :
@@ -225,25 +224,8 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
225224 coord_arr : np .ndarray = np .from_dlpack (arrs [1 ])
226225 value_arr : np .ndarray = np .from_dlpack (arrs [2 ])
227226
228- if str (coord_arr .dtype ) != coords_dtype :
229- raise BufferError (
230- invalid_dtype_str .format (
231- dtype = str (coord_arr .dtype ),
232- expected = coords_dtype ,
233- )
234- )
235-
236- if value_dtype .startswith ("complex[float" ) and value_dtype .endswith ("]" ):
237- complex_bits = 2 * int (value_arr [len ("complex[float" ) : - len ("]" )])
238- value_dtype : str = f"complex{ complex_bits } "
239-
240- if str (value_arr .dtype ) != value_dtype :
241- raise BufferError (
242- invalid_dtype_str .format (
243- dtype = str (coord_arr .dtype ),
244- expected = coords_dtype ,
245- )
246- )
227+ _check_binsparse_dt (coord_arr , coords_dtype )
228+ _check_binsparse_dt (value_arr , value_dtype )
247229
248230 return COO (
249231 coord_arr [:, start :end ],
@@ -254,5 +236,68 @@ def from_binsparse(arr, /, *, device=None, copy: bool | None = None) -> SparseAr
254236 prune = False ,
255237 idx_dtype = coord_arr .dtype ,
256238 )
239+ case {
240+ "format" : {
241+ "custom" : {
242+ "transpose" : transpose ,
243+ "level" : {
244+ "level_desc" : "dense" ,
245+ "rank" : 1 ,
246+ "level" : {
247+ "level_desc" : "sparse" ,
248+ "rank" : 1 ,
249+ "level" : {
250+ "level_desc" : "element" ,
251+ },
252+ },
253+ },
254+ },
255+ },
256+ "shape" : shape ,
257+ "number_of_stored_values" : nnz ,
258+ "data_types" : {
259+ "pointers_to_1" : ptr_dtype ,
260+ "indices_1" : crd_dtype ,
261+ "values" : val_dtype ,
262+ },
263+ ** _kwargs ,
264+ }:
265+ crd_arr = np .from_dlpack (arrs [0 ])
266+ _check_binsparse_dt (crd_arr , crd_dtype )
267+ ptr_arr = np .from_dlpack (arrs [1 ])
268+ _check_binsparse_dt (ptr_arr , ptr_dtype )
269+ val_arr = np .from_dlpack (arrs [2 ])
270+ _check_binsparse_dt (val_arr , val_dtype )
271+
272+ match transpose :
273+ case [0 , 1 ]:
274+ sparse_type = CSR
275+ case [1 , 0 ]:
276+ sparse_type = CSC
277+ case _:
278+ raise RuntimeError (format_err_str )
279+
280+ return sparse_type ((val_arr , ptr_arr , crd_arr ), shape = shape )
257281 case _:
282+ print (desc )
258283 raise RuntimeError (format_err_str )
284+
285+
286+ def _convert_binsparse_dtype (dt : str ) -> np .dtype :
287+ if dt .startswith ("complex[float" ) and dt .endswith ("]" ):
288+ complex_bits = 2 * int (dt [len ("complex[float" ) : - len ("]" )])
289+ dt : str = f"complex{ complex_bits } "
290+
291+ return np .dtype (dt )
292+
293+
294+ def _check_binsparse_dt (arr : np .ndarray , dt : str ) -> None :
295+ invalid_dtype_str = "Invalid dtype: `{dtype!s}`, expected `{expected!s}`."
296+ dt = _convert_binsparse_dtype (dt )
297+ if dt != arr .dtype :
298+ raise BufferError (
299+ invalid_dtype_str .format (
300+ dtype = arr .dtype ,
301+ expected = dt ,
302+ )
303+ )
0 commit comments