|
16 | 16 | to_numeric, |
17 | 17 | ) |
18 | 18 | import pandas._testing as tm |
| 19 | +from pandas.core.arrays.floating import ( |
| 20 | + Float32Dtype, |
| 21 | + Float64Dtype, |
| 22 | +) |
| 23 | +from pandas.core.arrays.integer import ( |
| 24 | + Int8Dtype, |
| 25 | + Int16Dtype, |
| 26 | + Int32Dtype, |
| 27 | + Int64Dtype, |
| 28 | + UInt8Dtype, |
| 29 | + UInt16Dtype, |
| 30 | + UInt32Dtype, |
| 31 | + UInt64Dtype, |
| 32 | +) |
| 33 | + |
| 34 | +try: |
| 35 | + import pyarrow as pa |
| 36 | +except ImportError: |
| 37 | + pa = None |
19 | 38 |
|
20 | 39 |
|
21 | 40 | @pytest.fixture(params=[None, "raise", "coerce"]) |
@@ -904,18 +923,79 @@ def test_coerce_pyarrow_backend(): |
904 | 923 | tm.assert_series_equal(result, expected) |
905 | 924 |
|
906 | 925 |
|
| 926 | +@pytest.mark.parametrize( |
| 927 | + "pyarrow_dtype", |
| 928 | + [ |
| 929 | + pytest.param( |
| 930 | + pa.int8(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 931 | + ), |
| 932 | + pytest.param( |
| 933 | + pa.int16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 934 | + ), |
| 935 | + pytest.param( |
| 936 | + pa.int32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 937 | + ), |
| 938 | + pytest.param( |
| 939 | + pa.int64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 940 | + ), |
| 941 | + pytest.param( |
| 942 | + pa.uint8(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 943 | + ), |
| 944 | + pytest.param( |
| 945 | + pa.uint16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 946 | + ), |
| 947 | + pytest.param( |
| 948 | + pa.uint32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 949 | + ), |
| 950 | + pytest.param( |
| 951 | + pa.uint64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 952 | + ), |
| 953 | + pytest.param( |
| 954 | + pa.float16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 955 | + ), |
| 956 | + pytest.param( |
| 957 | + pa.float32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 958 | + ), |
| 959 | + pytest.param( |
| 960 | + pa.float64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required") |
| 961 | + ), |
| 962 | + pytest.param( |
| 963 | + pa.decimal128(10, 2), |
| 964 | + marks=pytest.mark.skipif(not pa, reason="pyarrow required"), |
| 965 | + ), |
| 966 | + pytest.param( |
| 967 | + pa.decimal256(10, 2), |
| 968 | + marks=pytest.mark.skipif(not pa, reason="pyarrow required"), |
| 969 | + ), |
| 970 | + ], |
| 971 | +) |
| 972 | +def test_to_numeric_arrow_decimal_with_na(pyarrow_dtype): |
| 973 | + # GH 61641 |
| 974 | + numeric_type = ArrowDtype(pyarrow_dtype) |
| 975 | + series = Series([1, None], dtype=numeric_type) |
| 976 | + result = to_numeric(series, errors="coerce") |
| 977 | + |
| 978 | + tm.assert_series_equal(result, series) |
| 979 | + |
| 980 | + |
907 | 981 | @pytest.mark.parametrize( |
908 | 982 | "dtype", |
909 | 983 | [ |
910 | | - "ArrowDtype", |
| 984 | + Int8Dtype, |
| 985 | + Int16Dtype, |
| 986 | + Int32Dtype, |
| 987 | + Int64Dtype, |
| 988 | + UInt8Dtype, |
| 989 | + UInt16Dtype, |
| 990 | + UInt32Dtype, |
| 991 | + UInt64Dtype, |
| 992 | + Float32Dtype, |
| 993 | + Float64Dtype, |
911 | 994 | ], |
912 | 995 | ) |
913 | | -def test_to_numeric_arrow_decimal_with_na(dtype): |
| 996 | +def test_to_numeric_decimal_with_na(dtype): |
914 | 997 | # GH 61641 |
915 | | - pa = pytest.importorskip("pyarrow") |
916 | | - target_class = globals()[dtype] |
917 | | - decimal_type = target_class(pa.decimal128(3, scale=2)) |
918 | | - series = Series([1, None], dtype=decimal_type) |
| 998 | + series = Series([1, None], dtype=dtype()) |
919 | 999 | result = to_numeric(series, errors="coerce") |
920 | 1000 |
|
921 | 1001 | tm.assert_series_equal(result, series) |
0 commit comments