Skip to content

Commit 7361570

Browse files
committed
review comments
1 parent fcbdee4 commit 7361570

File tree

4 files changed

+28
-15
lines changed

4 files changed

+28
-15
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ All warnings for upcoming changes in pandas will have the base class :class:`pan
156156

157157
Other enhancements
158158
^^^^^^^^^^^^^^^^^^
159-
- :class:`pandas.NamedAgg` now forwards any ``*args`` and ``**kwargs``
159+
- :class:`pandas.NamedAgg` now supports passing ``*args`` and ``**kwargs``
160160
to calls of ``aggfunc`` (:issue:`58283`)
161161
- :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all
162162
inputs have identical ``attrs``, as has so far already been the case for

pandas/core/apply.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
from pandas.core._numba.executor import generate_apply_looper
5050
import pandas.core.common as com
5151
from pandas.core.construction import ensure_wrapped_if_datetimelike
52+
from pandas.core.groupby.generic import NamedAgg
5253
from pandas.core.util.numba_ import (
5354
get_jit_arguments,
5455
prepare_function_arguments,
@@ -1714,7 +1715,12 @@ def reconstruct_func(
17141715
or not and also normalize the keyword to get new order of columns.
17151716
17161717
If named aggregation is applied, `func` will be None, and kwargs contains the
1717-
column and aggregation function information to be parsed;
1718+
column and aggregation function information to be parsed.
1719+
Each value in kwargs can be either:
1720+
- a tuple of (column, aggfunc)
1721+
- or a NamedAgg instance, which may also include additional *args and **kwargs
1722+
to be passed to the aggregation function.
1723+
17181724
If named aggregation is not applied, `func` is either string (e.g. 'min') or
17191725
Callable, or list of them (e.g. ['min', np.max]), or the dictionary of column name
17201726
and str/Callable/list of them (e.g. {'A': 'min'}, or {'A': [np.min, lambda x: x]})
@@ -1727,8 +1733,9 @@ def reconstruct_func(
17271733
----------
17281734
func: agg function (e.g. 'min' or Callable) or list of agg functions
17291735
(e.g. ['min', np.max]) or dictionary (e.g. {'A': ['min', np.max]}).
1730-
**kwargs: dict, kwargs used in is_multi_agg_with_relabel and
1731-
normalize_keyword_aggregation function for relabelling
1736+
**kwargs : dict
1737+
Keyword arguments used in is_multi_agg_with_relabel and
1738+
normalize_keyword_aggregation function for relabelling.
17321739
17331740
Returns
17341741
-------
@@ -1745,7 +1752,6 @@ def reconstruct_func(
17451752
>>> reconstruct_func("min")
17461753
(False, 'min', None, None)
17471754
"""
1748-
from pandas.core.groupby.generic import NamedAgg
17491755

17501756
relabeling = func is None and (
17511757
is_multi_agg_with_relabel(**kwargs)
@@ -1776,10 +1782,10 @@ def reconstruct_func(
17761782
for key, val in kwargs.items():
17771783
if isinstance(val, NamedAgg):
17781784
aggfunc = val.aggfunc
1779-
if getattr(val, "args", ()) or getattr(val, "kwargs", {}):
1780-
a = getattr(val, "args", ())
1781-
kw = getattr(val, "kwargs", {})
1782-
aggfunc = lambda x, func=aggfunc, a=a, kw=kw: func(x, *a, **kw)
1785+
if val.args or val.kwargs:
1786+
aggfunc = lambda x, func=aggfunc, a=val.args, kw=val.kwargs: func(
1787+
x, *a, **kw
1788+
)
17831789
converted_kwargs[key] = (val.column, aggfunc)
17841790
else:
17851791
converted_kwargs[key] = val

pandas/core/groupby/generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class NamedAgg:
125125
aggfunc : function or str
126126
Function to apply to the provided column. If string, the name of a built-in
127127
pandas function.
128-
*args, **kwargs :
128+
*args, **kwargs : Any
129129
Optional positional and keyword arguments passed to ``aggfunc``.
130130
131131
See Also
@@ -163,7 +163,7 @@ class NamedAgg:
163163

164164
column: Hashable
165165
aggfunc: AggScalar
166-
args: tuple[Any, ...] = dataclasses.field(default_factory=tuple)
166+
args: tuple[Any, ...] = ()
167167
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
168168

169169
def __init__(

pandas/tests/groupby/aggregate/test_aggregate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ def n_between(self, ser, low, high, **kwargs):
870870
return ser.between(low, high, **kwargs).sum()
871871

872872
def test_namedagg_args(self):
873+
# https://github.com/pandas-dev/pandas/issues/58283
873874
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
874875

875876
result = df.groupby("A").agg(
@@ -879,6 +880,7 @@ def test_namedagg_args(self):
879880
tm.assert_frame_equal(result, expected)
880881

881882
def test_namedagg_kwargs(self):
883+
# https://github.com/pandas-dev/pandas/issues/58283
882884
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
883885

884886
result = df.groupby("A").agg(
@@ -890,6 +892,7 @@ def test_namedagg_kwargs(self):
890892
tm.assert_frame_equal(result, expected)
891893

892894
def test_namedagg_args_and_kwargs(self):
895+
# https://github.com/pandas-dev/pandas/issues/58283
893896
df = DataFrame({"A": [0, 0, 1, 1], "B": [-1, 0, 1, 2]})
894897

895898
result = df.groupby("A").agg(
@@ -903,17 +906,21 @@ def test_namedagg_args_and_kwargs(self):
903906
tm.assert_frame_equal(result, expected)
904907

905908
def test_multiple_named_agg_with_args_and_kwargs(self):
909+
# https://github.com/pandas-dev/pandas/issues/58283
906910
df = DataFrame({"A": [0, 1, 2, 3], "B": [1, 2, 3, 4]})
907911

908912
result = df.groupby("A").agg(
909913
n_between01=pd.NamedAgg("B", self.n_between, 0, 1),
910914
n_between13=pd.NamedAgg("B", self.n_between, 1, 3),
911915
n_between02=pd.NamedAgg("B", self.n_between, 0, 2),
912916
)
913-
expected = df.groupby("A").agg(
914-
n_between01=("B", lambda x: x.between(0, 1).sum()),
915-
n_between13=("B", lambda x: x.between(0, 3).sum()),
916-
n_between02=("B", lambda x: x.between(0, 2).sum()),
917+
expected = DataFrame(
918+
{
919+
"n_between01": [2, 0],
920+
"n_between13": [2, 1],
921+
"n_between02": [2, 1],
922+
},
923+
index=Index(["a", "b"], name="A"),
917924
)
918925
tm.assert_frame_equal(result, expected)
919926

0 commit comments

Comments
 (0)