@@ -50,6 +50,7 @@ def is_numpy_array(x):
5050 is_torch_array
5151 is_dask_array
5252 is_jax_array
53+ is_pydata_sparse
5354 """
5455 # Avoid importing NumPy if it isn't already
5556 if 'numpy' not in sys .modules :
@@ -79,6 +80,7 @@ def is_cupy_array(x):
7980 is_torch_array
8081 is_dask_array
8182 is_jax_array
83+ is_pydata_sparse
8284 """
8385 # Avoid importing NumPy if it isn't already
8486 if 'cupy' not in sys .modules :
@@ -105,6 +107,7 @@ def is_torch_array(x):
105107 is_cupy_array
106108 is_dask_array
107109 is_jax_array
110+ is_pydata_sparse
108111 """
109112 # Avoid importing torch if it isn't already
110113 if 'torch' not in sys .modules :
@@ -131,6 +134,7 @@ def is_dask_array(x):
131134 is_cupy_array
132135 is_torch_array
133136 is_jax_array
137+ is_pydata_sparse
134138 """
135139 # Avoid importing dask if it isn't already
136140 if 'dask.array' not in sys .modules :
@@ -157,6 +161,7 @@ def is_jax_array(x):
157161 is_cupy_array
158162 is_torch_array
159163 is_dask_array
164+ is_pydata_sparse
160165 """
161166 # Avoid importing jax if it isn't already
162167 if 'jax' not in sys .modules :
@@ -166,6 +171,35 @@ def is_jax_array(x):
166171
167172 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
168173
174+
175+ def is_pydata_sparse (x ) -> bool :
176+ """
177+ Return True if `x` is an array from the `sparse` package.
178+
179+ This function does not import `sparse` if it has not already been imported
180+ and is therefore cheap to use.
181+
182+
183+ See Also
184+ --------
185+
186+ array_namespace
187+ is_array_api_obj
188+ is_numpy_array
189+ is_cupy_array
190+ is_torch_array
191+ is_dask_array
192+ is_jax_array
193+ """
194+ # Avoid importing jax if it isn't already
195+ if 'sparse' not in sys .modules :
196+ return False
197+
198+ import sparse
199+
200+ # TODO: Account for other backends.
201+ return isinstance (x , sparse .SparseArray )
202+
169203def is_array_api_obj (x ):
170204 """
171205 Return True if `x` is an array API compatible array object.
@@ -185,6 +219,7 @@ def is_array_api_obj(x):
185219 or is_torch_array (x ) \
186220 or is_dask_array (x ) \
187221 or is_jax_array (x ) \
222+ or is_pydata_sparse (x ) \
188223 or hasattr (x , '__array_namespace__' )
189224
190225def _check_api_version (api_version ):
@@ -253,6 +288,7 @@ def your_function(x, y):
253288 is_torch_array
254289 is_dask_array
255290 is_jax_array
291+ is_pydata_sparse
256292
257293 """
258294 if use_compat not in [None , True , False ]:
@@ -312,6 +348,15 @@ def your_function(x, y):
312348 # not have a wrapper submodule for it.
313349 import jax .experimental .array_api as jnp
314350 namespaces .add (jnp )
351+ elif is_pydata_sparse (x ):
352+ if use_compat is True :
353+ _check_api_version (api_version )
354+ raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
355+ else :
356+ import sparse
357+ # `sparse` is already an array namespace. We do not have a wrapper
358+ # submodule for it.
359+ namespaces .add (sparse )
315360 elif hasattr (x , '__array_namespace__' ):
316361 if use_compat is True :
317362 raise ValueError ("The given array does not have an array-api-compat wrapper" )
@@ -406,8 +451,23 @@ def device(x: Array, /) -> Device:
406451 return x .device ()
407452 else :
408453 return x .device
454+ elif is_pydata_sparse (x ):
455+ # `sparse` will gain `.device`, so check for this first.
456+ x_device = getattr (x , 'device' , None )
457+ if x_device is not None :
458+ return x_device
459+ # Everything but DOK has this attr.
460+ try :
461+ inner = x .data
462+ except AttributeError :
463+ return "cpu"
464+ # Return the device of the constituent array
465+ return device (inner )
409466 return x .device
410467
468+ # Prevent shadowing, used below
469+ _device = device
470+
411471# Based on cupy.array_api.Array.to_device
412472def _cupy_to_device (x , device , / , stream = None ):
413473 import cupy as cp
@@ -523,6 +583,10 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
523583 # This import adds to_device to x
524584 import jax .experimental .array_api # noqa: F401
525585 return x .to_device (device , stream = stream )
586+ elif is_pydata_sparse (x ) and device == _device (x ):
587+ # Perform trivial check to return the same array if
588+ # device is same instead of err-ing.
589+ return x
526590 return x .to_device (device , stream = stream )
527591
528592def size (x ):
@@ -549,6 +613,7 @@ def size(x):
549613 "is_jax_array" ,
550614 "is_numpy_array" ,
551615 "is_torch_array" ,
616+ "is_pydata_sparse" ,
552617 "size" ,
553618 "to_device" ,
554619]
0 commit comments