Skip to content

Commit 631410e

Browse files
committed
Add parameterized tests
1 parent 3e9895d commit 631410e

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

pandas/tests/tools/test_to_numeric.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,25 @@
1616
to_numeric,
1717
)
1818
import pandas._testing as tm
19+
from pandas.core.arrays.floating import (
20+
Float32Dtype,
21+
Float64Dtype,
22+
)
23+
from pandas.core.arrays.integer import (
24+
Int8Dtype,
25+
Int16Dtype,
26+
Int32Dtype,
27+
Int64Dtype,
28+
UInt8Dtype,
29+
UInt16Dtype,
30+
UInt32Dtype,
31+
UInt64Dtype,
32+
)
33+
34+
try:
35+
import pyarrow as pa
36+
except ImportError:
37+
pa = None
1938

2039

2140
@pytest.fixture(params=[None, "raise", "coerce"])
@@ -904,18 +923,79 @@ def test_coerce_pyarrow_backend():
904923
tm.assert_series_equal(result, expected)
905924

906925

926+
@pytest.mark.parametrize(
927+
"pyarrow_dtype",
928+
[
929+
pytest.param(
930+
pa.int8(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
931+
),
932+
pytest.param(
933+
pa.int16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
934+
),
935+
pytest.param(
936+
pa.int32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
937+
),
938+
pytest.param(
939+
pa.int64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
940+
),
941+
pytest.param(
942+
pa.uint8(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
943+
),
944+
pytest.param(
945+
pa.uint16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
946+
),
947+
pytest.param(
948+
pa.uint32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
949+
),
950+
pytest.param(
951+
pa.uint64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
952+
),
953+
pytest.param(
954+
pa.float16(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
955+
),
956+
pytest.param(
957+
pa.float32(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
958+
),
959+
pytest.param(
960+
pa.float64(), marks=pytest.mark.skipif(not pa, reason="pyarrow required")
961+
),
962+
pytest.param(
963+
pa.decimal128(10, 2),
964+
marks=pytest.mark.skipif(not pa, reason="pyarrow required"),
965+
),
966+
pytest.param(
967+
pa.decimal256(10, 2),
968+
marks=pytest.mark.skipif(not pa, reason="pyarrow required"),
969+
),
970+
],
971+
)
972+
def test_to_numeric_arrow_decimal_with_na(pyarrow_dtype):
973+
# GH 61641
974+
numeric_type = ArrowDtype(pyarrow_dtype)
975+
series = Series([1, None], dtype=numeric_type)
976+
result = to_numeric(series, errors="coerce")
977+
978+
tm.assert_series_equal(result, series)
979+
980+
907981
@pytest.mark.parametrize(
908982
"dtype",
909983
[
910-
"ArrowDtype",
984+
Int8Dtype,
985+
Int16Dtype,
986+
Int32Dtype,
987+
Int64Dtype,
988+
UInt8Dtype,
989+
UInt16Dtype,
990+
UInt32Dtype,
991+
UInt64Dtype,
992+
Float32Dtype,
993+
Float64Dtype,
911994
],
912995
)
913-
def test_to_numeric_arrow_decimal_with_na(dtype):
996+
def test_to_numeric_decimal_with_na(dtype):
914997
# GH 61641
915-
pa = pytest.importorskip("pyarrow")
916-
target_class = globals()[dtype]
917-
decimal_type = target_class(pa.decimal128(3, scale=2))
918-
series = Series([1, None], dtype=decimal_type)
998+
series = Series([1, None], dtype=dtype())
919999
result = to_numeric(series, errors="coerce")
9201000

9211001
tm.assert_series_equal(result, series)

0 commit comments

Comments
 (0)