@@ -800,48 +800,51 @@ def test_split_filepath():
800800
801801def test_update_stratified_split ():
802802
803- dataset = (
804- Dataset .from_dataframe (pd .DataFrame (dict (
805- index = np .arange (100 ),
806- number = np .random .randn (100 ),
807- stratify = np .random .randint (0 , 10 , 100 ),
808- )))
809- .map (tuple )
810- )
803+ for _ in range (5 ):
811804
812- filepath = Path ('tmp_test_split.json' )
805+ dataset = (
806+ Dataset .from_dataframe (pd .DataFrame (dict (
807+ index = np .arange (100 ),
808+ number = np .random .randn (100 ),
809+ stratify1 = np .random .randint (0 , 10 , 100 ),
810+ stratify2 = np .random .randint (0 , 10 , 100 ),
811+ )))
812+ .map (tuple )
813+ )
813814
814- splits1 = (
815- dataset
816- .subset (lambda df : df ['index' ] < 50 )
817- .split (
818- key_column = 'index' ,
819- proportions = dict (train = 0.8 , test = 0.2 ),
820- filepath = filepath ,
821- stratify_column = 'stratify' ,
815+ filepath = Path ('tmp_test_split.json' )
816+
817+ splits1 = (
818+ dataset
819+ .subset (lambda df : df ['index' ] < 50 )
820+ .split (
821+ key_column = 'index' ,
822+ proportions = dict (train = 0.8 , test = 0.2 ),
823+ filepath = filepath ,
824+ stratify_column = 'stratify1' ,
825+ )
822826 )
823- )
824827
825- splits2 = (
826- dataset
827- .split (
828- key_column = 'index' ,
829- proportions = dict (train = 0.8 , test = 0.2 ),
830- filepath = filepath ,
831- stratify_column = 'stratify' ,
828+ splits2 = (
829+ dataset
830+ .split (
831+ key_column = 'index' ,
832+ proportions = dict (train = 0.8 , test = 0.2 ),
833+ filepath = filepath ,
834+ stratify_column = 'stratify2' ,
835+ )
832836 )
833- )
834837
835- assert (
836- splits1 ['train' ].dataframe ['index' ]
837- .isin (splits2 ['train' ].dataframe ['index' ])
838- .all ()
839- )
838+ assert (
839+ splits1 ['train' ].dataframe ['index' ]
840+ .isin (splits2 ['train' ].dataframe ['index' ])
841+ .all ()
842+ )
840843
841- assert (
842- splits1 ['test' ].dataframe ['index' ]
843- .isin (splits2 ['test' ].dataframe ['index' ])
844- .all ()
845- )
844+ assert (
845+ splits1 ['test' ].dataframe ['index' ]
846+ .isin (splits2 ['test' ].dataframe ['index' ])
847+ .all ()
848+ )
846849
847- filepath .unlink ()
850+ filepath .unlink ()
0 commit comments