@@ -3979,6 +3979,116 @@ def truncate(self, before=None, after=None) -> MultiIndex:
39793979 verify_integrity = False ,
39803980 )
39813981
3982+ def factorize (
3983+ self ,
3984+ sort : bool = False ,
3985+ use_na_sentinel : bool = True ,
3986+ ) -> tuple [npt .NDArray [np .intp ], MultiIndex ]:
3987+ """
3988+ Encode the object as an enumerated type or categorical variable.
3989+
3990+ This method preserves extension dtypes (e.g., Int64, boolean, string)
3991+ in MultiIndex levels during factorization. See GH#62337.
3992+
3993+ Parameters
3994+ ----------
3995+ sort : bool, default False
3996+ Sort uniques and shuffle codes to maintain the relationship.
3997+ use_na_sentinel : bool, default True
3998+ If True, the sentinel -1 will be used for NaN values. If False,
3999+ NaN values will be encoded as non-negative integers and will not drop the
4000+ NaN from the uniques of the values.
4001+
4002+ Returns
4003+ -------
4004+ codes : np.ndarray
4005+ An integer ndarray that's an indexer into uniques.
4006+ uniques : MultiIndex
4007+ The unique values with extension dtypes preserved when present.
4008+
4009+ See Also
4010+ --------
4011+ Index.factorize : Encode the object as an enumerated type.
4012+
4013+ Examples
4014+ --------
4015+ >>> mi = pd.MultiIndex.from_arrays(
4016+ ... [pd.array([1, 2, 1], dtype="Int64"), ["a", "b", "a"]]
4017+ ... )
4018+ >>> codes, uniques = mi.factorize()
4019+ >>> codes
4020+ array([0, 1, 0])
4021+ >>> uniques.dtypes
4022+ level_0 Int64
4023+ level_1 object
4024+ dtype: object
4025+ """
4026+ # Check if any level has extension dtypes
4027+ has_extension_dtypes = any (
4028+ isinstance (level .dtype , ExtensionDtype ) for level in self .levels
4029+ )
4030+
4031+ if not has_extension_dtypes :
4032+ # Use parent implementation for performance when no extension dtypes
4033+ return super ().factorize (sort = sort , use_na_sentinel = use_na_sentinel )
4034+
4035+ # Custom implementation for extension dtypes (GH#62337)
4036+ return self ._factorize_with_extension_dtypes (
4037+ sort = sort , use_na_sentinel = use_na_sentinel
4038+ )
4039+
4040+ def _factorize_with_extension_dtypes (
4041+ self , sort : bool , use_na_sentinel : bool
4042+ ) -> tuple [npt .NDArray [np .intp ], MultiIndex ]:
4043+ """
4044+ Factorize MultiIndex while preserving extension dtypes.
4045+
4046+ This method uses the base factorize on _values but then reconstructs
4047+ the MultiIndex with proper extension dtypes preserved.
4048+ """
4049+ # Factorize using base algorithm on _values
4050+ codes , uniques_array = algos .factorize (
4051+ self ._values , sort = sort , use_na_sentinel = use_na_sentinel
4052+ )
4053+
4054+ # Handle empty case
4055+ if len (uniques_array ) == 0 :
4056+ # Create empty levels with preserved dtypes
4057+ empty_levels = []
4058+ for original_level in self .levels :
4059+ # Create empty level with same dtype
4060+ empty_level = original_level [:0 ] # Slice to get empty with same dtype
4061+ empty_levels .append (empty_level )
4062+
4063+ # Create empty MultiIndex with preserved level dtypes
4064+ result_mi = type (self )(
4065+ levels = empty_levels ,
4066+ codes = [[] for _ in range (len (empty_levels ))],
4067+ )
4068+ return codes , result_mi
4069+
4070+ # Create MultiIndex from unique tuples
4071+ result_mi = type (self ).from_tuples (uniques_array )
4072+
4073+ # Restore extension dtypes
4074+ new_levels = []
4075+ for i , original_level in enumerate (self .levels ):
4076+ if isinstance (original_level .dtype , ExtensionDtype ):
4077+ # Preserve extension dtype by casting result level
4078+ try :
4079+ new_level = result_mi .levels [i ].astype (original_level .dtype )
4080+ new_levels .append (new_level )
4081+ except (TypeError , ValueError ):
4082+ # If casting fails, keep the inferred level
4083+ new_levels .append (result_mi .levels [i ])
4084+ else :
4085+ # Keep inferred dtype for regular levels
4086+ new_levels .append (result_mi .levels [i ])
4087+
4088+ # Reconstruct with preserved dtypes
4089+ result_mi = result_mi .set_levels (new_levels )
4090+ return codes , result_mi
4091+
39824092 def equals (self , other : object ) -> bool :
39834093 """
39844094 Determines if two MultiIndex objects have the same labeling information
0 commit comments