Skip to content

Commit a5f0b1d

Browse files
committed
fix sort_index using level name on MultiIndex
1 parent 94c7e88 commit a5f0b1d

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

pandas/core/sorting.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757

5858
def get_indexer_indexer(
5959
target: Index,
60-
level: Level | list[Level] | None,
60+
level: Level | list[Level] | None, # can level actually be a list here?
6161
ascending: list[bool] | bool,
6262
kind: SortKind,
6363
na_position: NaPosition,
@@ -87,7 +87,19 @@ def get_indexer_indexer(
8787
# error: Incompatible types in assignment (expression has type
8888
# "Union[ExtensionArray, ndarray[Any, Any], Index, Series]", variable has
8989
# type "Index")
90+
91+
# before:
92+
# MultiIndex([('a', 'top10'),
93+
# ('a', 'top2')],
94+
# names=['A', 'B'])
9095
target = ensure_key_mapped(target, key, levels=level) # type: ignore[assignment]
96+
# # after
97+
# MultiIndex([('a', 1),
98+
# ('a', 0)],
99+
# names=['A', None])
100+
# the big problem is that the name is lost as well,
101+
# but with the new change I preserve it
102+
91103
target = target._sort_levels_monotonic()
92104

93105
if level is not None:
@@ -531,11 +543,15 @@ def _ensure_key_mapped_multiindex(
531543
level_iter = [level]
532544
else:
533545
level_iter = level
534-
535546
sort_levels: range | set = {index._get_level_number(lev) for lev in level_iter}
536547
else:
537548
sort_levels = range(index.nlevels)
538549

550+
# breakpoint() # the loops through the levels
551+
# for the levels to be sorted, it applies the key function
552+
# (uses the number, not the name)
553+
# it returns the indexeer: ensure_key_mapped(
554+
# index._get_level_values(1), key) = Index([1, 0], dtype='int64')
539555
mapped = [
540556
(
541557
ensure_key_mapped(index._get_level_values(level), key)
@@ -569,19 +585,23 @@ def ensure_key_mapped(
569585
return values
570586

571587
if isinstance(values, ABCMultiIndex):
588+
# redirects to special MultiIndex handler
572589
return _ensure_key_mapped_multiindex(values, key, level=levels)
573590

574591
result = key(values.copy())
575592
if len(result) != len(values):
576593
raise ValueError(
577-
"User-provided `key` function must not change the shape of the array."
594+
"User-provided `key` bfunction must not change the shape of the array."
578595
)
579596

580597
try:
581598
if isinstance(
582599
values, Index
583600
): # convert to a new Index subclass, not necessarily the same
584-
result = Index(result, tupleize_cols=False)
601+
# preserve the original name when creating the new Index
602+
result = Index(
603+
result, tupleize_cols=False, name=getattr(values, "name", None)
604+
)
585605
else:
586606
# try to revert to original type otherwise
587607
type_of_values = type(values)

pandas/tests/frame/methods/test_sort_index.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from natsort import index_natsorted
12
import numpy as np
23
import pytest
34

@@ -943,6 +944,88 @@ def test_sort_index_multiindex_sort_remaining(self, ascending):
943944

944945
tm.assert_frame_equal(result, expected)
945946

947+
def test_sort_multi_index_sort_by_level_name(self):
948+
# GH#62361
949+
950+
df = DataFrame(
951+
[[1, 2], [3, 4]],
952+
columns=MultiIndex.from_product(
953+
[["a"], ["top10", "top2"]], names=("A", "B")
954+
),
955+
)
956+
957+
expected = DataFrame(
958+
[[2, 1], [4, 3]],
959+
columns=MultiIndex.from_product(
960+
[["a"], ["top2", "top10"]], names=("A", "B")
961+
),
962+
)
963+
964+
sorted_df = df.sort_index(
965+
axis=1, level="B", key=lambda x: np.argsort(index_natsorted(x))
966+
)
967+
tm.assert_frame_equal(sorted_df, expected)
968+
969+
def test_sort_multi_index_sort_by_level_name_2(self):
970+
# GH#62361
971+
972+
df = DataFrame(
973+
[[1, 2], [3, 4]],
974+
columns=MultiIndex.from_tuples(
975+
[("alpha10", "top10"), ("alpha3", "top2")], names=("A", "B")
976+
),
977+
)
978+
979+
expected = DataFrame(
980+
[[2, 1], [4, 3]],
981+
columns=MultiIndex.from_tuples(
982+
[("alpha3", "top2"), ("alpha10", "top10")], names=("A", "B")
983+
),
984+
)
985+
986+
sorted_df = df.sort_index(
987+
axis=1, level=0, key=lambda x: np.argsort(index_natsorted(x))
988+
)
989+
tm.assert_frame_equal(sorted_df, expected)
990+
991+
sorted_df = df.sort_index(
992+
axis=1, level="A", key=lambda x: np.argsort(index_natsorted(x))
993+
)
994+
tm.assert_frame_equal(sorted_df, expected)
995+
996+
sorted_df = df.sort_index(
997+
axis=1, level=1, key=lambda x: np.argsort(index_natsorted(x))
998+
)
999+
tm.assert_frame_equal(sorted_df, expected)
1000+
1001+
sorted_df = df.sort_index(
1002+
axis=1, level="B", key=lambda x: np.argsort(index_natsorted(x))
1003+
)
1004+
tm.assert_frame_equal(sorted_df, expected)
1005+
1006+
sorted_df = df.sort_index(
1007+
axis=1, level=[0, 1], key=lambda x: np.argsort(index_natsorted(x))
1008+
)
1009+
tm.assert_frame_equal(sorted_df, expected)
1010+
1011+
sorted_df = df.sort_index(
1012+
axis=1, level=[1, 0], key=lambda x: np.argsort(index_natsorted(x))
1013+
)
1014+
tm.assert_frame_equal(sorted_df, expected)
1015+
1016+
sorted_df = df.sort_index(
1017+
axis=1, level=[1, "A"], key=lambda x: np.argsort(index_natsorted(x))
1018+
)
1019+
tm.assert_frame_equal(sorted_df, expected)
1020+
1021+
# repetition does not matter
1022+
sorted_df = df.sort_index(
1023+
axis=1,
1024+
level=["A", "B", 0, 1, "B"],
1025+
key=lambda x: np.argsort(index_natsorted(x)),
1026+
)
1027+
tm.assert_frame_equal(sorted_df, expected)
1028+
9461029

9471030
def test_sort_index_with_sliced_multiindex():
9481031
# GH 55379

0 commit comments

Comments
 (0)