@@ -223,15 +223,15 @@ def split(
223223 key_column : str ,
224224 proportions : Dict [str , float ],
225225 stratify_column : Optional [str ] = None ,
226- save_directory : Optional [Union [str , Path ]] = None ,
226+ filepath : Optional [Union [str , Path ]] = None ,
227227 frozen : Optional [bool ] = False ,
228228 seed : Optional [int ] = None ,
229229 ) -> Dict [str , Dataset [T ]]:
230230 '''
231231 Split dataset into multiple parts. Optionally you can chose to stratify
232232 on a column in the source dataframe or save the split to a json file.
233233 If you are sure that the split strategy will not change then you can
234- safely use a seed instead of a save_directory .
234+ safely use a seed instead of a filepath .
235235
236236 Saved splits can continue from the old split and handles:
237237
@@ -257,17 +257,17 @@ def split(
257257 >>> split_datasets['test'][0]
258258 3
259259 '''
260- if save_directory is not None :
261- save_directory = Path (save_directory )
262- save_directory .mkdir (parents = True , exist_ok = True )
260+ if filepath is not None :
261+ filepath = Path (filepath )
262+ filepath . parent .mkdir (parents = True , exist_ok = True )
263263
264264 if stratify_column is not None :
265265 return tools .stratified_split (
266266 self ,
267267 key_column = key_column ,
268268 proportions = proportions ,
269269 stratify_column = stratify_column ,
270- save_directory = save_directory ,
270+ filepath = filepath ,
271271 seed = seed ,
272272 frozen = frozen ,
273273 )
@@ -276,10 +276,7 @@ def split(
276276 self ,
277277 key_column = key_column ,
278278 proportions = proportions ,
279- filepath = (
280- save_directory / 'split.json'
281- if save_directory is not None else None
282- ),
279+ filepath = filepath ,
283280 seed = seed ,
284281 frozen = frozen ,
285282 )
@@ -627,14 +624,13 @@ def test_combine_dataset():
627624
628625
629626def test_split_dataset ():
630- import shutil
631627 dataset = Dataset .from_dataframe (pd .DataFrame (dict (
632628 index = np .arange (100 ),
633629 number = np .random .randn (100 ),
634630 stratify = np .concatenate ([np .ones (50 ), np .zeros (50 )]),
635631 ))).map (tuple )
636632
637- save_directory = Path ('test_split_dataset' )
633+ filepath = Path ('test_split_dataset.json ' )
638634 proportions = dict (
639635 gradient = 0.7 ,
640636 early_stopping = 0.15 ,
@@ -644,7 +640,7 @@ def test_split_dataset():
644640 kwargs = dict (
645641 key_column = 'index' ,
646642 proportions = proportions ,
647- save_directory = save_directory ,
643+ filepath = filepath ,
648644 stratify_column = 'stratify' ,
649645 )
650646
@@ -668,7 +664,7 @@ def test_split_dataset():
668664 stratify_column = 'stratify' ,
669665 seed = 800 ,
670666 )
671- shutil . rmtree ( save_directory )
667+ filepath . unlink ( )
672668
673669 assert split_datasets1 == split_datasets2
674670 assert split_datasets1 != split_datasets3
@@ -677,13 +673,12 @@ def test_split_dataset():
677673
678674
679675def test_group_split_dataset ():
680- import shutil
681676 dataset = Dataset .from_dataframe (pd .DataFrame (dict (
682677 group = np .arange (100 ) // 4 ,
683678 number = np .random .randn (100 ),
684679 ))).map (tuple )
685680
686- save_directory = Path ('test_split_dataset' )
681+ filepath = Path ('test_split_dataset.json ' )
687682 proportions = dict (
688683 gradient = 0.7 ,
689684 early_stopping = 0.15 ,
@@ -693,7 +688,7 @@ def test_group_split_dataset():
693688 kwargs = dict (
694689 key_column = 'group' ,
695690 proportions = proportions ,
696- save_directory = save_directory ,
691+ filepath = filepath ,
697692 )
698693
699694 split_datasets1 = dataset .split (** kwargs )
@@ -714,7 +709,7 @@ def test_group_split_dataset():
714709 seed = 800 ,
715710 )
716711
717- shutil . rmtree ( save_directory )
712+ filepath . unlink ( )
718713
719714 assert split_datasets1 == split_datasets2
720715 assert split_datasets1 != split_datasets3
@@ -773,8 +768,7 @@ def test_with_columns_split():
773768 assert splits ['train' ][0 ][0 ] * 2 == splits ['train' ][0 ][2 ]
774769
775770
776- def test_split_save_directory ():
777- import shutil
771+ def test_split_filepath ():
778772
779773 dataset = (
780774 Dataset .from_dataframe (pd .DataFrame (dict (
@@ -785,20 +779,70 @@ def test_split_save_directory():
785779 .map (tuple )
786780 )
787781
788- save_directory = Path ('tmp_test_directory ' )
782+ filepath = Path ('tmp_test_split.json ' )
789783 splits1 = dataset .split (
790784 key_column = 'index' ,
791785 proportions = dict (train = 0.8 , test = 0.2 ),
792- save_directory = save_directory ,
786+ filepath = filepath ,
793787 )
794788
795789 splits2 = dataset .split (
796790 key_column = 'index' ,
797791 proportions = dict (train = 0.8 , test = 0.2 ),
798- save_directory = save_directory ,
792+ filepath = filepath ,
799793 )
800794
801795 assert splits1 ['train' ][0 ] == splits2 ['train' ][0 ]
802796 assert splits1 ['test' ][0 ] == splits2 ['test' ][0 ]
803797
804- shutil .rmtree (save_directory )
798+ filepath .unlink ()
799+
800+
801+ def test_update_stratified_split ():
802+
803+ dataset = (
804+ Dataset .from_dataframe (pd .DataFrame (dict (
805+ index = np .arange (100 ),
806+ number = np .random .randn (100 ),
807+ stratify1 = np .random .randint (0 , 10 , 100 ),
808+ stratify2 = np .random .randint (0 , 10 , 100 ),
809+ )))
810+ .map (tuple )
811+ )
812+
813+ filepath = Path ('tmp_test_split.json' )
814+
815+ splits1 = (
816+ dataset
817+ .subset (lambda df : df ['index' ] < 50 )
818+ .split (
819+ key_column = 'index' ,
820+ proportions = dict (train = 0.8 , test = 0.2 ),
821+ filepath = filepath ,
822+ stratify_column = 'stratify1' ,
823+ )
824+ )
825+
826+ splits2 = (
827+ dataset
828+ .split (
829+ key_column = 'index' ,
830+ proportions = dict (train = 0.8 , test = 0.2 ),
831+ filepath = filepath ,
832+ stratify_column = 'stratify2' ,
833+ )
834+ )
835+
836+ assert (
837+ splits1 ['train' ].dataframe ['index' ]
838+ .isin (splits2 ['train' ].dataframe ['index' ])
839+ .all ()
840+ )
841+
842+ assert (
843+ splits1 ['compare' ].dataframe ['index' ]
844+ .isin (splits2 ['compare' ].dataframe ['index' ])
845+ .all ()
846+ )
847+
848+ filepath .unlink ()
0 commit comments