@@ -274,6 +274,59 @@ def test_all(library, module):
274274 assert not fails , "Missing exports: %s" % fails
275275
276276
277+ @pytest .mark .parametrize ("module" , list (NAMES ))
278+ @pytest .mark .parametrize ("library" , wrapped_libraries )
279+ def test_compat_doesnt_hide_names (library , module ):
280+ """The base namespace can have more names than the ones explicitly exported
281+ by array-api-compat. Test that we're not suppressing them.
282+ """
283+ bare_xp = pytest .importorskip (library )
284+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
285+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
286+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
287+ aapi_names = set (NAMES [module ])
288+ extra_names = {
289+ name
290+ for name in dir (bare_mod )
291+ if not name .startswith ("_" ) and name not in aapi_names
292+ }
293+ missing = extra_names - set (dir (compat_mod ))
294+
295+ # These are spurious to begin with in the bare libraries
296+ missing -= {"annotations" , "importlib" , "warnings" , "operator" , "sys" , "Sequence" }
297+ if module != "" :
298+ missing -= {"Array" , "test" }
299+
300+ assert not missing , "Non-Array API names have been hidden: %s" % missing
301+
302+
303+ @pytest .mark .parametrize ("module" , list (NAMES ))
304+ @pytest .mark .parametrize ("library" , wrapped_libraries )
305+ def test_compat_spurious_names (library , module ):
306+ """Test that array-api-compat isn't adding non-Array API names
307+ to the namespace.
308+ """
309+ bare_xp = pytest .importorskip (library )
310+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
311+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
312+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
313+ aapi_names = set (NAMES [module ])
314+ compat_spurious_names = (
315+ set (dir (compat_mod ))
316+ - set (dir (bare_mod ))
317+ - aapi_names
318+ - {"__all__" }
319+ )
320+ # Quietly ignore *Result dataclasses
321+ compat_spurious_names = {
322+ name for name in compat_spurious_names if not name .endswith ("Result" )
323+ }
324+
325+ assert not compat_spurious_names , (
326+ "array-api-compat is adding non-Array API names: %s" % compat_spurious_names
327+ )
328+
329+
277330@pytest .mark .parametrize (
278331 "name" , [name for name in NAMES ["" ] if hasattr (builtins , name )]
279332)
0 commit comments