Skip to content

Commit 632b2d6

Browse files
BUG: Override MultiIndex.factorize to preserve extension dtypes
1 parent ccab4f2 commit 632b2d6

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed

pandas/core/indexes/multi.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3979,6 +3979,116 @@ def truncate(self, before=None, after=None) -> MultiIndex:
39793979
verify_integrity=False,
39803980
)
39813981

3982+
def factorize(
3983+
self,
3984+
sort: bool = False,
3985+
use_na_sentinel: bool = True,
3986+
) -> tuple[npt.NDArray[np.intp], MultiIndex]:
3987+
"""
3988+
Encode the object as an enumerated type or categorical variable.
3989+
3990+
This method preserves extension dtypes (e.g., Int64, boolean, string)
3991+
in MultiIndex levels during factorization. See GH#62337.
3992+
3993+
Parameters
3994+
----------
3995+
sort : bool, default False
3996+
Sort uniques and shuffle codes to maintain the relationship.
3997+
use_na_sentinel : bool, default True
3998+
If True, the sentinel -1 will be used for NaN values. If False,
3999+
NaN values will be encoded as non-negative integers and will not drop the
4000+
NaN from the uniques of the values.
4001+
4002+
Returns
4003+
-------
4004+
codes : np.ndarray
4005+
An integer ndarray that's an indexer into uniques.
4006+
uniques : MultiIndex
4007+
The unique values with extension dtypes preserved when present.
4008+
4009+
See Also
4010+
--------
4011+
Index.factorize : Encode the object as an enumerated type.
4012+
4013+
Examples
4014+
--------
4015+
>>> mi = pd.MultiIndex.from_arrays(
4016+
... [pd.array([1, 2, 1], dtype="Int64"), ["a", "b", "a"]]
4017+
... )
4018+
>>> codes, uniques = mi.factorize()
4019+
>>> codes
4020+
array([0, 1, 0])
4021+
>>> uniques.dtypes
4022+
level_0 Int64
4023+
level_1 object
4024+
dtype: object
4025+
"""
4026+
# Check if any level has extension dtypes
4027+
has_extension_dtypes = any(
4028+
isinstance(level.dtype, ExtensionDtype) for level in self.levels
4029+
)
4030+
4031+
if not has_extension_dtypes:
4032+
# Use parent implementation for performance when no extension dtypes
4033+
return super().factorize(sort=sort, use_na_sentinel=use_na_sentinel)
4034+
4035+
# Custom implementation for extension dtypes (GH#62337)
4036+
return self._factorize_with_extension_dtypes(
4037+
sort=sort, use_na_sentinel=use_na_sentinel
4038+
)
4039+
4040+
def _factorize_with_extension_dtypes(
4041+
self, sort: bool, use_na_sentinel: bool
4042+
) -> tuple[npt.NDArray[np.intp], MultiIndex]:
4043+
"""
4044+
Factorize MultiIndex while preserving extension dtypes.
4045+
4046+
This method uses the base factorize on _values but then reconstructs
4047+
the MultiIndex with proper extension dtypes preserved.
4048+
"""
4049+
# Factorize using base algorithm on _values
4050+
codes, uniques_array = algos.factorize(
4051+
self._values, sort=sort, use_na_sentinel=use_na_sentinel
4052+
)
4053+
4054+
# Handle empty case
4055+
if len(uniques_array) == 0:
4056+
# Create empty levels with preserved dtypes
4057+
empty_levels = []
4058+
for original_level in self.levels:
4059+
# Create empty level with same dtype
4060+
empty_level = original_level[:0] # Slice to get empty with same dtype
4061+
empty_levels.append(empty_level)
4062+
4063+
# Create empty MultiIndex with preserved level dtypes
4064+
result_mi = type(self)(
4065+
levels=empty_levels,
4066+
codes=[[] for _ in range(len(empty_levels))],
4067+
)
4068+
return codes, result_mi
4069+
4070+
# Create MultiIndex from unique tuples
4071+
result_mi = type(self).from_tuples(uniques_array)
4072+
4073+
# Restore extension dtypes
4074+
new_levels = []
4075+
for i, original_level in enumerate(self.levels):
4076+
if isinstance(original_level.dtype, ExtensionDtype):
4077+
# Preserve extension dtype by casting result level
4078+
try:
4079+
new_level = result_mi.levels[i].astype(original_level.dtype)
4080+
new_levels.append(new_level)
4081+
except (TypeError, ValueError):
4082+
# If casting fails, keep the inferred level
4083+
new_levels.append(result_mi.levels[i])
4084+
else:
4085+
# Keep inferred dtype for regular levels
4086+
new_levels.append(result_mi.levels[i])
4087+
4088+
# Reconstruct with preserved dtypes
4089+
result_mi = result_mi.set_levels(new_levels)
4090+
return codes, result_mi
4091+
39824092
def equals(self, other: object) -> bool:
39834093
"""
39844094
Determines if two MultiIndex objects have the same labeling information

0 commit comments

Comments
 (0)