Skip to content

Conversation

@mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Nov 4, 2025

Reference issue

Toward gh-20544

What does this implement/fix?

Adds array API support to scipy.stats.epps_singleton_2samp.

Additional information

Updates array_api_extra to bring in the new vectorized cov.

@mdhaber mdhaber added the enhancement A new feature or improvement label Nov 4, 2025
@mdhaber mdhaber added array types Items related to array API support and input array validation (see gh-18286) scipy._lib and removed scipy._lib labels Nov 4, 2025
@mdhaber
Copy link
Contributor Author

mdhaber commented Nov 5, 2025

Most CI failures are data-apis/array-api-extra#454. I'll wait for that to be resolved before re-running CI.

Comment on lines +32 to +33
@xp_capabilities(skip_backends=[("dask.array", "lazy -> no _axis_nan_policy")],
jax_jit=False)
Copy link
Contributor Author

@mdhaber mdhaber Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this resolves the CI failures, although it's hard to tell with all the noise.

Suggested change
@xp_capabilities(skip_backends=[("dask.array", "lazy -> no _axis_nan_policy")],
jax_jit=False)
@xp_capabilities(skip_backends=[("dask.array", "lazy -> no _axis_nan_policy"),
("jax.numpy', 'lazy -> no _axis_nan_policy")])

@lucascolley
Copy link
Member

lucascolley commented Nov 5, 2025

The following rudimentary diff makes the tests pass with NumPy for me:

diff --git a/scipy/spatial/transform/_rigid_transform.py b/scipy/spatial/transform/_rigid_transform.py
index 4d40b98ea4..e58954fee4 100644
--- a/scipy/spatial/transform/_rigid_transform.py
+++ b/scipy/spatial/transform/_rigid_transform.py
@@ -410,3 +410,3 @@ class RigidTransform:
         if self._single:
-            matrix = xpx.atleast_nd(matrix, ndim=3, xp=xp)
+            matrix = xp.expand_dims(matrix, axis=0)

@@ -993,5 +993,14 @@ class RigidTransform:
         xp = array_namespace(transforms[0].as_matrix())
-        matrix = xp.concat(
-            [xpx.atleast_nd(x.as_matrix(), ndim=3, xp=xp) for x in transforms]
-        )
+        expanded_transforms = []
+        for x in transforms:
+            x = x.as_matrix()
+            if x.ndim == 3:
+                expanded_transforms.append(x)
+            elif x.ndim == 2:
+                expanded_transforms.append(xp.expand_dims(x, axis=0))
+            else:
+                raise ValueError("panic")
+        matrix = xp.concat(expanded_transforms)
         return RigidTransform._from_raw_matrix(matrix, xp, None)
@@ -1936,3 +1945,3 @@ class RigidTransform:
         if tf._single:
-            matrix = xpx.atleast_nd(matrix, ndim=3, xp=xp)
+            matrix = xp.expand_dims(matrix, axis=0)
         tf._matrix = matrix

@scottshambaugh @amacati what do you think about merging a polished version of this diff?

Long-term I think we should just enforce that atleast_nd adds dimensions to the front, so the existing code would be fine. But right now it is relying on undefined behaviour.

EDIT: context: data-apis/array-api-extra#454 (comment)

@amacati
Copy link
Collaborator

amacati commented Nov 5, 2025

At first glance this seems incorrect.

+            x = x.as_matrix()
+            if x.ndim == 3:
+                expanded_transforms.append(x)
+            elif x.ndim == 2:
+                expanded_transforms.append(xp.expand_dims(x, axis=0))
+            else:
+                raise ValueError("panic")

RigidTransform can be a batch of multiple leading dimensions. Concatenating several 4D RigidTransforms is well-defined, but would error with this loop, right?

@lucascolley
Copy link
Member

in which case we should add test coverage for that too!

@amacati
Copy link
Collaborator

amacati commented Nov 5, 2025

Yes, I am surprised that this was not picked up already.

@amacati
Copy link
Collaborator

amacati commented Nov 5, 2025

Yeah, looking at the concatenate tests, the Nd case is not covered. It would also make sense to introduce an axis argument, but that's maybe something for later, or should be implemented if people actually need it. I will try to expand the test coverage asap, just a bit short of time these days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy._lib scipy.stats

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants