55)
66from pathlib import Path
77from functools import lru_cache
8- import warnings
98import textwrap
109import inspect
1110import numpy as np
@@ -117,6 +116,11 @@ def __eq__(self: Dataset[T], other: Dataset[R]) -> bool:
117116 return False
118117 return True
119118
119+ def replace (self , ** kwargs ):
120+ new_dict = self .dict ()
121+ new_dict .update (** kwargs )
122+ return type (self )(** new_dict )
123+
120124 def map (
121125 self : Dataset [T ], function : Callable [[T ], R ]
122126 ) -> Dataset [R ]:
@@ -258,7 +262,8 @@ def split(
258262 save_directory .mkdir (parents = True , exist_ok = True )
259263
260264 if stratify_column is not None :
261- return self ._stratified_split (
265+ return tools .stratified_split (
266+ self ,
262267 key_column = key_column ,
263268 proportions = proportions ,
264269 stratify_column = stratify_column ,
@@ -267,7 +272,8 @@ def split(
267272 frozen = frozen ,
268273 )
269274 else :
270- return self ._unstratified_split (
275+ return tools .unstratified_split (
276+ self ,
271277 key_column = key_column ,
272278 proportions = proportions ,
273279 filepath = (
@@ -278,74 +284,6 @@ def split(
278284 frozen = frozen ,
279285 )
280286
281- def _unstratified_split (
282- self ,
283- key_column : str ,
284- proportions : Dict [str , float ],
285- filepath : Optional [Path ] = None ,
286- seed : Optional [int ] = None ,
287- frozen : Optional [bool ] = False ,
288- ):
289- split_dataframes = tools .numpy_seed (seed )(tools .split_dataframes )
290- return {
291- split_name : Dataset (
292- dataframe = dataframe ,
293- length = len (dataframe ),
294- get_item = self .get_item ,
295- )
296- for split_name , dataframe in split_dataframes (
297- self .dataframe ,
298- key_column ,
299- proportions ,
300- filepath = filepath ,
301- frozen = frozen ,
302- ).items ()
303- }
304-
305- def _stratified_split (
306- self ,
307- key_column : str ,
308- proportions : Dict [str , float ],
309- stratify_column : Optional [str ] = None ,
310- save_directory : Optional [Path ] = None ,
311- seed : Optional [int ] = None ,
312- frozen : Optional [bool ] = False ,
313- ):
314- if (
315- stratify_column is not None
316- and any (self .dataframe [key_column ].duplicated ())
317- ):
318- # mathematically impossible in the general case
319- warnings .warn (
320- 'Trying to do stratified split with non-unique key column'
321- ' - cannot guarantee correct splitting of key values.'
322- )
323- strata = {
324- stratum_value : self .subset (
325- lambda df : df [stratify_column ] == stratum_value
326- )
327- for stratum_value in self .dataframe [stratify_column ].unique ()
328- }
329- split_strata = [
330- stratum ._unstratified_split (
331- key_column = key_column ,
332- proportions = proportions ,
333- filepath = (
334- save_directory / f'{ hash (stratum_value )} .json'
335- if save_directory is not None else None
336- ),
337- seed = seed ,
338- frozen = frozen ,
339- )
340- for stratum_value , stratum in strata .items ()
341- ]
342- return {
343- split_name : Dataset .concat (
344- [split_stratum [split_name ] for split_stratum in split_strata ]
345- )
346- for split_name in proportions .keys ()
347- }
348-
349287 def with_columns (
350288 self : Dataset [T ], ** kwargs : Callable [pd .Dataframe , pd .Series ]
351289 ) -> Dataset [T ]:
0 commit comments