@@ -804,8 +804,7 @@ def test_update_stratified_split():
804804 Dataset .from_dataframe (pd .DataFrame (dict (
805805 index = np .arange (100 ),
806806 number = np .random .randn (100 ),
807- stratify1 = np .random .randint (0 , 10 , 100 ),
808- stratify2 = np .random .randint (0 , 10 , 100 ),
807+ stratify = np .random .randint (0 , 10 , 100 ),
809808 )))
810809 .map (tuple )
811810 )
@@ -819,7 +818,7 @@ def test_update_stratified_split():
819818 key_column = 'index' ,
820819 proportions = dict (train = 0.8 , test = 0.2 ),
821820 filepath = filepath ,
822- stratify_column = 'stratify1 ' ,
821+ stratify_column = 'stratify ' ,
823822 )
824823 )
825824
@@ -829,7 +828,7 @@ def test_update_stratified_split():
829828 key_column = 'index' ,
830829 proportions = dict (train = 0.8 , test = 0.2 ),
831830 filepath = filepath ,
832- stratify_column = 'stratify2 ' ,
831+ stratify_column = 'stratify ' ,
833832 )
834833 )
835834
@@ -840,8 +839,8 @@ def test_update_stratified_split():
840839 )
841840
842841 assert (
843- splits1 ['compare ' ].dataframe ['index' ]
844- .isin (splits2 ['compare ' ].dataframe ['index' ])
842+ splits1 ['test ' ].dataframe ['index' ]
843+ .isin (splits2 ['test ' ].dataframe ['index' ])
845844 .all ()
846845 )
847846
0 commit comments