Skip to content

Commit ac9cd15

Browse files
TST: Add tests for MultiIndex.factorize method with extension dtypes
1 parent 77fdffd commit ac9cd15

File tree

1 file changed

+128
-0
lines changed

1 file changed

+128
-0
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Tests for MultiIndex.factorize method
3+
"""
4+
5+
import numpy as np
6+
import pytest
7+
8+
import pandas as pd
9+
import pandas._testing as tm
10+
11+
12+
class TestMultiIndexFactorize:
13+
def test_factorize_extension_dtype_int32(self):
14+
# GH#62337: factorize should preserve Int32 extension dtype
15+
df = pd.DataFrame({"col": pd.Series([1, None, 2], dtype="Int32")})
16+
mi = pd.MultiIndex.from_frame(df)
17+
18+
codes, uniques = mi.factorize()
19+
20+
result_dtype = uniques.to_frame().iloc[:, 0].dtype
21+
expected_dtype = pd.Int32Dtype()
22+
assert result_dtype == expected_dtype
23+
24+
# Verify codes are correct
25+
expected_codes = np.array([0, 1, 2], dtype=np.intp)
26+
tm.assert_numpy_array_equal(codes, expected_codes)
27+
28+
@pytest.mark.parametrize("dtype", ["Int32", "Int64", "string", "boolean"])
29+
def test_factorize_extension_dtypes(self, dtype):
30+
# GH#62337: factorize should preserve various extension dtypes
31+
if dtype == "boolean":
32+
values = [True, None, False]
33+
elif dtype == "string":
34+
values = ["a", None, "b"]
35+
else: # Int32, Int64
36+
values = [1, None, 2]
37+
38+
df = pd.DataFrame({"col": pd.Series(values, dtype=dtype)})
39+
mi = pd.MultiIndex.from_frame(df)
40+
41+
codes, uniques = mi.factorize()
42+
result_dtype = uniques.to_frame().iloc[:, 0].dtype
43+
44+
assert str(result_dtype) == dtype
45+
46+
def test_factorize_multiple_extension_dtypes(self):
47+
# GH#62337: factorize with multiple columns having extension dtypes
48+
df = pd.DataFrame(
49+
{
50+
"int_col": pd.Series([1, 2, 1], dtype="Int64"),
51+
"str_col": pd.Series(["a", "b", "a"], dtype="string"),
52+
}
53+
)
54+
mi = pd.MultiIndex.from_frame(df)
55+
56+
codes, uniques = mi.factorize()
57+
58+
result_frame = uniques.to_frame()
59+
assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype()
60+
assert result_frame.iloc[:, 1].dtype == pd.StringDtype()
61+
62+
# Should have 2 unique combinations: (1,'a') and (2,'b')
63+
assert len(uniques) == 2
64+
65+
def test_factorize_preserves_names(self):
66+
# GH#62337: factorize should preserve MultiIndex names
67+
df = pd.DataFrame(
68+
{
69+
"level_1": pd.Series([1, 2], dtype="Int32"),
70+
"level_2": pd.Series(["a", "b"], dtype="string"),
71+
}
72+
)
73+
mi = pd.MultiIndex.from_frame(df)
74+
75+
codes, uniques = mi.factorize()
76+
77+
tm.assert_index_equal(uniques.names, mi.names)
78+
79+
def test_factorize_extension_dtype_with_sort(self):
80+
# GH#62337: factorize with sort=True should preserve extension dtypes
81+
df = pd.DataFrame({"col": pd.Series([2, None, 1], dtype="Int32")})
82+
mi = pd.MultiIndex.from_frame(df)
83+
84+
codes, uniques = mi.factorize(sort=True)
85+
86+
result_dtype = uniques.to_frame().iloc[:, 0].dtype
87+
assert result_dtype == pd.Int32Dtype()
88+
89+
def test_factorize_empty_extension_dtype(self):
90+
# GH#62337: factorize on empty MultiIndex with extension dtype
91+
df = pd.DataFrame({"col": pd.Series([], dtype="Int32")})
92+
mi = pd.MultiIndex.from_frame(df)
93+
94+
codes, uniques = mi.factorize()
95+
96+
assert len(codes) == 0
97+
assert len(uniques) == 0
98+
assert uniques.to_frame().iloc[:, 0].dtype == pd.Int32Dtype()
99+
100+
def test_factorize_regular_dtypes_unchanged(self):
101+
# Ensure regular dtypes still work as before
102+
df = pd.DataFrame({"int_col": [1, 2, 1], "float_col": [1.1, 2.2, 1.1]})
103+
mi = pd.MultiIndex.from_frame(df)
104+
105+
codes, uniques = mi.factorize()
106+
107+
result_frame = uniques.to_frame()
108+
assert result_frame.iloc[:, 0].dtype == np.dtype("int64")
109+
assert result_frame.iloc[:, 1].dtype == np.dtype("float64")
110+
111+
# Should have 2 unique combinations
112+
assert len(uniques) == 2
113+
114+
def test_factorize_mixed_extension_regular_dtypes(self):
115+
# Mix of extension and regular dtypes
116+
df = pd.DataFrame(
117+
{
118+
"ext_col": pd.Series([1, 2, 1], dtype="Int64"),
119+
"reg_col": [1.1, 2.2, 1.1], # regular float64
120+
}
121+
)
122+
mi = pd.MultiIndex.from_frame(df)
123+
124+
codes, uniques = mi.factorize()
125+
126+
result_frame = uniques.to_frame()
127+
assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype()
128+
assert result_frame.iloc[:, 1].dtype == np.dtype("float64")

0 commit comments

Comments
 (0)