|
48 | 48 | notna, |
49 | 49 | ) |
50 | 50 |
|
| 51 | + |
| 52 | +from pandas.core.util.numba_ import GLOBAL_USE_NUMBA |
| 53 | +from pandas.core import nanops_numba |
| 54 | + |
51 | 55 | if TYPE_CHECKING: |
52 | 56 | from collections.abc import Callable |
53 | 57 |
|
@@ -97,6 +101,38 @@ def _f(*args, **kwargs): |
97 | 101 | return cast(F, _f) |
98 | 102 |
|
99 | 103 |
|
| 104 | +class numba_switch: |
| 105 | + def __init__(self, name=None, **kwargs) -> None: |
| 106 | + self.name = name |
| 107 | + self.kwargs = kwargs |
| 108 | + |
| 109 | + def __call__(self, alt: F) -> F: |
| 110 | + nb_name = self.name or alt.__name__ |
| 111 | + |
| 112 | + try: |
| 113 | + nb_func = getattr(nanops_numba, nb_name) |
| 114 | + except (AttributeError, NameError): # pragma: no cover |
| 115 | + nb_func = None |
| 116 | + |
| 117 | + @functools.wraps(alt) |
| 118 | + def f( |
| 119 | + values: np.ndarray, |
| 120 | + *, |
| 121 | + axis: AxisInt | None = None, |
| 122 | + skipna: bool = True, |
| 123 | + **kwds, |
| 124 | + ): |
| 125 | + disallowed = values.dtype == "O" |
| 126 | + if GLOBAL_USE_NUMBA and not disallowed: |
| 127 | + result = nb_func(values, skipna=skipna, axis=axis, **kwds) |
| 128 | + else: |
| 129 | + result = alt(values, axis=axis, skipna=skipna, **kwds) |
| 130 | + |
| 131 | + return result |
| 132 | + |
| 133 | + return cast(F, f) |
| 134 | + |
| 135 | + |
100 | 136 | class bottleneck_switch: |
101 | 137 | def __init__(self, name=None, **kwargs) -> None: |
102 | 138 | self.name = name |
@@ -593,6 +629,7 @@ def nanall( |
593 | 629 | return values.all(axis) # type: ignore[return-value] |
594 | 630 |
|
595 | 631 |
|
| 632 | +@numba_switch() |
596 | 633 | @disallow("M8") |
597 | 634 | @_datetimelike_compat |
598 | 635 | @maybe_operate_rowwise |
@@ -660,7 +697,7 @@ def _mask_datetimelike_result( |
660 | 697 | return result |
661 | 698 |
|
662 | 699 |
|
663 | | -@bottleneck_switch() |
| 700 | +@numba_switch() |
664 | 701 | @_datetimelike_compat |
665 | 702 | def nanmean( |
666 | 703 | values: np.ndarray, |
@@ -910,7 +947,7 @@ def _get_counts_nanvar( |
910 | 947 | return count, d |
911 | 948 |
|
912 | 949 |
|
913 | | -@bottleneck_switch(ddof=1) |
| 950 | +@numba_switch(ddof=1) |
914 | 951 | def nanstd( |
915 | 952 | values, |
916 | 953 | *, |
@@ -957,7 +994,7 @@ def nanstd( |
957 | 994 |
|
958 | 995 |
|
959 | 996 | @disallow("M8", "m8") |
960 | | -@bottleneck_switch(ddof=1) |
| 997 | +@numba_switch(ddof=1) |
961 | 998 | def nanvar( |
962 | 999 | values: np.ndarray, |
963 | 1000 | *, |
@@ -1035,6 +1072,7 @@ def nanvar( |
1035 | 1072 | return result |
1036 | 1073 |
|
1037 | 1074 |
|
| 1075 | +@numba_switch() |
1038 | 1076 | @disallow("M8", "m8") |
1039 | 1077 | def nansem( |
1040 | 1078 | values: np.ndarray, |
@@ -1089,7 +1127,7 @@ def nansem( |
1089 | 1127 |
|
1090 | 1128 |
|
1091 | 1129 | def _nanminmax(meth, fill_value_typ): |
1092 | | - @bottleneck_switch(name=f"nan{meth}") |
| 1130 | + @numba_switch(name=f"nan{meth}") |
1093 | 1131 | @_datetimelike_compat |
1094 | 1132 | def reduction( |
1095 | 1133 | values: np.ndarray, |
|
0 commit comments