1818from pandas .tests .frame .common import zip_frames
1919
2020
21+ @pytest .fixture (params = ["python" , "numba" ])
22+ def engine (request ):
23+ if request .param == "numba" :
24+ pytest .importorskip ("numba" )
25+ return request .param
26+
27+
2128def test_apply (float_frame ):
2229 with np .errstate (all = "ignore" ):
2330 # ufunc
@@ -234,36 +241,42 @@ def test_apply_broadcast_series_lambda_func(int_frame_const_col):
234241
235242
236243@pytest .mark .parametrize ("axis" , [0 , 1 ])
237- def test_apply_raw_float_frame (float_frame , axis ):
244+ def test_apply_raw_float_frame (float_frame , axis , engine ):
245+ if engine == "numba" :
246+ pytest .skip ("numba can't handle when UDF returns None." )
247+
238248 def _assert_raw (x ):
239249 assert isinstance (x , np .ndarray )
240250 assert x .ndim == 1
241251
242- float_frame .apply (_assert_raw , axis = axis , raw = True )
252+ float_frame .apply (_assert_raw , axis = axis , engine = engine , raw = True )
243253
244254
245255@pytest .mark .parametrize ("axis" , [0 , 1 ])
246- def test_apply_raw_float_frame_lambda (float_frame , axis ):
247- result = float_frame .apply (np .mean , axis = axis , raw = True )
256+ def test_apply_raw_float_frame_lambda (float_frame , axis , engine ):
257+ result = float_frame .apply (np .mean , axis = axis , engine = engine , raw = True )
248258 expected = float_frame .apply (lambda x : x .values .mean (), axis = axis )
249259 tm .assert_series_equal (result , expected )
250260
251261
252- def test_apply_raw_float_frame_no_reduction (float_frame ):
262+ def test_apply_raw_float_frame_no_reduction (float_frame , engine ):
253263 # no reduction
254- result = float_frame .apply (lambda x : x * 2 , raw = True )
264+ result = float_frame .apply (lambda x : x * 2 , engine = engine , raw = True )
255265 expected = float_frame * 2
256266 tm .assert_frame_equal (result , expected )
257267
258268
259269@pytest .mark .parametrize ("axis" , [0 , 1 ])
260- def test_apply_raw_mixed_type_frame (mixed_type_frame , axis ):
270+ def test_apply_raw_mixed_type_frame (mixed_type_frame , axis , engine ):
271+ if engine == "numba" :
272+ pytest .skip ("isinstance check doesn't work with numba" )
273+
261274 def _assert_raw (x ):
262275 assert isinstance (x , np .ndarray )
263276 assert x .ndim == 1
264277
265278 # Mixed dtype (GH-32423)
266- mixed_type_frame .apply (_assert_raw , axis = axis , raw = True )
279+ mixed_type_frame .apply (_assert_raw , axis = axis , engine = engine , raw = True )
267280
268281
269282def test_apply_axis1 (float_frame ):
@@ -300,14 +313,20 @@ def test_apply_mixed_dtype_corner_indexing():
300313)
301314@pytest .mark .parametrize ("raw" , [True , False ])
302315@pytest .mark .parametrize ("axis" , [0 , 1 ])
303- def test_apply_empty_infer_type (ax , func , raw , axis ):
316+ def test_apply_empty_infer_type (ax , func , raw , axis , engine , request ):
304317 df = DataFrame (** {ax : ["a" , "b" , "c" ]})
305318
306319 with np .errstate (all = "ignore" ):
307320 test_res = func (np .array ([], dtype = "f8" ))
308321 is_reduction = not isinstance (test_res , np .ndarray )
309322
310- result = df .apply (func , axis = axis , raw = raw )
323+ if engine == "numba" and raw is False :
324+ mark = pytest .mark .xfail (
325+ reason = "numba engine only supports raw=True at the moment"
326+ )
327+ request .node .add_marker (mark )
328+
329+ result = df .apply (func , axis = axis , engine = engine , raw = raw )
311330 if is_reduction :
312331 agg_axis = df ._get_agg_axis (axis )
313332 assert isinstance (result , Series )
@@ -607,8 +626,10 @@ def non_reducing_function(row):
607626 assert names == list (df .index )
608627
609628
610- def test_apply_raw_function_runs_once ():
629+ def test_apply_raw_function_runs_once (engine ):
611630 # https://github.com/pandas-dev/pandas/issues/34506
631+ if engine == "numba" :
632+ pytest .skip ("appending to list outside of numba func is not supported" )
612633
613634 df = DataFrame ({"a" : [1 , 2 , 3 ]})
614635 values = [] # Save row values function is applied to
@@ -623,7 +644,7 @@ def non_reducing_function(row):
623644 for func in [reducing_function , non_reducing_function ]:
624645 del values [:]
625646
626- df .apply (func , raw = True , axis = 1 )
647+ df .apply (func , engine = engine , raw = True , axis = 1 )
627648 assert values == list (df .a .to_list ())
628649
629650
@@ -1449,10 +1470,12 @@ def test_apply_no_suffix_index():
14491470 tm .assert_frame_equal (result , expected )
14501471
14511472
1452- def test_apply_raw_returns_string ():
1473+ def test_apply_raw_returns_string (engine ):
14531474 # https://github.com/pandas-dev/pandas/issues/35940
1475+ if engine == "numba" :
1476+ pytest .skip ("No object dtype support in numba" )
14541477 df = DataFrame ({"A" : ["aa" , "bbb" ]})
1455- result = df .apply (lambda x : x [0 ], axis = 1 , raw = True )
1478+ result = df .apply (lambda x : x [0 ], engine = engine , axis = 1 , raw = True )
14561479 expected = Series (["aa" , "bbb" ])
14571480 tm .assert_series_equal (result , expected )
14581481
@@ -1632,3 +1655,14 @@ def test_agg_dist_like_and_nonunique_columns():
16321655 result = df .agg ({"A" : "count" })
16331656 expected = df ["A" ].count ()
16341657 tm .assert_series_equal (result , expected )
1658+
1659+
1660+ def test_numba_unsupported ():
1661+ df = DataFrame (
1662+ {"A" : [None , 2 , 3 ], "B" : [1.0 , np .nan , 3.0 ], "C" : ["foo" , None , "bar" ]}
1663+ )
1664+ with pytest .raises (
1665+ ValueError ,
1666+ match = "The numba engine in DataFrame.apply can only be used when raw=True" ,
1667+ ):
1668+ df .apply (lambda x : x , engine = "numba" , raw = False )
0 commit comments