diff --git a/sklearnex/utils/_array_api.py b/sklearnex/utils/_array_api.py index b2a241c786..d74e866e8a 100644 --- a/sklearnex/utils/_array_api.py +++ b/sklearnex/utils/_array_api.py @@ -14,7 +14,7 @@ # limitations under the License. # ============================================================================== -"""Tools to support array_api.""" +"""Tools to support array API.""" import math from collections.abc import Callable @@ -27,6 +27,7 @@ from daal4py.sklearn._utils import sklearn_check_version from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace +from .._config import get_config from ..base import oneDALEstimator if sklearn_check_version("1.6"): @@ -83,11 +84,19 @@ def get_namespace(*arrays): True of the arrays are containers that implement the Array API spec. """ - sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) - - if sycl_type: - return xp, is_array_api_compliant - elif sklearn_check_version("1.2"): + # check required because _get_sycl_namespace only verifies that *arrays + # are of the same sycl namespace, not of the same array namespace. + # When array_api_dispatch is enabled, then sklearn's version is required + # for the additional array namespace check. This is now possible with + # dpnp and dpctl as they both support `__array_namespace__`. + if not get_config().get("array_api_dispatch", False): + sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) + if sycl_type: + return xp, is_array_api_compliant + + # sklearn contains a specially patched numpy wrapper that should be + # reused which is yielded from sklearn's get_namespace. + if sklearn_check_version("1.2"): return sklearn_get_namespace(*arrays) else: return np, False