@@ -50,7 +50,7 @@ def is_numpy_array(x):
5050 is_torch_array
5151 is_dask_array
5252 is_jax_array
53- is_pydata_sparse
53+ is_pydata_sparse_array
5454 """
5555 # Avoid importing NumPy if it isn't already
5656 if 'numpy' not in sys .modules :
@@ -80,7 +80,7 @@ def is_cupy_array(x):
8080 is_torch_array
8181 is_dask_array
8282 is_jax_array
83- is_pydata_sparse
83+ is_pydata_sparse_array
8484 """
8585 # Avoid importing NumPy if it isn't already
8686 if 'cupy' not in sys .modules :
@@ -107,7 +107,7 @@ def is_torch_array(x):
107107 is_cupy_array
108108 is_dask_array
109109 is_jax_array
110- is_pydata_sparse
110+ is_pydata_sparse_array
111111 """
112112 # Avoid importing torch if it isn't already
113113 if 'torch' not in sys .modules :
@@ -134,7 +134,7 @@ def is_dask_array(x):
134134 is_cupy_array
135135 is_torch_array
136136 is_jax_array
137- is_pydata_sparse
137+ is_pydata_sparse_array
138138 """
139139 # Avoid importing dask if it isn't already
140140 if 'dask.array' not in sys .modules :
@@ -161,7 +161,7 @@ def is_jax_array(x):
161161 is_cupy_array
162162 is_torch_array
163163 is_dask_array
164- is_pydata_sparse
164+ is_pydata_sparse_array
165165 """
166166 # Avoid importing jax if it isn't already
167167 if 'jax' not in sys .modules :
@@ -172,7 +172,7 @@ def is_jax_array(x):
172172 return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
173173
174174
175- def is_pydata_sparse (x ) -> bool :
175+ def is_pydata_sparse_array (x ) -> bool :
176176 """
177177 Return True if `x` is an array from the `sparse` package.
178178
@@ -219,7 +219,7 @@ def is_array_api_obj(x):
219219 or is_torch_array (x ) \
220220 or is_dask_array (x ) \
221221 or is_jax_array (x ) \
222- or is_pydata_sparse (x ) \
222+ or is_pydata_sparse_array (x ) \
223223 or hasattr (x , '__array_namespace__' )
224224
225225def _check_api_version (api_version ):
@@ -288,7 +288,7 @@ def your_function(x, y):
288288 is_torch_array
289289 is_dask_array
290290 is_jax_array
291- is_pydata_sparse
291+ is_pydata_sparse_array
292292
293293 """
294294 if use_compat not in [None , True , False ]:
@@ -348,7 +348,7 @@ def your_function(x, y):
348348 # not have a wrapper submodule for it.
349349 import jax .experimental .array_api as jnp
350350 namespaces .add (jnp )
351- elif is_pydata_sparse (x ):
351+ elif is_pydata_sparse_array (x ):
352352 if use_compat is True :
353353 _check_api_version (api_version )
354354 raise ValueError ("`sparse` does not have an array-api-compat wrapper" )
@@ -451,7 +451,7 @@ def device(x: Array, /) -> Device:
451451 return x .device ()
452452 else :
453453 return x .device
454- elif is_pydata_sparse (x ):
454+ elif is_pydata_sparse_array (x ):
455455 # `sparse` will gain `.device`, so check for this first.
456456 x_device = getattr (x , 'device' , None )
457457 if x_device is not None :
@@ -583,7 +583,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
583583 # This import adds to_device to x
584584 import jax .experimental .array_api # noqa: F401
585585 return x .to_device (device , stream = stream )
586- elif is_pydata_sparse (x ) and device == _device (x ):
586+ elif is_pydata_sparse_array (x ) and device == _device (x ):
587587 # Perform trivial check to return the same array if
588588 # device is same instead of err-ing.
589589 return x
@@ -613,7 +613,7 @@ def size(x):
613613 "is_jax_array" ,
614614 "is_numpy_array" ,
615615 "is_torch_array" ,
616- "is_pydata_sparse " ,
616+ "is_pydata_sparse_array " ,
617617 "size" ,
618618 "to_device" ,
619619]
0 commit comments