Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v3.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,7 @@ MultiIndex
- Bug in :meth:`MultiIndex.from_tuples` causing wrong output with input of type tuples having NaN values (:issue:`60695`, :issue:`60988`)
- Bug in :meth:`DataFrame.__setitem__` where column alignment logic would reindex the assigned value with an empty index, incorrectly setting all values to ``NaN``.(:issue:`61841`)
- Bug in :meth:`DataFrame.reindex` and :meth:`Series.reindex` where reindexing :class:`Index` to a :class:`MultiIndex` would incorrectly set all values to ``NaN``.(:issue:`60923`)
- Bug in :meth:`MultiIndex.factorize` losing extension dtypes and converting them to base dtypes (:issue:`62337`)

I/O
^^^
Expand Down
20 changes: 20 additions & 0 deletions pandas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,7 +1302,27 @@ def factorize(
# GH#57517
uniques = self[:0]
else:
# GH#62337: preserve extension dtypes by reconstructing from original
# First create the MultiIndex using the standard constructor
uniques = self._constructor(uniques)

# Then replace levels to preserve extension dtypes
if len(uniques) > 0 and isinstance(uniques, ABCMultiIndex):
new_levels = []
# After isinstance check, we know uniques has levels attribute
for i, (level, orig_level) in enumerate( # pyright: ignore[reportGeneralTypeIssues]
zip(uniques.levels, self.levels, strict=False)
):
try:
# Try to cast to original extension dtype
new_level = level.astype(orig_level.dtype)
new_levels.append(new_level)
except (TypeError, ValueError):
# If casting fails, keep the inferred level
new_levels.append(level)

# Reconstruct MultiIndex with preserved dtypes only
uniques = uniques.set_levels(new_levels)
else:
from pandas import Index

Expand Down
134 changes: 134 additions & 0 deletions pandas/tests/indexes/multi/test_factorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
Tests for MultiIndex.factorize method
"""

import numpy as np
import pytest

import pandas as pd
import pandas._testing as tm


class TestMultiIndexFactorize:
def test_factorize_extension_dtype_int32(self):
# GH#62337: factorize should preserve Int32 extension dtype
df = pd.DataFrame({"col": pd.Series([1, None, 2], dtype="Int32")})
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

result_dtype = uniques.to_frame().iloc[:, 0].dtype
expected_dtype = pd.Int32Dtype()
assert result_dtype == expected_dtype

# Verify codes are correct
expected_codes = np.array([0, 1, 2], dtype=np.intp)
tm.assert_numpy_array_equal(codes, expected_codes)

@pytest.mark.parametrize("dtype", ["Int32", "Int64", "string", "boolean"])
def test_factorize_extension_dtypes(self, dtype):
# GH#62337: factorize should preserve various extension dtypes
if dtype == "boolean":
values = [True, None, False]
elif dtype == "string":
values = ["a", None, "b"]
else: # Int32, Int64
values = [1, None, 2]

df = pd.DataFrame({"col": pd.Series(values, dtype=dtype)})
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()
result_dtype = uniques.to_frame().iloc[:, 0].dtype

assert str(result_dtype) == dtype

def test_factorize_multiple_extension_dtypes(self):
# GH#62337: factorize with multiple columns having extension dtypes
df = pd.DataFrame(
{
"int_col": pd.Series([1, 2, 1], dtype="Int64"),
"str_col": pd.Series(["a", "b", "a"], dtype="string"),
}
)
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

result_frame = uniques.to_frame()
assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype()
assert result_frame.iloc[:, 1].dtype == pd.StringDtype()

# Should have 2 unique combinations: (1,'a') and (2,'b')
assert len(uniques) == 2

def test_factorize_preserves_names(self):
# GH#62337: factorize should preserve MultiIndex names when extension
# dtypes are involved
df = pd.DataFrame(
{
"level_1": pd.Series([1, 2], dtype="Int32"),
"level_2": pd.Series(["a", "b"], dtype="string"),
}
)
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

# The main fix is extension dtype preservation, names behavior follows
# existing patterns
# Just verify that factorize runs without errors and dtypes are preserved
result_frame = uniques.to_frame()
assert result_frame.iloc[:, 0].dtype == pd.Int32Dtype()
assert result_frame.iloc[:, 1].dtype == pd.StringDtype()

def test_factorize_extension_dtype_with_sort(self):
# GH#62337: factorize with sort=True should preserve extension dtypes
df = pd.DataFrame({"col": pd.Series([2, None, 1], dtype="Int32")})
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize(sort=True)

result_dtype = uniques.to_frame().iloc[:, 0].dtype
assert result_dtype == pd.Int32Dtype()

def test_factorize_empty_extension_dtype(self):
# GH#62337: factorize on empty MultiIndex with extension dtype
df = pd.DataFrame({"col": pd.Series([], dtype="Int32")})
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

assert len(codes) == 0
assert len(uniques) == 0
assert uniques.to_frame().iloc[:, 0].dtype == pd.Int32Dtype()

def test_factorize_regular_dtypes_unchanged(self):
# Ensure regular dtypes still work as before
df = pd.DataFrame({"int_col": [1, 2, 1], "float_col": [1.1, 2.2, 1.1]})
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

result_frame = uniques.to_frame()
assert result_frame.iloc[:, 0].dtype == np.dtype("int64")
assert result_frame.iloc[:, 1].dtype == np.dtype("float64")

# Should have 2 unique combinations
assert len(uniques) == 2

def test_factorize_mixed_extension_regular_dtypes(self):
# Mix of extension and regular dtypes
df = pd.DataFrame(
{
"ext_col": pd.Series([1, 2, 1], dtype="Int64"),
"reg_col": [1.1, 2.2, 1.1], # regular float64
}
)
mi = pd.MultiIndex.from_frame(df)

codes, uniques = mi.factorize()

result_frame = uniques.to_frame()
assert result_frame.iloc[:, 0].dtype == pd.Int64Dtype()
assert result_frame.iloc[:, 1].dtype == np.dtype("float64")
Loading