@@ -70,10 +70,18 @@ def test_check_sampling_strategy_warning():
7070 }, multiclass_target , 'clean-sampling' )
7171
7272
73- def test_check_sampling_strategy_float_error ():
74- msg = "'clean-sampling' methods do let the user specify the sampling ratio"
75- with pytest .raises (ValueError , match = msg ):
76- check_sampling_strategy (0.5 , binary_target , 'clean-sampling' )
73+ @pytest .mark .parametrize (
74+ "ratio, y, type, err_msg" ,
75+ [(0.5 , binary_target , 'clean-sampling' ,
76+ "'clean-sampling' methods do let the user specify the sampling ratio" ),
77+ (0.1 , np .array ([0 ] * 10 + [1 ] * 20 ), 'over-sampling' ,
78+ "remove samples from the minority class while trying to generate new" ),
79+ (0.1 , np .array ([0 ] * 10 + [1 ] * 20 ), 'under-sampling' ,
80+ "generate new sample in the majority class while trying to remove" )]
81+ )
82+ def test_check_sampling_strategy_float_error (ratio , y , type , err_msg ):
83+ with pytest .raises (ValueError , match = err_msg ):
84+ check_sampling_strategy (ratio , y , type )
7785
7886
7987def test_check_sampling_strategy_error ():
@@ -329,9 +337,9 @@ def test_check_ratio(ratio, sampling_type, expected_ratio, target):
329337def test_sampling_strategy_dict_over_sampling ():
330338 y = np .array ([1 ] * 50 + [2 ] * 100 + [3 ] * 25 )
331339 sampling_strategy = {1 : 70 , 2 : 140 , 3 : 70 }
332- expected_msg = ("After over-sampling, the number of samples \(140\) in"
333- " class 2 will be larger than the number of samples in the "
334- " majority class \(class #2 -> 100\)" )
340+ expected_msg = (r "After over-sampling, the number of samples \(140\) in"
341+ r " class 2 will be larger than the number of samples in"
342+ r" the majority class \(class #2 -> 100\)" )
335343 with warns (UserWarning , expected_msg ):
336344 check_sampling_strategy (sampling_strategy , y , 'over-sampling' )
337345
0 commit comments