Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions sklearnex/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

"""Tools to support array_api."""
"""Tools to support array API."""

import math
from collections.abc import Callable
Expand All @@ -27,6 +27,7 @@
from daal4py.sklearn._utils import sklearn_check_version
from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace

from .._config import get_config
from ..base import oneDALEstimator

if sklearn_check_version("1.6"):
Expand Down Expand Up @@ -83,11 +84,19 @@ def get_namespace(*arrays):
True of the arrays are containers that implement the Array API spec.
"""

sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)

if sycl_type:
return xp, is_array_api_compliant
elif sklearn_check_version("1.2"):
# check required because _get_sycl_namespace only verifies that *arrays
# are of the same sycl namespace, not of the same array namespace.
# When array_api_dispatch is enabled, then sklearn's version is required
# for the additional array namespace check. This is now possible with
# dpnp and dpctl as they both support `__array_namespace__`.
if not get_config().get("array_api_dispatch", False):
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
if sycl_type:
return xp, is_array_api_compliant

# sklearn contains a specially patched numpy wrapper that should be
# reused which is yielded from sklearn's get_namespace.
if sklearn_check_version("1.2"):
return sklearn_get_namespace(*arrays)
else:
return np, False
Expand Down
Loading