|
48 | 48 | notna, |
49 | 49 | ) |
50 | 50 |
|
| 51 | +from pandas.core.util.numba_ import GLOBAL_USE_NUMBA |
| 52 | + |
| 53 | +if GLOBAL_USE_NUMBA: |
| 54 | + from pandas.core import nanops_numba |
| 55 | + |
| 56 | + |
51 | 57 | if TYPE_CHECKING: |
52 | 58 | from collections.abc import Callable |
53 | 59 |
|
@@ -97,6 +103,38 @@ def _f(*args, **kwargs): |
97 | 103 | return cast(F, _f) |
98 | 104 |
|
99 | 105 |
|
| 106 | +class numba_switch: |
| 107 | + def __init__(self, name=None, **kwargs) -> None: |
| 108 | + self.name = name |
| 109 | + self.kwargs = kwargs |
| 110 | + |
| 111 | + def __call__(self, alt: F) -> F: |
| 112 | + nb_name = self.name or alt.__name__ |
| 113 | + |
| 114 | + try: |
| 115 | + nb_func = getattr(nanops_numba, nb_name) |
| 116 | + except (AttributeError, NameError): # pragma: no cover |
| 117 | + return alt |
| 118 | + |
| 119 | + @functools.wraps(alt) |
| 120 | + def f( |
| 121 | + values: np.ndarray, |
| 122 | + *, |
| 123 | + axis: AxisInt | None = None, |
| 124 | + skipna: bool = True, |
| 125 | + **kwds, |
| 126 | + ): |
| 127 | + disallowed = values.dtype == "O" |
| 128 | + if not disallowed: |
| 129 | + result = nb_func(values, skipna=skipna, axis=axis, **kwds) |
| 130 | + else: |
| 131 | + result = alt(values, axis=axis, skipna=skipna, **kwds) |
| 132 | + |
| 133 | + return result |
| 134 | + |
| 135 | + return cast(F, f) |
| 136 | + |
| 137 | + |
100 | 138 | class bottleneck_switch: |
101 | 139 | def __init__(self, name=None, **kwargs) -> None: |
102 | 140 | self.name = name |
@@ -593,6 +631,7 @@ def nanall( |
593 | 631 | return values.all(axis) # type: ignore[return-value] |
594 | 632 |
|
595 | 633 |
|
| 634 | +@numba_switch() |
596 | 635 | @disallow("M8") |
597 | 636 | @_datetimelike_compat |
598 | 637 | @maybe_operate_rowwise |
@@ -658,7 +697,7 @@ def _mask_datetimelike_result( |
658 | 697 | return result |
659 | 698 |
|
660 | 699 |
|
661 | | -@bottleneck_switch() |
| 700 | +@numba_switch() |
662 | 701 | @_datetimelike_compat |
663 | 702 | def nanmean( |
664 | 703 | values: np.ndarray, |
@@ -908,7 +947,7 @@ def _get_counts_nanvar( |
908 | 947 | return count, d |
909 | 948 |
|
910 | 949 |
|
911 | | -@bottleneck_switch(ddof=1) |
| 950 | +@numba_switch(ddof=1) |
912 | 951 | def nanstd( |
913 | 952 | values, |
914 | 953 | *, |
@@ -955,7 +994,7 @@ def nanstd( |
955 | 994 |
|
956 | 995 |
|
957 | 996 | @disallow("M8", "m8") |
958 | | -@bottleneck_switch(ddof=1) |
| 997 | +@numba_switch(ddof=1) |
959 | 998 | def nanvar( |
960 | 999 | values: np.ndarray, |
961 | 1000 | *, |
@@ -1033,6 +1072,7 @@ def nanvar( |
1033 | 1072 | return result |
1034 | 1073 |
|
1035 | 1074 |
|
| 1075 | +@numba_switch() |
1036 | 1076 | @disallow("M8", "m8") |
1037 | 1077 | def nansem( |
1038 | 1078 | values: np.ndarray, |
@@ -1087,7 +1127,7 @@ def nansem( |
1087 | 1127 |
|
1088 | 1128 |
|
1089 | 1129 | def _nanminmax(meth, fill_value_typ): |
1090 | | - @bottleneck_switch(name=f"nan{meth}") |
| 1130 | + @numba_switch(name=f"nan{meth}") |
1091 | 1131 | @_datetimelike_compat |
1092 | 1132 | def reduction( |
1093 | 1133 | values: np.ndarray, |
|
0 commit comments