Skip to content

Commit 74057eb

Browse files
author
T. Koskamp
committed
Update indices property from groupby
1 parent d2046e9 commit 74057eb

File tree

2 files changed

+32
-39
lines changed

2 files changed

+32
-39
lines changed

pandas/core/groupby/groupby.py

Lines changed: 14 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
637637
return self._grouper.indices
638638

639639
@final
640-
def _get_indices(self, names):
640+
def _get_indices(self, name):
641641
"""
642642
Safe get multiple indices, translate keys for
643643
datelike to underlying repr.
@@ -650,28 +650,27 @@ def get_converter(s):
650650
return lambda key: Timestamp(key)
651651
elif isinstance(s, np.datetime64):
652652
return lambda key: Timestamp(key).asm8
653-
elif isna(s):
654-
return lambda key: np.nan
655653
else:
656654
return lambda key: key
657655

658-
if len(names) == 0:
659-
return []
656+
if isna(name):
657+
return self.indices.get(np.nan, [])
658+
if isinstance(name, tuple):
659+
name = tuple(np.nan if isna(comp) else comp for comp in name)
660660

661661
if len(self.indices) > 0:
662662
index_sample = next(iter(self.indices))
663663
else:
664664
index_sample = None # Dummy sample
665665

666-
name_sample = names[0]
667666
if isinstance(index_sample, tuple):
668-
if not isinstance(name_sample, tuple):
667+
if not isinstance(name, tuple):
669668
msg = "must supply a tuple to get_group with multiple grouping keys"
670669
raise ValueError(msg)
671-
if not len(name_sample) == len(index_sample):
670+
if not len(name) == len(index_sample):
672671
try:
673672
# If the original grouper was a tuple
674-
return [self.indices[name] for name in names]
673+
return self.indices[name]
675674
except KeyError as err:
676675
# turns out it wasn't a tuple
677676
msg = (
@@ -680,41 +679,20 @@ def get_converter(s):
680679
)
681680
raise ValueError(msg) from err
682681

683-
has_nan = any(isna(n) for n in name_sample)
684-
685-
sample = name_sample if has_nan else index_sample
686-
converters = (get_converter(s) for s in sample)
687-
688-
names = (
689-
tuple(f(n) for f, n in zip(converters, name, strict=True))
690-
for name in names
691-
)
692-
693-
indices = self.indices
694-
if not self.dropna and has_nan:
695-
indices = {}
696-
for k, v in self.indices.items():
697-
k = tuple(np.nan if isna(e) else e for e in k)
698-
indices[k] = v
682+
converters = (get_converter(s) for s in index_sample)
683+
name = tuple(f(n) for f, n in zip(converters, name, strict=True))
699684
else:
700-
has_nan = isna(name_sample)
701-
702-
convert_sample = name_sample if has_nan else index_sample
703-
converter = get_converter(convert_sample)
704-
names = (converter(name) for name in names)
705-
706-
indices = self.indices
707-
if not self.dropna and has_nan:
708-
indices = {np.nan if isna(k) else k: v for k, v in indices.items()}
685+
converter = get_converter(index_sample)
686+
name = converter(name)
709687

710-
return [indices.get(name, []) for name in names]
688+
return self.indices.get(name, [])
711689

712690
@final
713691
def _get_index(self, name):
714692
"""
715693
Safe get index, translate keys for datelike to underlying repr.
716694
"""
717-
return self._get_indices([name])[0]
695+
return self._get_indices(name)
718696

719697
@final
720698
@cache_readonly

pandas/core/groupby/ops.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,24 @@ def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
652652
"""dict {group name -> group indices}"""
653653
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
654654
# This shows unused categories in indices GH#38642
655-
return self.groupings[0].indices
656-
codes_list = [ping.codes for ping in self.groupings]
657-
return get_indexer_dict(codes_list, self.levels)
655+
result = self.groupings[0].indices
656+
else:
657+
codes_list = [ping.codes for ping in self.groupings]
658+
result = get_indexer_dict(codes_list, self.levels)
659+
if not self.dropna:
660+
has_mi = isinstance(self.result_index, MultiIndex)
661+
if not has_mi and self.result_index.hasnans:
662+
result = {
663+
np.nan if isna(key) else key: value for key, value in result.items()
664+
}
665+
elif has_mi:
666+
# MultiIndex has no efficient way to tell if there are NAs
667+
result = {
668+
tuple(np.nan if isna(comp) else comp for comp in key): value
669+
for key, value in result.items()
670+
}
671+
672+
return result
658673

659674
@final
660675
@cache_readonly

0 commit comments

Comments
 (0)