1717
1818import argparse
1919from collections .abc import Iterator
20+ import dataclasses
2021import functools
2122import importlib
2223import itertools
2324import json
2425import multiprocessing
2526import os
27+ import typing
2628from typing import Any , Type
2729
2830from absl import logging
31+ import simple_parsing
2932import tensorflow_datasets as tfds
3033from tensorflow_datasets .scripts .cli import cli_utils
3134
32- # pylint: disable=logging-fstring-interpolation
3335
34-
35- def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
36- """Add subparser for `build` command.
37-
38- New flags should be added to `cli_utils` module.
39-
40- Args:
41- parsers: The subparsers object to add the parser to.
36+ @dataclasses .dataclass (frozen = True , kw_only = True )
37+ class Args :
38+ """CLI arguments for building datasets.
39+
40+ Attributes:
41+ positional_datasets: Name(s) of the dataset(s) to build. Default to current
42+ dir. See https://www.tensorflow.org/datasets/cli for accepted values.
43+ datasets: Datasets can also be provided as keyword argument.
44+ debug: Debug & tests options. Use --pdb to enter post-mortem debugging mode
45+ if an exception is raised.
46+ paths: Path options.
47+ generation: Generation options.
48+ publishing: Publishing options.
49+ automation: Automation options.
4250 """
43- build_parser = parsers .add_parser (
44- 'build' , help = 'Commands for downloading and preparing datasets.'
45- )
46- build_parser .add_argument (
47- 'datasets' , # Positional arguments
48- type = str ,
51+
52+ positional_datasets : list [str ] = simple_parsing .field (
53+ positional = True ,
4954 nargs = '*' ,
50- help = (
51- 'Name(s) of the dataset(s) to build. Default to current dir. '
52- 'See https://www.tensorflow.org/datasets/cli for accepted values.'
53- ),
54- )
55- build_parser .add_argument ( # Also accept keyword arguments
56- '--datasets' ,
57- type = str ,
58- nargs = '+' ,
59- dest = 'datasets_keyword' ,
60- help = 'Datasets can also be provided as keyword argument.' ,
55+ default_factory = list ,
56+ # Need to explicitly set metavar for command-line help.
57+ metavar = 'datasets' ,
6158 )
59+ datasets : list [str ] = simple_parsing .field (nargs = '*' , default_factory = list )
6260
63- cli_utils .add_debug_argument_group (build_parser )
64- cli_utils .add_path_argument_group (build_parser )
65- cli_utils .add_generation_argument_group (build_parser )
66- cli_utils .add_publish_argument_group (build_parser )
67-
68- # **** Automation options ****
69- automation_group = build_parser .add_argument_group (
70- 'Automation' , description = 'Used by automated scripts.'
61+ debug : cli_utils .DebugOptions = cli_utils .DebugOptions ()
62+ paths : cli_utils .PathOptions = simple_parsing .field (
63+ default_factory = cli_utils .PathOptions
64+ )
65+ generation : cli_utils .GenerationOptions = simple_parsing .field (
66+ default_factory = cli_utils .GenerationOptions
7167 )
72- automation_group .add_argument (
73- '--exclude_datasets' ,
74- type = str ,
75- help = (
76- 'If set, generate all datasets except the one defined here. '
77- 'Comma separated list of datasets to exclude. '
78- ),
68+ publishing : cli_utils .PublishingOptions = simple_parsing .field (
69+ default_factory = cli_utils .PublishingOptions
7970 )
80- automation_group .add_argument (
81- '--experimental_latest_version' ,
82- action = 'store_true' ,
83- help = (
84- 'Build the latest Version(experiments=...) available rather than '
85- 'default version.'
86- ),
71+ automation : cli_utils .AutomationOptions = simple_parsing .field (
72+ default_factory = cli_utils .AutomationOptions
8773 )
8874
89- build_parser .set_defaults (subparser_fn = _build_datasets )
75+ def execute (self ) -> None :
76+ """Build the given datasets."""
77+ # Eventually register additional datasets imports
78+ if self .generation .imports :
79+ list (
80+ importlib .import_module (m ) for m in self .generation .imports .split (',' )
81+ )
9082
83+ # Select datasets to generate
84+ datasets = self .positional_datasets + self .datasets
85+ if (
86+ self .automation .exclude_datasets
87+ ): # Generate all datasets if `--exclude_datasets` set
88+ if datasets :
89+ raise ValueError ("--exclude_datasets can't be used with `datasets`" )
90+ datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
91+ self .automation .exclude_datasets .split (',' )
92+ )
93+ datasets = sorted (datasets ) # `set` is not deterministic
94+ else :
95+ datasets = datasets or ['' ] # Empty string for default
96+
97+ # Import builder classes
98+ builders_cls_and_kwargs = [
99+ _get_builder_cls_and_kwargs (
100+ dataset , has_imports = bool (self .generation .imports )
101+ )
102+ for dataset in datasets
103+ ]
104+
105+ # Parallelize datasets generation.
106+ builders = itertools .chain (* (
107+ _make_builders (self , builder_cls , builder_kwargs )
108+ for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
109+ ))
110+ process_builder_fn = functools .partial (
111+ _download if self .generation .download_only else _download_and_prepare ,
112+ self ,
113+ )
91114
92- def _build_datasets (args : argparse .Namespace ) -> None :
93- """Build the given datasets."""
94- # Eventually register additional datasets imports
95- if args .imports :
96- list (importlib .import_module (m ) for m in args .imports .split (',' ))
115+ if self .generation .num_processes == 1 :
116+ for builder in builders :
117+ process_builder_fn (builder )
118+ else :
119+ with multiprocessing .Pool (self .generation .num_processes ) as pool :
120+ pool .map (process_builder_fn , builders )
97121
98- # Select datasets to generate
99- datasets = (args .datasets or []) + (args .datasets_keyword or [])
100- if args .exclude_datasets : # Generate all datasets if `--exclude_datasets` set
101- if datasets :
102- raise ValueError ("--exclude_datasets can't be used with `datasets`" )
103- datasets = set (tfds .list_builders (with_community_datasets = False )) - set (
104- args .exclude_datasets .split (',' )
105- )
106- datasets = sorted (datasets ) # `set` is not deterministic
107- else :
108- datasets = datasets or ['' ] # Empty string for default
109-
110- # Import builder classes
111- builders_cls_and_kwargs = [
112- _get_builder_cls_and_kwargs (dataset , has_imports = bool (args .imports ))
113- for dataset in datasets
114- ]
115-
116- # Parallelize datasets generation.
117- builders = itertools .chain (* (
118- _make_builders (args , builder_cls , builder_kwargs )
119- for (builder_cls , builder_kwargs ) in builders_cls_and_kwargs
120- ))
121- process_builder_fn = functools .partial (
122- _download if args .download_only else _download_and_prepare , args
123- )
124122
125- if args .num_processes == 1 :
126- for builder in builders :
127- process_builder_fn (builder )
128- else :
129- with multiprocessing .Pool (args .num_processes ) as pool :
130- pool .map (process_builder_fn , builders )
123+ def register_subparser (parsers : argparse ._SubParsersAction ) -> None : # pylint: disable=protected-access
124+ """Add subparser for `build` command."""
125+ parser = parsers .add_parser (
126+ 'build' , help = 'Commands for downloading and preparing datasets.'
127+ )
128+ parser = typing .cast (simple_parsing .ArgumentParser , parser )
129+ parser .add_arguments (Args , dest = 'args' )
130+ parser .set_defaults (subparser_fn = lambda args : args .args .execute ())
131131
132132
133133def _make_builders (
134- args : argparse . Namespace ,
134+ args : Args ,
135135 builder_cls : Type [tfds .core .DatasetBuilder ],
136136 builder_kwargs : dict [str , Any ],
137137) -> Iterator [tfds .core .DatasetBuilder ]:
@@ -146,7 +146,7 @@ def _make_builders(
146146 Initialized dataset builders.
147147 """
148148 # Eventually overwrite version
149- if args .experimental_latest_version :
149+ if args .automation . experimental_latest_version :
150150 if 'version' in builder_kwargs :
151151 raise ValueError (
152152 "Can't have both `--experimental_latest` and version set (`:1.0.0`)"
@@ -157,19 +157,19 @@ def _make_builders(
157157 builder_kwargs ['config' ] = _get_config_name (
158158 builder_cls = builder_cls ,
159159 config_kwarg = builder_kwargs .get ('config' ),
160- config_name = args .config ,
161- config_idx = args .config_idx ,
160+ config_name = args .generation . config ,
161+ config_idx = args .generation . config_idx ,
162162 )
163163
164- if args .file_format :
165- builder_kwargs ['file_format' ] = args .file_format
164+ if args .generation . file_format :
165+ builder_kwargs ['file_format' ] = args .generation . file_format
166166
167167 make_builder = functools .partial (
168168 _make_builder ,
169169 builder_cls ,
170- overwrite = args .overwrite ,
171- fail_if_exists = args .fail_if_exists ,
172- data_dir = args .data_dir ,
170+ overwrite = args .debug . overwrite ,
171+ fail_if_exists = args .debug . fail_if_exists ,
172+ data_dir = args .paths . data_dir ,
173173 ** builder_kwargs ,
174174 )
175175
@@ -203,7 +203,7 @@ def _get_builder_cls_and_kwargs(
203203 if not has_imports :
204204 path = _search_script_path (ds_to_build )
205205 if path is not None :
206- logging .info (f 'Loading dataset { ds_to_build } from path: { path } ' )
206+ logging .info ('Loading dataset %s from path: %s' , ds_to_build , path )
207207 # Dynamically load user dataset script
208208 # When possible, load from the parent's parent, so module is named
209209 # "foo.foo_dataset_builder".
@@ -228,7 +228,9 @@ def _get_builder_cls_and_kwargs(
228228 name , builder_kwargs = tfds .core .naming .parse_builder_name_kwargs (ds_to_build )
229229 builder_cls = tfds .builder_cls (str (name ))
230230 logging .info (
231- f'Loading dataset { ds_to_build } from imports: { builder_cls .__module__ } '
231+ 'Loading dataset %s from imports: %s' ,
232+ ds_to_build ,
233+ builder_cls .__module__ ,
232234 )
233235 return builder_cls , builder_kwargs
234236
@@ -308,7 +310,7 @@ def _make_builder(
308310
309311
310312def _download (
311- args : argparse . Namespace ,
313+ args : Args ,
312314 builder : tfds .core .DatasetBuilder ,
313315) -> None :
314316 """Downloads all files of the given builder."""
@@ -330,7 +332,7 @@ def _download(
330332 if builder .MAX_SIMULTANEOUS_DOWNLOADS is not None :
331333 max_simultaneous_downloads = builder .MAX_SIMULTANEOUS_DOWNLOADS
332334
333- download_dir = args .download_dir or os .path .join (
335+ download_dir = args .paths . download_dir or os .path .join (
334336 builder ._data_dir_root , 'downloads' # pylint: disable=protected-access
335337 )
336338 dl_manager = tfds .download .DownloadManager (
@@ -352,55 +354,55 @@ def _download(
352354
353355
354356def _download_and_prepare (
355- args : argparse . Namespace ,
357+ args : Args ,
356358 builder : tfds .core .DatasetBuilder ,
357359) -> None :
358360 """Generate a single builder."""
359361 cli_utils .download_and_prepare (
360362 builder = builder ,
361363 download_config = _make_download_config (args , dataset_name = builder .name ),
362- download_dir = args .download_dir ,
363- publish_dir = args .publish_dir ,
364- skip_if_published = args .skip_if_published ,
365- overwrite = args .overwrite ,
366- beam_pipeline_options = args .beam_pipeline_options ,
367- nondeterministic_order = args .nondeterministic_order ,
364+ download_dir = args .paths . download_dir ,
365+ publish_dir = args .publishing . publish_dir ,
366+ skip_if_published = args .publishing . skip_if_published ,
367+ overwrite = args .debug . overwrite ,
368+ beam_pipeline_options = args .generation . beam_pipeline_options ,
369+ nondeterministic_order = args .generation . nondeterministic_order ,
368370 )
369371
370372
371373def _make_download_config (
372- args : argparse . Namespace ,
374+ args : Args ,
373375 dataset_name : str ,
374376) -> tfds .download .DownloadConfig :
375377 """Generate the download and prepare configuration."""
376378 # Load the download config
377- manual_dir = args .manual_dir
378- if args .add_name_to_manual_dir :
379+ manual_dir = args .paths . manual_dir
380+ if args .paths . add_name_to_manual_dir :
379381 manual_dir = manual_dir / dataset_name
380382
381383 kwargs = {}
382- if args .max_shard_size_mb :
383- kwargs ['max_shard_size' ] = args .max_shard_size_mb << 20
384- if args .num_shards :
385- kwargs ['num_shards' ] = args .num_shards
386- if args .download_config :
387- kwargs .update (json .loads (args .download_config ))
384+ if args .generation . max_shard_size_mb :
385+ kwargs ['max_shard_size' ] = args .generation . max_shard_size_mb << 20
386+ if args .generation . num_shards :
387+ kwargs ['num_shards' ] = args .generation . num_shards
388+ if args .generation . download_config :
389+ kwargs .update (json .loads (args .generation . download_config ))
388390
389391 if 'download_mode' in kwargs :
390392 kwargs ['download_mode' ] = tfds .download .GenerateMode (
391393 kwargs ['download_mode' ]
392394 )
393395 else :
394396 kwargs ['download_mode' ] = tfds .download .GenerateMode .REUSE_DATASET_IF_EXISTS
395- if args .update_metadata_only :
397+ if args .generation . update_metadata_only :
396398 kwargs ['download_mode' ] = tfds .download .GenerateMode .UPDATE_DATASET_INFO
397399
398400 return tfds .download .DownloadConfig (
399- extract_dir = args .extract_dir ,
401+ extract_dir = args .paths . extract_dir ,
400402 manual_dir = manual_dir ,
401- max_examples_per_split = args .max_examples_per_split ,
402- register_checksums = args .register_checksums ,
403- force_checksums_validation = args .force_checksums_validation ,
403+ max_examples_per_split = args .debug . max_examples_per_split ,
404+ register_checksums = args .generation . register_checksums ,
405+ force_checksums_validation = args .generation . force_checksums_validation ,
404406 ** kwargs ,
405407 )
406408
@@ -445,11 +447,10 @@ def _get_config_name(
445447 else :
446448 return config_name
447449 elif config_idx is not None : # `--config_idx 123`
448- if config_idx > len (builder_cls .BUILDER_CONFIGS ):
450+ if config_idx >= len (builder_cls .BUILDER_CONFIGS ):
449451 raise ValueError (
450- f'--config_idx { config_idx } greater than number '
451- f'of configs { len (builder_cls .BUILDER_CONFIGS )} for '
452- f'{ builder_cls .name } .'
452+ f'--config_idx { config_idx } greater than number of configs '
453+ f'{ len (builder_cls .BUILDER_CONFIGS )} for { builder_cls .name } .'
453454 )
454455 else :
455456 # Use `config.name` to avoid
0 commit comments