@@ -33,7 +33,8 @@ def ser():
3333 ["max" , np .array ([2 , 6 , 7 , 4 , np .nan , 4 , 2 , 8 , np .nan , 6 ])],
3434 ["first" , np .array ([1 , 5 , 7 , 3 , np .nan , 4 , 2 , 8 , np .nan , 6 ])],
3535 ["dense" , np .array ([1 , 3 , 4 , 2 , np .nan , 2 , 1 , 5 , np .nan , 3 ])],
36- ]
36+ ],
37+ ids = lambda x : x [0 ],
3738)
3839def results (request ):
3940 return request .param
@@ -48,12 +49,29 @@ def results(request):
4849 "Int64" ,
4950 pytest .param ("float64[pyarrow]" , marks = td .skip_if_no ("pyarrow" )),
5051 pytest .param ("int64[pyarrow]" , marks = td .skip_if_no ("pyarrow" )),
52+ pytest .param ("string[pyarrow]" , marks = td .skip_if_no ("pyarrow" )),
53+ "string[python]" ,
54+ "str" ,
5155 ]
5256)
5357def dtype (request ):
5458 return request .param
5559
5660
61+ def expected_dtype (dtype , method , pct = False ):
62+ exp_dtype = "float64"
63+ # elif dtype in ["Int64", "Float64", "string[pyarrow]", "string[python]"]:
64+ if dtype in ["string[pyarrow]" ]:
65+ exp_dtype = "Float64"
66+ elif dtype in ["float64[pyarrow]" , "int64[pyarrow]" ]:
67+ if method == "average" or pct :
68+ exp_dtype = "double[pyarrow]"
69+ else :
70+ exp_dtype = "uint64[pyarrow]"
71+
72+ return exp_dtype
73+
74+
5775class TestSeriesRank :
5876 def test_rank (self , datetime_series ):
5977 sp_stats = pytest .importorskip ("scipy.stats" )
@@ -241,12 +259,14 @@ def test_rank_signature(self):
241259 with pytest .raises (ValueError , match = msg ):
242260 s .rank ("average" )
243261
244- @pytest .mark .parametrize ("dtype" , [None , object ])
245- def test_rank_tie_methods (self , ser , results , dtype ):
262+ def test_rank_tie_methods (self , ser , results , dtype , using_infer_string ):
246263 method , exp = results
264+ if dtype == "int64" or (not using_infer_string and dtype == "str" ):
265+ pytest .skip ("int64/str does not support NaN" )
266+
247267 ser = ser if dtype is None else ser .astype (dtype )
248268 result = ser .rank (method = method )
249- tm .assert_series_equal (result , Series (exp ))
269+ tm .assert_series_equal (result , Series (exp , dtype = expected_dtype ( dtype , method ) ))
250270
251271 @pytest .mark .parametrize ("ascending" , [True , False ])
252272 @pytest .mark .parametrize ("method" , ["average" , "min" , "max" , "first" , "dense" ])
@@ -346,25 +366,35 @@ def test_rank_methods_series(self, method, op, value):
346366 ],
347367 )
348368 def test_rank_dense_method (self , dtype , ser , exp ):
369+ if ser [0 ] < 0 and dtype .startswith ("str" ):
370+ exp = exp [::- 1 ]
349371 s = Series (ser ).astype (dtype )
350372 result = s .rank (method = "dense" )
351- expected = Series (exp ).astype (result . dtype )
373+ expected = Series (exp ).astype (expected_dtype ( dtype , "dense" ) )
352374 tm .assert_series_equal (result , expected )
353375
354- def test_rank_descending (self , ser , results , dtype ):
376+ def test_rank_descending (self , ser , results , dtype , using_infer_string ):
355377 method , _ = results
356- if "i" in dtype :
378+ if dtype == "int64" or ( not using_infer_string and dtype == "str" ) :
357379 s = ser .dropna ()
358380 else :
359381 s = ser .astype (dtype )
360382
361383 res = s .rank (ascending = False )
362- expected = (s .max () - s ).rank ()
363- tm .assert_series_equal (res , expected )
384+ if dtype .startswith ("str" ):
385+ expected = (s .astype ("float64" ).max () - s .astype ("float64" )).rank ()
386+ else :
387+ expected = (s .max () - s ).rank ()
388+ tm .assert_series_equal (res , expected .astype (expected_dtype (dtype , "average" )))
364389
365- expected = (s .max () - s ).rank (method = method )
390+ if dtype .startswith ("str" ):
391+ expected = (s .astype ("float64" ).max () - s .astype ("float64" )).rank (
392+ method = method
393+ )
394+ else :
395+ expected = (s .max () - s ).rank (method = method )
366396 res2 = s .rank (method = method , ascending = False )
367- tm .assert_series_equal (res2 , expected )
397+ tm .assert_series_equal (res2 , expected . astype ( expected_dtype ( dtype , method )) )
368398
369399 def test_rank_int (self , ser , results ):
370400 method , exp = results
@@ -421,9 +451,11 @@ def test_rank_ea_small_values(self):
421451 ],
422452)
423453def test_rank_dense_pct (dtype , ser , exp ):
454+ if ser [0 ] < 0 and dtype .startswith ("str" ):
455+ exp = exp [::- 1 ]
424456 s = Series (ser ).astype (dtype )
425457 result = s .rank (method = "dense" , pct = True )
426- expected = Series (exp ).astype (result . dtype )
458+ expected = Series (exp ).astype (expected_dtype ( dtype , "dense" , pct = True ) )
427459 tm .assert_series_equal (result , expected )
428460
429461
@@ -442,9 +474,11 @@ def test_rank_dense_pct(dtype, ser, exp):
442474 ],
443475)
444476def test_rank_min_pct (dtype , ser , exp ):
477+ if ser [0 ] < 0 and dtype .startswith ("str" ):
478+ exp = exp [::- 1 ]
445479 s = Series (ser ).astype (dtype )
446480 result = s .rank (method = "min" , pct = True )
447- expected = Series (exp ).astype (result . dtype )
481+ expected = Series (exp ).astype (expected_dtype ( dtype , "min" , pct = True ) )
448482 tm .assert_series_equal (result , expected )
449483
450484
@@ -463,9 +497,11 @@ def test_rank_min_pct(dtype, ser, exp):
463497 ],
464498)
465499def test_rank_max_pct (dtype , ser , exp ):
500+ if ser [0 ] < 0 and dtype .startswith ("str" ):
501+ exp = exp [::- 1 ]
466502 s = Series (ser ).astype (dtype )
467503 result = s .rank (method = "max" , pct = True )
468- expected = Series (exp ).astype (result . dtype )
504+ expected = Series (exp ).astype (expected_dtype ( dtype , "max" , pct = True ) )
469505 tm .assert_series_equal (result , expected )
470506
471507
@@ -484,9 +520,11 @@ def test_rank_max_pct(dtype, ser, exp):
484520 ],
485521)
486522def test_rank_average_pct (dtype , ser , exp ):
523+ if ser [0 ] < 0 and dtype .startswith ("str" ):
524+ exp = exp [::- 1 ]
487525 s = Series (ser ).astype (dtype )
488526 result = s .rank (method = "average" , pct = True )
489- expected = Series (exp ).astype (result . dtype )
527+ expected = Series (exp ).astype (expected_dtype ( dtype , "average" , pct = True ) )
490528 tm .assert_series_equal (result , expected )
491529
492530
@@ -505,9 +543,11 @@ def test_rank_average_pct(dtype, ser, exp):
505543 ],
506544)
507545def test_rank_first_pct (dtype , ser , exp ):
546+ if ser [0 ] < 0 and dtype .startswith ("str" ):
547+ exp = exp [::- 1 ]
508548 s = Series (ser ).astype (dtype )
509549 result = s .rank (method = "first" , pct = True )
510- expected = Series (exp ).astype (result . dtype )
550+ expected = Series (exp ).astype (expected_dtype ( dtype , "first" , pct = True ) )
511551 tm .assert_series_equal (result , expected )
512552
513553
0 commit comments