@@ -282,9 +282,11 @@ class MultiSplitInfo(SplitInfo):
282282 This should only be used to read data and not when producing data.
283283 """
284284
285- split_infos : list [SplitInfo ] = dataclasses .field (default_factory = list )
285+ split_infos : list [SplitInfo | SubSplitInfo ] = dataclasses .field (
286+ default_factory = list
287+ )
286288
287- def __init__ (self , name : str , split_infos : list [SplitInfo ]):
289+ def __init__ (self , name : str , split_infos : list [SplitInfo | SubSplitInfo ]):
288290 if not split_infos :
289291 raise ValueError ('Need to pass a non-empty list of SplitInfos' )
290292 object .__setattr__ (self , 'split_infos' , split_infos )
@@ -315,6 +317,16 @@ def __repr__(self) -> str:
315317 f'split_infos={ self .split_infos !r} )'
316318 )
317319
320+ @property
321+ def examples_in_shards (self ) -> list [int ]:
322+ result = []
323+ for split_info in self .split_infos :
324+ if isinstance (split_info , (SubSplitInfo , MultiSplitInfo )):
325+ result .extend (split_info .examples_in_shards )
326+ else :
327+ result .extend (split_info .shard_lengths )
328+ return result
329+
318330 @property
319331 def file_instructions (self ) -> list [shard_utils .FileInstruction ]:
320332 result = []
@@ -361,6 +373,10 @@ class SubSplitInfo:
361373 def shard_lengths (self ) -> list [int ]:
362374 return [f .take for f in self .file_instructions ]
363375
376+ @property
377+ def examples_in_shards (self ) -> list [int ]:
378+ return [f .examples_in_shard for f in self .file_instructions ]
379+
364380 @property
365381 def num_examples (self ) -> int :
366382 """Returns the number of example in the subsplit."""
@@ -526,7 +542,7 @@ def _make_absolute_instructions(
526542
527543def _file_instructions_for_split (
528544 instruction : _AbsoluteInstruction ,
529- split_info : SplitInfo ,
545+ split_info : SplitInfo | SubSplitInfo ,
530546) -> list [shard_utils .FileInstruction ]:
531547 """Returns the file instructions from the given instruction applied to the given split info."""
532548 if not split_info .num_examples :
@@ -537,9 +553,7 @@ def _file_instructions_for_split(
537553 return []
538554 to = split_info .num_examples if instruction .to is None else instruction .to
539555 if isinstance (split_info , (SubSplitInfo , MultiSplitInfo )):
540- examples_in_shards = [
541- f .examples_in_shard for f in split_info .file_instructions
542- ]
556+ examples_in_shards = split_info .examples_in_shards
543557 else :
544558 examples_in_shards = None
545559 return shard_utils .get_file_instructions (
0 commit comments