Skip to content

Commit 0ebca38

Browse files
committed
update tests
1 parent d9f6983 commit 0ebca38

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

pandas/tests/io/parser/test_preserve_leading_zeros.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import pytest
44

5+
import pandas._testing as tm
6+
from pandas.errors import ParserWarning
7+
58

69
def test_leading_zeros_preserved_with_dtype_str(all_parsers, request):
710
# GH#57666: pyarrow engine strips leading zeros when dtype=str is passed
@@ -16,10 +19,19 @@ def test_leading_zeros_preserved_with_dtype_str(all_parsers, request):
1619
EF,000023607,ghi,0205
1720
GH,100102040,jkl,0205"""
1821

19-
result = parser.read_csv(
20-
StringIO(data),
21-
dtype=str,
22-
)
22+
if engine_name == "pyarrow":
23+
with tm.assert_produces_warning(
24+
ParserWarning, match="pyarrow engine expects a dict mapping"
25+
):
26+
result = parser.read_csv(
27+
StringIO(data),
28+
dtype=str,
29+
)
30+
else:
31+
result = parser.read_csv(
32+
StringIO(data),
33+
dtype=str,
34+
)
2335

2436
try:
2537
assert result.shape == (4, 4)
@@ -40,7 +52,7 @@ def test_leading_zeros_preserved_with_dtype_str(all_parsers, request):
4052
raise
4153

4254

43-
def test_leading_zeros_preserved_with_dtype_dict(all_parsers):
55+
def test_leading_zeros_preserved_with_dtype_dict_str_only(all_parsers):
4456
# GH#57666: pyarrow engine strips leading zeros when dtype=str is passed
4557
# GH#61618: further discussion on ensuring string dtype preservation across engines
4658

@@ -69,3 +81,44 @@ def test_leading_zeros_preserved_with_dtype_dict(all_parsers):
6981
assert result.loc[1, "col3"] == 200
7082
assert result.loc[2, "col3"] == 201
7183
assert result.loc[3, "col3"] == 202
84+
85+
86+
def test_leading_zeros_preserved_with_heterogeneous_dtypes(all_parsers):
87+
# GH#57666: pyarrow engine strips leading zeros when dtype=str is passed
88+
# GH#61618: further discussion on ensuring string dtype preservation across engines
89+
90+
parser = all_parsers
91+
engine_name = getattr(parser, "engine", "unknown")
92+
93+
data = """col1,col2,col3,col4
94+
AB,000388907,199,0150
95+
CD,101044572,200,0150
96+
EF,000023607,201,0205
97+
GH,100102040,202,0205"""
98+
99+
if engine_name == "pyarrow":
100+
with tm.assert_produces_warning(
101+
ParserWarning, match="may not be handled correctly by the pyarrow engine"
102+
):
103+
result = parser.read_csv(
104+
StringIO(data),
105+
dtype={"col2": str, "col3": int, "col4": str},
106+
)
107+
else:
108+
result = parser.read_csv(
109+
StringIO(data),
110+
dtype={"col2": str, "col3": int, "col4": str},
111+
)
112+
113+
assert result.shape == (4, 4)
114+
assert list(result.columns) == ["col1", "col2", "col3", "col4"]
115+
116+
assert result.loc[0, "col2"] == "000388907", "lost zeros in col2 row 0"
117+
assert result.loc[2, "col2"] == "000023607", "lost zeros in col2 row 2"
118+
assert result.loc[0, "col4"] == "0150", "lost zeros in col4 row 0"
119+
assert result.loc[2, "col4"] == "0205", "lost zeros in col4 row 2"
120+
121+
assert result.loc[0, "col3"] == 199
122+
assert result.loc[1, "col3"] == 200
123+
assert result.loc[2, "col3"] == 201
124+
assert result.loc[3, "col3"] == 202

0 commit comments

Comments
 (0)