@@ -249,6 +249,59 @@ def test_dir(library, module):
249249 assert not fails , "Missing exports: %s" % fails
250250
251251
252+ @pytest .mark .parametrize ("module" , list (NAMES ))
253+ @pytest .mark .parametrize ("library" , wrapped_libraries )
254+ def test_compat_doesnt_hide_names (library , module ):
255+ """The base namespace can have more names than the ones explicitly exported
256+ by array-api-compat. Test that we're not suppressing them.
257+ """
258+ bare_xp = pytest .importorskip (library )
259+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
260+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
261+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
262+ aapi_names = set (NAMES [module ])
263+ extra_names = {
264+ name
265+ for name in dir (bare_mod )
266+ if not name .startswith ("_" ) and name not in aapi_names
267+ }
268+ missing = extra_names - set (dir (compat_mod ))
269+
270+ # These are spurious to begin with in the bare libraries
271+ missing -= {"annotations" , "importlib" , "warnings" , "operator" , "sys" , "Sequence" }
272+ if module != "" :
273+ missing -= {"Array" , "test" }
274+
275+ assert not missing , "Non-Array API names have been hidden: %s" % missing
276+
277+
278+ @pytest .mark .parametrize ("module" , list (NAMES ))
279+ @pytest .mark .parametrize ("library" , wrapped_libraries )
280+ def test_compat_spurious_names (library , module ):
281+ """Test that array-api-compat isn't adding non-Array API names
282+ to the namespace.
283+ """
284+ bare_xp = pytest .importorskip (library )
285+ compat_xp = pytest .importorskip (f"array_api_compat.{ library } " )
286+ bare_mod = getattr (bare_xp , module ) if module else bare_xp
287+ compat_mod = getattr (compat_xp , module ) if module else compat_xp
288+ aapi_names = set (NAMES [module ])
289+ compat_spurious_names = (
290+ set (dir (compat_mod ))
291+ - set (dir (bare_mod ))
292+ - aapi_names
293+ - {"__all__" }
294+ )
295+ # Quietly ignore *Result dataclasses
296+ compat_spurious_names = {
297+ name for name in compat_spurious_names if not name .endswith ("Result" )
298+ }
299+
300+ assert not compat_spurious_names , (
301+ "array-api-compat is adding non-Array API names: %s" % compat_spurious_names
302+ )
303+
304+
252305@pytest .mark .parametrize (
253306 "name" , [name for name in NAMES ["" ] if hasattr (builtins , name )]
254307)
0 commit comments