33import itertools
44import warnings
55from functools import partial , reduce
6- from typing import TYPE_CHECKING
6+ from typing import TYPE_CHECKING , Callable
77
88import numpy as np
99import pandas as pd
1010import pytest
1111from numpy_groupies .aggregate_numpy import aggregate
1212
13+ from flox import xrutils
1314from flox .aggregations import Aggregation
1415from flox .core import (
1516 _convert_expected_groups_to_index ,
@@ -53,6 +54,7 @@ def dask_array_ones(*args):
5354 "sum" ,
5455 "nansum" ,
5556 "argmax" ,
57+ "nanfirst" ,
5658 pytest .param ("nanargmax" , marks = (pytest .mark .skip ,)),
5759 "prod" ,
5860 "nanprod" ,
@@ -70,6 +72,7 @@ def dask_array_ones(*args):
7072 pytest .param ("nanargmin" , marks = (pytest .mark .skip ,)),
7173 "any" ,
7274 "all" ,
75+ "nanlast" ,
7376 pytest .param ("median" , marks = (pytest .mark .skip ,)),
7477 pytest .param ("nanmedian" , marks = (pytest .mark .skip ,)),
7578)
@@ -78,6 +81,21 @@ def dask_array_ones(*args):
7881 from flox .core import T_Engine , T_ExpectedGroupsOpt , T_Func2
7982
8083
84+ def _get_array_func (func : str ) -> Callable :
85+ if func == "count" :
86+
87+ def npfunc (x ):
88+ x = np .asarray (x )
89+ return (~ np .isnan (x )).sum ()
90+
91+ elif func in ["nanfirst" , "nanlast" ]:
92+ npfunc = getattr (xrutils , func )
93+ else :
94+ npfunc = getattr (np , func )
95+
96+ return npfunc
97+
98+
8199def test_alignment_error ():
82100 da = np .ones ((12 ,))
83101 labels = np .ones ((5 ,))
@@ -217,6 +235,10 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
217235 if "arg" in func and add_nan_by :
218236 array_ [..., nanmask ] = np .nan
219237 expected = getattr (np , "nan" + func )(array_ , axis = - 1 , ** kwargs )
238+ # elif func in ["first", "last"]:
239+ # expected = getattr(xrutils, f"nan{func}")(array_[..., ~nanmask], axis=-1, **kwargs)
240+ elif func in ["nanfirst" , "nanlast" ]:
241+ expected = getattr (xrutils , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
220242 else :
221243 expected = getattr (np , func )(array_ [..., ~ nanmask ], axis = - 1 , ** kwargs )
222244 for _ in range (nby ):
@@ -241,7 +263,7 @@ def test_groupby_reduce_all(nby, size, chunks, func, add_nan_by, engine):
241263 call = partial (
242264 groupby_reduce , array , * by , method = method , reindex = reindex , ** flox_kwargs
243265 )
244- if "arg" in func and reindex is True :
266+ if ( "arg" in func or func in [ "first" , "last" ]) and reindex is True :
245267 # simple_combine with argreductions not supported right now
246268 with pytest .raises (NotImplementedError ):
247269 call ()
@@ -486,6 +508,28 @@ def test_dask_reduce_axis_subset():
486508 )
487509
488510
511+ @pytest .mark .parametrize ("func" , ["first" , "last" , "nanfirst" , "nanlast" ])
512+ @pytest .mark .parametrize ("axis" , [(0 , 1 )])
513+ def test_first_last_disallowed (axis , func ):
514+ with pytest .raises (ValueError ):
515+ groupby_reduce (np .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
516+
517+
518+ @requires_dask
519+ @pytest .mark .parametrize ("func" , ["nanfirst" , "nanlast" ])
520+ @pytest .mark .parametrize ("axis" , [None , (0 , 1 , 2 )])
521+ def test_nanfirst_nanlast_disallowed_dask (axis , func ):
522+ with pytest .raises (ValueError ):
523+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = axis )
524+
525+
526+ @requires_dask
527+ @pytest .mark .parametrize ("func" , ["first" , "last" ])
528+ def test_first_last_disallowed_dask (func ):
529+ with pytest .raises (NotImplementedError ):
530+ groupby_reduce (dask .array .empty ((2 , 3 , 2 )), np .ones ((2 , 3 , 2 )), func = func , axis = - 1 )
531+
532+
489533@requires_dask
490534@pytest .mark .parametrize ("func" , ALL_FUNCS )
491535@pytest .mark .parametrize (
@@ -495,8 +539,34 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
495539 if "arg" in func and engine == "flox" :
496540 pytest .skip ()
497541
498- if not isinstance (axis , int ) and "arg" in func and (axis is None or len (axis ) > 1 ):
499- pytest .skip ()
542+ if not isinstance (axis , int ):
543+ if "arg" in func and (axis is None or len (axis ) > 1 ):
544+ pytest .skip ()
545+ if ("first" in func or "last" in func ) and (axis is not None and len (axis ) not in [1 , 3 ]):
546+ pytest .skip ()
547+
548+ if func in ["all" , "any" ]:
549+ fill_value = False
550+ else :
551+ fill_value = 123
552+
553+ if "var" in func or "std" in func :
554+ tolerance = {"rtol" : 1e-14 , "atol" : 1e-16 }
555+ else :
556+ tolerance = None
557+ # tests against the numpy output to make sure dask compute matches
558+ by = np .broadcast_to (labels2d , (3 , * labels2d .shape ))
559+ rng = np .random .default_rng (12345 )
560+ array = rng .random (by .shape )
561+ kwargs = dict (
562+ func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
563+ )
564+ expected , _ = groupby_reduce (array , by , ** kwargs )
565+ if engine == "flox" :
566+ kwargs .pop ("engine" )
567+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
568+ assert_equal (expected_npg , expected )
569+
500570 if func in ["all" , "any" ]:
501571 fill_value = False
502572 else :
@@ -513,17 +583,23 @@ def test_groupby_reduce_axis_subset_against_numpy(func, axis, engine):
513583 kwargs = dict (
514584 func = func , axis = axis , expected_groups = [0 , 2 ], fill_value = fill_value , engine = engine
515585 )
586+ expected , _ = groupby_reduce (array , by , ** kwargs )
587+ if engine == "flox" :
588+ kwargs .pop ("engine" )
589+ expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
590+ assert_equal (expected_npg , expected )
591+
592+ if ("first" in func or "last" in func ) and (
593+ axis is None or (not isinstance (axis , int ) and len (axis ) != 1 )
594+ ):
595+ return
596+
516597 with raise_if_dask_computes ():
517598 actual , _ = groupby_reduce (
518599 da .from_array (array , chunks = (- 1 , 2 , 3 )),
519600 da .from_array (by , chunks = (- 1 , 2 , 2 )),
520601 ** kwargs ,
521602 )
522- expected , _ = groupby_reduce (array , by , ** kwargs )
523- if engine == "flox" :
524- kwargs .pop ("engine" )
525- expected_npg , _ = groupby_reduce (array , by , ** kwargs , engine = "numpy" )
526- assert_equal (expected_npg , expected )
527603 assert_equal (actual , expected , tolerance )
528604
529605
@@ -751,23 +827,17 @@ def test_fill_value_behaviour(func, chunks, fill_value, engine):
751827 if chunks is not None and not has_dask :
752828 pytest .skip ()
753829
754- if func == "count" :
755-
756- def npfunc (x ):
757- x = np .asarray (x )
758- return (~ np .isnan (x )).sum ()
759-
760- else :
761- npfunc = getattr (np , func )
762-
830+ npfunc = _get_array_func (func )
763831 by = np .array ([1 , 2 , 3 , 1 , 2 , 3 ])
764832 array = np .array ([np .nan , 1 , 1 , np .nan , 1 , 1 ])
765833 if chunks :
766834 array = dask .array .from_array (array , chunks )
767835 actual , _ = groupby_reduce (
768836 array , by , func = func , engine = engine , fill_value = fill_value , expected_groups = [0 , 1 , 2 , 3 ]
769837 )
770- expected = np .array ([fill_value , fill_value , npfunc ([1.0 , 1.0 ]), npfunc ([1.0 , 1.0 ])])
838+ expected = np .array (
839+ [fill_value , fill_value , npfunc ([1.0 , 1.0 ], axis = 0 ), npfunc ([1.0 , 1.0 ], axis = 0 )]
840+ )
771841 assert_equal (actual , expected )
772842
773843
@@ -832,6 +902,8 @@ def test_cohorts_nd_by(func, method, axis, engine):
832902
833903 if axis is not None and method != "map-reduce" :
834904 pytest .xfail ()
905+ if axis is None and ("first" in func or "last" in func ):
906+ pytest .skip ()
835907
836908 kwargs = dict (func = func , engine = engine , method = method , axis = axis , fill_value = fill_value )
837909 actual , groups = groupby_reduce (array , by , ** kwargs )
@@ -897,7 +969,8 @@ def test_bool_reductions(func, engine):
897969 pytest .skip ()
898970 groups = np .array ([1 , 1 , 1 ])
899971 data = np .array ([True , True , False ])
900- expected = np .expand_dims (getattr (np , func )(data ), - 1 )
972+ npfunc = _get_array_func (func )
973+ expected = np .expand_dims (npfunc (data , axis = 0 ), - 1 )
901974 actual , _ = groupby_reduce (data , groups , func = func , engine = engine )
902975 assert_equal (expected , actual )
903976
0 commit comments