2121 --jsonld=/tmp/croissant.json \
2222 --data_dir=/tmp/foo \
2323 --file_format=array_record \
24- --record_sets=record1 --record_sets= record2 \
24+ --record_sets=record1, record2 \
2525 --mapping='{"document.csv": "~/Downloads/document.csv"}"'
2626```
2727"""
2828
2929import argparse
30- from collections . abc import Sequence
30+ import dataclasses
3131import json
32+ import typing
3233
3334from etils import epath
35+ import simple_parsing
36+ from tensorflow_datasets .core import file_adapters
3437from tensorflow_datasets .core .dataset_builders import croissant_builder
3538from tensorflow_datasets .scripts .cli import cli_utils
3639
3740
38- def add_parser_arguments (parser : argparse .ArgumentParser ):
39- """Add arguments for `build_croissant` subparser."""
40- parser .add_argument (
41- '--jsonld' ,
42- type = str ,
43- help = 'The Croissant config file for the given dataset.' ,
44- required = True ,
45- )
46- parser .add_argument (
47- '--record_sets' ,
48- nargs = '*' ,
49- help = (
50- 'The names of the record sets to generate. Each record set will'
51- ' correspond to a separate config. If not specified, it will use all'
52- ' the record sets'
53- ),
54- )
55- parser .add_argument (
56- '--mapping' ,
57- type = str ,
58- help = (
59- 'Mapping filename->filepath as a Python dict[str, str] to handle'
60- ' manual downloads. If `document.csv` is the FileObject and you'
61- ' downloaded it to `~/Downloads/document.csv`, you can'
62- ' specify`--mapping=\' {"document.csv": "~/Downloads/document.csv"}\' '
63- ),
64- )
65-
66- cli_utils .add_debug_argument_group (parser )
67- cli_utils .add_path_argument_group (parser )
68- cli_utils .add_generation_argument_group (parser )
69- cli_utils .add_publish_argument_group (parser )
41+ @dataclasses .dataclass
42+ class CmdArgs :
43+ """CLI arguments for preparing a Croissant dataset.
7044
71-
72- def register_subparser (parsers : argparse ._SubParsersAction ):
73- """Add subparser for `convert_format` command."""
74- parser = parsers .add_parser (
75- 'build_croissant' ,
76- help = 'Prepares a croissant dataset' ,
77- )
78- add_parser_arguments (parser )
79- parser .set_defaults (
80- subparser_fn = lambda args : prepare_croissant_builder (
81- jsonld = args .jsonld ,
82- data_dir = args .data_dir ,
83- file_format = args .file_format ,
84- record_sets = args .record_sets ,
85- mapping = args .mapping ,
86- download_dir = args .download_dir ,
87- publish_dir = args .publish_dir ,
88- skip_if_published = args .skip_if_published ,
89- overwrite = args .overwrite ,
90- )
91- )
92-
93-
94- def prepare_croissant_builder (
95- jsonld : epath .PathLike ,
96- data_dir : epath .PathLike ,
97- file_format : str ,
98- record_sets : Sequence [str ],
99- mapping : str | None ,
100- download_dir : epath .PathLike | None ,
101- publish_dir : epath .PathLike | None ,
102- skip_if_published : bool ,
103- overwrite : bool ,
104- ) -> None :
105- # pyformat: disable
106- """Creates a Croissant Builder and runs the preparation.
107-
108- Args:
109- jsonld: The Croissant config file for the given dataset
45+ Attributes:
46+ jsonld: Path to the JSONLD file.
11047 data_dir: Path where the converted dataset will be stored.
11148 file_format: File format to convert the dataset to.
112- record_sets: The `@id`s of the record sets to generate. Each record set will
49+ record_sets: The names of the record sets to generate. Each record set will
11350 correspond to a separate config. If not specified, it will use all the
114- record sets
51+ record sets.
11552 mapping: Mapping filename->filepath as a Python dict[str, str] to handle
11653 manual downloads. If `document.csv` is the FileObject and you downloaded
11754 it to `~/Downloads/document.csv`, you can specify
118- `mapping={"document.csv": "~/Downloads/document.csv"}`.,
55+ `-- mapping=' {"document.csv": "~/Downloads/document.csv"}'`
11956 download_dir: Where to place downloads. Default to `<data_dir>/downloads/`.
12057 publish_dir: Where to optionally publish the dataset after it has been
12158 generated successfully. Should be the root data dir under which datasets
@@ -124,29 +61,74 @@ def prepare_croissant_builder(
12461 already published, then it will not be regenerated.
12562 overwrite: Delete pre-existing dataset if it exists.
12663 """
127- # pyformat: enable
128- if not record_sets :
129- record_sets = None
13064
131- if mapping :
65+ jsonld : epath .PathLike
66+ data_dir : epath .PathLike
67+ # Need to override the default use of `Enum.name` for choice options.
68+ file_format : str = simple_parsing .choice (
69+ * (file_format .value for file_format in file_adapters .FileFormat ),
70+ default = file_adapters .FileFormat .ARRAY_RECORD .value ,
71+ )
72+ # Need to manually parse comma-separated list of values, see:
73+ # https://github.com/lebrice/SimpleParsing/issues/142.
74+ record_sets : list [str ] = simple_parsing .field (
75+ default_factory = list ,
76+ type = lambda record_sets_str : record_sets_str .split (',' ),
77+ nargs = '?' ,
78+ )
79+ mapping : str | None = None
80+ download_dir : epath .PathLike | None = None
81+ publish_dir : epath .PathLike | None = None
82+ skip_if_published : bool = False
83+ overwrite : bool = False
84+
85+
86+ def register_subparser (parsers : argparse ._SubParsersAction ):
87+ """Add subparser for `convert_format` command."""
88+ orig_parser_class = parsers ._parser_class # pylint: disable=protected-access
89+ try :
90+ parsers ._parser_class = simple_parsing .ArgumentParser # pylint: disable=protected-access
91+ parser = parsers .add_parser (
92+ 'build_croissant' ,
93+ help = 'Prepares a croissant dataset' ,
94+ )
95+ parser = typing .cast (simple_parsing .ArgumentParser , parser )
96+ finally :
97+ parsers ._parser_class = orig_parser_class # pylint: disable=protected-access
98+ parser .add_arguments (CmdArgs , dest = 'args' )
99+ parser .set_defaults (
100+ subparser_fn = lambda args : prepare_croissant_builder (args .args )
101+ )
102+
103+
104+ def prepare_croissant_builder (args : CmdArgs ) -> None :
105+ """Creates a Croissant Builder and runs the preparation.
106+
107+ Args:
108+ args: CLI arguments.
109+ """
110+ if args .mapping :
132111 try :
133- mapping = json .loads (mapping )
112+ mapping = json .loads (args . mapping )
134113 except json .JSONDecodeError as e :
135- raise ValueError (f'Error parsing mapping parameter: { mapping } ' ) from e
114+ raise ValueError (
115+ f'Error parsing mapping parameter: { args .mapping } '
116+ ) from e
117+ else :
118+ mapping = None
136119
137120 builder = croissant_builder .CroissantBuilder (
138- jsonld = jsonld ,
139- record_set_ids = record_sets ,
140- file_format = file_format ,
141- data_dir = data_dir ,
121+ jsonld = args . jsonld ,
122+ record_set_ids = args . record_sets or None ,
123+ file_format = args . file_format ,
124+ data_dir = args . data_dir ,
142125 mapping = mapping ,
143126 )
144127 cli_utils .download_and_prepare (
145128 builder = builder ,
146129 download_config = None ,
147- download_dir = epath .Path (download_dir ) if download_dir else None ,
148- publish_dir = epath .Path (publish_dir ) if publish_dir else None ,
149- skip_if_published = skip_if_published ,
150- freeze_files = freeze_files ,
151- overwrite = overwrite ,
130+ download_dir = epath .Path (args .download_dir ) if args .download_dir else None ,
131+ publish_dir = epath .Path (args .publish_dir ) if args .publish_dir else None ,
132+ skip_if_published = args .skip_if_published ,
133+ overwrite = args .overwrite ,
152134 )
0 commit comments