55)
66from pathlib import Path
77from functools import lru_cache
8+ import warnings
89import textwrap
910import inspect
1011import numpy as np
@@ -218,15 +219,15 @@ def split(
218219 key_column : str ,
219220 proportions : Dict [str , float ],
220221 stratify_column : Optional [str ] = None ,
221- filepath : Optional [Union [str , Path ]] = None ,
222+ save_directory : Optional [Union [str , Path ]] = None ,
222223 frozen : Optional [bool ] = False ,
223224 seed : Optional [int ] = None ,
224225 ) -> Dict [str , Dataset [T ]]:
225226 '''
226227 Split dataset into multiple parts. Optionally you can chose to stratify
227228 on a column in the source dataframe or save the split to a json file.
228229 If you are sure that the split strategy will not change then you can
229- safely use a seed instead of a filepath .
230+ safely use a seed instead of a save_directory .
230231
231232 Saved splits can continue from the old split and handles:
232233
@@ -252,14 +253,40 @@ def split(
252253 >>> split_datasets['test'][0]
253254 3
254255 '''
255- if filepath is not None :
256- filepath = Path (filepath )
257-
258- if seed is None :
259- split_dataframes = tools .split_dataframes
256+ if save_directory is not None :
257+ save_directory = Path (save_directory )
258+ save_directory .mkdir (parents = True , exist_ok = True )
259+
260+ if stratify_column is not None :
261+ return self ._stratified_split (
262+ key_column = key_column ,
263+ proportions = proportions ,
264+ stratify_column = stratify_column ,
265+ save_directory = save_directory ,
266+ seed = seed ,
267+ frozen = frozen ,
268+ )
260269 else :
261- split_dataframes = tools .numpy_seed (seed )(tools .split_dataframes )
270+ return self ._unstratified_split (
271+ key_column = key_column ,
272+ proportions = proportions ,
273+ filepath = (
274+ save_directory / 'split.json'
275+ if save_directory is not None else None
276+ ),
277+ seed = seed ,
278+ frozen = frozen ,
279+ )
262280
281+ def _unstratified_split (
282+ self ,
283+ key_column : str ,
284+ proportions : Dict [str , float ],
285+ filepath : Optional [Path ] = None ,
286+ seed : Optional [int ] = None ,
287+ frozen : Optional [bool ] = False ,
288+ ):
289+ split_dataframes = tools .numpy_seed (seed )(tools .split_dataframes )
263290 return {
264291 split_name : Dataset (
265292 dataframe = dataframe ,
@@ -270,63 +297,53 @@ def split(
270297 self .dataframe ,
271298 key_column ,
272299 proportions ,
273- stratify_column ,
274- filepath ,
275- frozen ,
300+ filepath = filepath ,
301+ frozen = frozen ,
276302 ).items ()
277303 }
278304
279- def group_split (
305+ def _stratified_split (
280306 self ,
281- split_column : str ,
307+ key_column : str ,
282308 proportions : Dict [str , float ],
283- filepath : Optional [Union [ str , Path ] ] = None ,
284- frozen : Optional [bool ] = False ,
309+ stratify_column : Optional [str ] = None ,
310+ save_directory : Optional [Path ] = None ,
285311 seed : Optional [int ] = None ,
286- ) -> Dict [str , Dataset [T ]]:
287- '''
288- Similar to :func:`Dataset.split`, but uses a non-unique split column
289- instead of a unique key column. This is useful for example when you
290- have a dataset with examples that come from separate sources and you
291- don't want to have examples from the same source in different splits.
292- Does not support stratification.
293-
294- >>> split_file = Path('doctest_split_dataset.json')
295- >>> split_datasets = (
296- ... Dataset.from_dataframe(pd.DataFrame(dict(
297- ... source=np.arange(100) // 4,
298- ... number=np.random.randn(100),
299- ... )))
300- ... .group_split(
301- ... split_column='source',
302- ... proportions=dict(train=0.8, test=0.2),
303- ... filepath=split_file,
304- ... )
305- ... )
306- >>> len(split_datasets['train'])
307- 80
308- >>> split_file.unlink() # clean up after doctest
309- '''
310- if filepath is not None :
311- filepath = Path (filepath )
312-
313- split_dataframes = tools .group_split_dataframes
314- if seed is not None :
315- split_dataframes = tools .numpy_seed (seed )(split_dataframes )
316-
312+ frozen : Optional [bool ] = False ,
313+ ):
314+ if (
315+ stratify_column is not None
316+ and any (self .dataframe [key_column ].duplicated ())
317+ ):
318+ # mathematically impossible in the general case
319+ warnings .warn (
320+ 'Trying to do stratified split with non-unique key column'
321+ ' - cannot guarantee correct splitting of key values.'
322+ )
323+ strata = {
324+ stratum_value : self .subset (
325+ lambda df : df [stratify_column ] == stratum_value
326+ )
327+ for stratum_value in self .dataframe [stratify_column ].unique ()
328+ }
329+ split_strata = [
330+ stratum ._unstratified_split (
331+ key_column = key_column ,
332+ proportions = proportions ,
333+ filepath = (
334+ save_directory / f'{ hash (stratum_value )} .json'
335+ if save_directory is not None else None
336+ ),
337+ seed = seed ,
338+ frozen = frozen ,
339+ )
340+ for stratum_value , stratum in strata .items ()
341+ ]
317342 return {
318- split_name : Dataset (
319- dataframe = dataframe ,
320- length = len (dataframe ),
321- get_item = self .get_item ,
343+ split_name : Dataset .concat (
344+ [split_stratum [split_name ] for split_stratum in split_strata ]
322345 )
323- for split_name , dataframe in split_dataframes (
324- self .dataframe ,
325- split_column ,
326- proportions ,
327- filepath ,
328- frozen ,
329- ).items ()
346+ for split_name in proportions .keys ()
330347 }
331348
332349 def with_columns (
@@ -672,13 +689,14 @@ def test_combine_dataset():
672689
673690
674691def test_split_dataset ():
692+ import shutil
675693 dataset = Dataset .from_dataframe (pd .DataFrame (dict (
676694 index = np .arange (100 ),
677695 number = np .random .randn (100 ),
678696 stratify = np .concatenate ([np .ones (50 ), np .zeros (50 )]),
679697 ))).map (tuple )
680698
681- split_file = Path ('test_split_dataset.json ' )
699+ save_directory = Path ('test_split_dataset' )
682700 proportions = dict (
683701 gradient = 0.7 ,
684702 early_stopping = 0.15 ,
@@ -688,7 +706,7 @@ def test_split_dataset():
688706 kwargs = dict (
689707 key_column = 'index' ,
690708 proportions = proportions ,
691- filepath = split_file ,
709+ save_directory = save_directory ,
692710 stratify_column = 'stratify' ,
693711 )
694712
@@ -712,8 +730,7 @@ def test_split_dataset():
712730 stratify_column = 'stratify' ,
713731 seed = 800 ,
714732 )
715-
716- split_file .unlink ()
733+ shutil .rmtree (save_directory )
717734
718735 assert split_datasets1 == split_datasets2
719736 assert split_datasets1 != split_datasets3
@@ -722,45 +739,128 @@ def test_split_dataset():
722739
723740
724741def test_group_split_dataset ():
742+ import shutil
725743 dataset = Dataset .from_dataframe (pd .DataFrame (dict (
726744 group = np .arange (100 ) // 4 ,
727745 number = np .random .randn (100 ),
728746 ))).map (tuple )
729747
730- split_file = Path ('test_split_dataset.json ' )
748+ save_directory = Path ('test_split_dataset' )
731749 proportions = dict (
732750 gradient = 0.7 ,
733751 early_stopping = 0.15 ,
734752 compare = 0.15 ,
735753 )
736754
737755 kwargs = dict (
738- split_column = 'group' ,
756+ key_column = 'group' ,
739757 proportions = proportions ,
740- filepath = split_file ,
758+ save_directory = save_directory ,
741759 )
742760
743- split_datasets1 = dataset .group_split (** kwargs )
744- split_datasets2 = dataset .group_split (** kwargs )
745- split_datasets3 = dataset .group_split (
746- split_column = 'group' ,
761+ split_datasets1 = dataset .split (** kwargs )
762+ split_datasets2 = dataset .split (** kwargs )
763+ split_datasets3 = dataset .split (
764+ key_column = 'group' ,
747765 proportions = proportions ,
748766 seed = 100 ,
749767 )
750- split_datasets4 = dataset .group_split (
751- split_column = 'group' ,
768+ split_datasets4 = dataset .split (
769+ key_column = 'group' ,
752770 proportions = proportions ,
753771 seed = 100 ,
754772 )
755- split_datasets5 = dataset .group_split (
756- split_column = 'group' ,
773+ split_datasets5 = dataset .split (
774+ key_column = 'group' ,
757775 proportions = proportions ,
758776 seed = 800 ,
759777 )
760778
761- split_file . unlink ( )
779+ shutil . rmtree ( save_directory )
762780
763781 assert split_datasets1 == split_datasets2
764782 assert split_datasets1 != split_datasets3
765783 assert split_datasets3 == split_datasets4
766784 assert split_datasets3 != split_datasets5
785+
786+
787+ def test_missing_stratify_column ():
788+ from pytest import raises
789+
790+ dataset = Dataset .from_dataframe (pd .DataFrame (dict (
791+ index = np .arange (100 ),
792+ number = np .random .randn (100 ),
793+ ))).map (tuple )
794+
795+ with raises (KeyError ):
796+ dataset .split (
797+ key_column = 'index' ,
798+ proportions = dict (train = 0.8 , test = 0.2 ),
799+ stratify_column = 'should_fail' ,
800+ )
801+
802+
803+ def test_split_proportions ():
804+ dataset = Dataset .from_dataframe (pd .DataFrame (dict (
805+ index = np .arange (100 ),
806+ number = np .random .randn (100 ),
807+ stratify = np .arange (100 ) // 10 ,
808+ ))).map (tuple )
809+
810+ splits = dataset .split (
811+ key_column = 'index' ,
812+ proportions = dict (train = 0.8 , test = 0.2 ),
813+ stratify_column = 'stratify' ,
814+ )
815+
816+ assert len (splits ['train' ]) == 80
817+ assert len (splits ['test' ]) == 20
818+
819+
820+ def test_with_columns_split ():
821+ dataset = (
822+ Dataset .from_dataframe (pd .DataFrame (dict (
823+ index = np .arange (100 ),
824+ number = np .arange (100 ),
825+ )))
826+ .map (tuple )
827+ .with_columns (split = lambda df : df ['index' ] * 2 )
828+ )
829+
830+ splits = dataset .split (
831+ key_column = 'index' ,
832+ proportions = dict (train = 0.8 , test = 0.2 ),
833+ )
834+
835+ assert splits ['train' ][0 ][0 ] * 2 == splits ['train' ][0 ][2 ]
836+
837+
838+ def test_split_save_directory ():
839+ import shutil
840+
841+ dataset = (
842+ Dataset .from_dataframe (pd .DataFrame (dict (
843+ index = np .arange (100 ),
844+ number = np .random .randn (100 ),
845+ stratify = np .arange (100 ) // 10 ,
846+ )))
847+ .map (tuple )
848+ )
849+
850+ save_directory = Path ('tmp_test_directory' )
851+ splits1 = dataset .split (
852+ key_column = 'index' ,
853+ proportions = dict (train = 0.8 , test = 0.2 ),
854+ save_directory = save_directory ,
855+ )
856+
857+ splits2 = dataset .split (
858+ key_column = 'index' ,
859+ proportions = dict (train = 0.8 , test = 0.2 ),
860+ save_directory = save_directory ,
861+ )
862+
863+ assert splits1 ['train' ][0 ] == splits2 ['train' ][0 ]
864+ assert splits1 ['test' ][0 ] == splits2 ['test' ][0 ]
865+
866+ shutil .rmtree (save_directory )
0 commit comments