|
26 | 26 | ``` |
27 | 27 | """ |
28 | 28 |
|
| 29 | +import dataclasses |
| 30 | + |
29 | 31 | from absl import app |
30 | | -from absl import flags |
| 32 | +from etils import eapp |
| 33 | +from etils import epath |
| 34 | +import simple_parsing |
31 | 35 | from tensorflow_datasets.core import file_adapters |
32 | 36 | from tensorflow_datasets.scripts.cli import croissant |
33 | 37 |
|
34 | 38 |
|
35 | | -_JSONLD = flags.DEFINE_string( |
36 | | - name='jsonld', default=None, help='Path to the JSONLD file.', required=True |
37 | | -) |
38 | | -_DATA_DIR = flags.DEFINE_string( |
39 | | - name='data_dir', |
40 | | - default=None, |
41 | | - help='Path where the converted dataset will be stored.', |
42 | | - required=True, |
43 | | -) |
44 | | -_FILE_FORMAT = flags.DEFINE_enum_class( |
45 | | - name='file_format', |
46 | | - default=file_adapters.FileFormat.ARRAY_RECORD, |
47 | | - enum_class=file_adapters.FileFormat, |
48 | | - help='File format to convert the dataset to.', |
49 | | -) |
50 | | -_RECORD_SETS = flags.DEFINE_list( |
51 | | - name='record_sets', |
52 | | - default=[], |
53 | | - help=( |
54 | | - 'The names of the record sets to generate. Each record set will' |
55 | | - ' correspond to a separate config. If not specified, it will use all' |
56 | | - ' the record sets.' |
57 | | - ), |
58 | | -) |
59 | | -_MAPPING = flags.DEFINE_string( |
60 | | - name='mapping', |
61 | | - default=None, |
62 | | - help=( |
63 | | - 'Mapping filename->filepath as a Python dict[str, str] to handle' |
64 | | - ' manual downloads. If `document.csv` is the FileObject and you' |
65 | | - ' downloaded it to `~/Downloads/document.csv`, you can' |
66 | | - ' specify`--mapping=\'{"document.csv": "~/Downloads/document.csv"}\'' |
67 | | - ), |
68 | | -) |
69 | | -_DOWNLOAD_DIR = flags.DEFINE_string( |
70 | | - name='download_dir', |
71 | | - default=None, |
72 | | - help='Where to place downloads. Default to `<data_dir>/downloads/`.', |
73 | | -) |
74 | | -_PUBLISH_DIR = flags.DEFINE_string( |
75 | | - name='publish_dir', |
76 | | - default=None, |
77 | | - help=( |
78 | | - 'Where to optionally publish the dataset after it has been generated ' |
79 | | - 'successfully. Should be the root data dir under which datasets are ' |
80 | | - 'stored. If unspecified, dataset will not be published.' |
81 | | - ), |
82 | | -) |
83 | | -_SKIP_IF_PUBLISHED = flags.DEFINE_bool( |
84 | | - name='skip_if_published', |
85 | | - default=False, |
86 | | - help=( |
87 | | - 'If the dataset with the same version and config is already published, ' |
88 | | - 'then it will not be regenerated.' |
89 | | - ), |
90 | | -) |
91 | | -_OVERWRITE = flags.DEFINE_bool( |
92 | | - name='overwrite', |
93 | | - default=False, |
94 | | - help='Delete pre-existing dataset if it exists.', |
95 | | -) |
| 39 | +@dataclasses.dataclass |
| 40 | +class CmdArgs: |
| 41 | + """CLI arguments for preparing a Croissant dataset. |
| 42 | +
|
| 43 | + Attributes: |
| 44 | + jsonld: Path to the JSONLD file. |
| 45 | + data_dir: Path where the converted dataset will be stored. |
| 46 | + file_format: File format to convert the dataset to. |
| 47 | + record_sets: The names of the record sets to generate. Each record set will |
| 48 | + correspond to a separate config. If not specified, it will use all the |
| 49 | + record sets. |
| 50 | + mapping: Mapping filename->filepath as a Python dict[str, str] to handle |
| 51 | + manual downloads. If `document.csv` is the FileObject and you downloaded |
| 52 | + it to `~/Downloads/document.csv`, you can specify |
| 53 | + `--mapping='{"document.csv": "~/Downloads/document.csv"}'` |
| 54 | + download_dir: Where to place downloads. Default to `<data_dir>/downloads/`. |
| 55 | + publish_dir: Where to optionally publish the dataset after it has been |
| 56 | + generated successfully. Should be the root data dir under which datasets |
| 57 | + are stored. If unspecified, dataset will not be published. |
| 58 | + skip_if_published: If the dataset with the same version and config is |
| 59 | + already published, then it will not be regenerated. |
| 60 | + overwrite: Delete pre-existing dataset if it exists. |
| 61 | + """ |
| 62 | + |
| 63 | + jsonld: epath.PathLike |
| 64 | + data_dir: epath.PathLike |
| 65 | + # Need to override the default use of `Enum.name` for choice options. |
| 66 | + file_format: str = simple_parsing.choice( |
| 67 | + *(file_format.value for file_format in file_adapters.FileFormat), |
| 68 | + default=file_adapters.FileFormat.ARRAY_RECORD.value, |
| 69 | + ) |
| 70 | + # Need to manually parse comma-separated list of values, see: |
| 71 | + # https://github.com/lebrice/SimpleParsing/issues/142. |
| 72 | + record_sets: list[str] = simple_parsing.field( |
| 73 | + default_factory=list, |
| 74 | + type=lambda record_sets_str: record_sets_str.split(','), |
| 75 | + nargs='?', |
| 76 | + ) |
| 77 | + mapping: str | None = None |
| 78 | + download_dir: epath.PathLike | None = None |
| 79 | + publish_dir: epath.PathLike | None = None |
| 80 | + skip_if_published: bool = False |
| 81 | + overwrite: bool = False |
| 82 | + |
| 83 | +parse_flags = eapp.make_flags_parser(CmdArgs) |
96 | 84 |
|
97 | 85 |
|
98 | | -def main(_): |
| 86 | +def main(args: CmdArgs): |
99 | 87 | croissant.prepare_croissant_builder( |
100 | | - jsonld=_JSONLD.value, |
101 | | - data_dir=_DATA_DIR.value, |
102 | | - file_format=_FILE_FORMAT.value.value, |
103 | | - record_sets=_RECORD_SETS.value, |
104 | | - mapping=_MAPPING.value, |
105 | | - download_dir=_DOWNLOAD_DIR.value, |
106 | | - publish_dir=_PUBLISH_DIR.value, |
107 | | - skip_if_published=_SKIP_IF_PUBLISHED.value, |
108 | | - overwrite=_OVERWRITE.value, |
| 88 | + jsonld=args.jsonld, |
| 89 | + data_dir=args.data_dir, |
| 90 | + file_format=args.file_format, |
| 91 | + record_sets=args.record_sets, |
| 92 | + mapping=args.mapping, |
| 93 | + download_dir=args.download_dir, |
| 94 | + publish_dir=args.publish_dir, |
| 95 | + skip_if_published=args.skip_if_published, |
| 96 | + overwrite=args.overwrite, |
109 | 97 | ) |
110 | 98 |
|
111 | 99 |
|
112 | 100 | if __name__ == '__main__': |
113 | | - app.run(main) |
| 101 | + app.run(main, flags_parser=parse_flags) |
0 commit comments