|
21 | 21 | See: https://www.tensorflow.org/datasets/cli |
22 | 22 | """ |
23 | 23 |
|
24 | | -import argparse |
| 24 | +import dataclasses |
25 | 25 | import logging as python_logging |
26 | | -from typing import List |
27 | 26 |
|
28 | 27 | from absl import app |
29 | 28 | from absl import flags |
30 | 29 | from absl import logging |
| 30 | +import simple_parsing |
31 | 31 |
|
32 | 32 | import tensorflow_datasets.public_api as tfds |
33 | 33 |
|
|
41 | 41 | FLAGS = flags.FLAGS |
42 | 42 |
|
43 | 43 |
|
44 | | -def _parse_flags(argv: List[str]) -> argparse.Namespace: |
45 | | - """Command lines flag parsing.""" |
46 | | - parser = cli_utils.ArgumentParser( |
47 | | - description='Tensorflow Datasets CLI tool', |
48 | | - allow_abbrev=False, |
49 | | - ) |
50 | | - parser.add_argument( |
51 | | - '--version', |
52 | | - action='version', |
53 | | - version='TensorFlow Datasets: ' + tfds.__version__, |
54 | | - ) |
55 | | - parser.add_argument( |
56 | | - '--dry_run', |
57 | | - action='store_true', |
58 | | - help='If True, print the parsed arguments.', |
| 44 | +@dataclasses.dataclass(frozen=True, kw_only=True) |
| 45 | +class _DummyCommand: |
| 46 | + """Dummy command to avoid `command is MISSING` error.""" |
| 47 | + |
| 48 | + pass |
| 49 | + |
| 50 | + |
| 51 | +version_field = simple_parsing.field( |
| 52 | + action='version', |
| 53 | + version='TensorFlow Datasets: ' + tfds.__version__, |
| 54 | + help='The version of the TensorFlow Datasets package.', |
| 55 | +) |
| 56 | + |
| 57 | + |
| 58 | +@dataclasses.dataclass(frozen=True, kw_only=True) |
| 59 | +class Args(cli_utils.Args): |
| 60 | + """Tensorflow Datasets CLI tool.""" |
| 61 | + |
| 62 | + version: str = version_field |
| 63 | + """The version of the TensorFlow Datasets package.""" |
| 64 | + |
| 65 | + dry_run: bool = simple_parsing.flag(default=False) |
| 66 | + """If True, print the parsed arguments and exit.""" |
| 67 | + |
| 68 | + command: build.Args | new.Args | convert_format.Args | croissant.CmdArgs = ( |
| 69 | + simple_parsing.subparsers( |
| 70 | + { |
| 71 | + 'build': build.Args, |
| 72 | + 'new': new.Args, |
| 73 | + 'convert_format': convert_format.Args, |
| 74 | + 'build_croissant': croissant.CmdArgs, |
| 75 | + }, |
| 76 | + default_factory=_DummyCommand, |
| 77 | + ) |
59 | 78 | ) |
60 | | - parser.set_defaults(subparser_fn=lambda _: parser.print_help()) |
61 | | - # Register sub-commands |
62 | | - subparser = parser.add_subparsers(title='command') |
63 | | - build.register_subparser(subparser) |
64 | | - new.register_subparser(subparser) |
65 | | - convert_format.register_subparser(subparser) |
66 | | - croissant.register_subparser(subparser) |
67 | | - return parser.parse_args(argv[1:]) |
| 79 | + """The command to execute.""" |
| 80 | + |
| 81 | + def execute(self) -> None: |
| 82 | + """Run the command.""" |
| 83 | + if self.dry_run: |
| 84 | + print(self) |
| 85 | + # When no command is given, print the help message. |
| 86 | + elif isinstance(self.command, _DummyCommand): |
| 87 | + _parse_flags(['', '--help']) |
| 88 | + else: |
| 89 | + self.command.execute() |
| 90 | + |
| 91 | + |
| 92 | +_parse_flags = cli_utils.make_flags_parser( |
| 93 | + Args, description='Tensorflow Datasets CLI tool' |
| 94 | +) |
68 | 95 |
|
69 | 96 |
|
70 | | -def main(args: argparse.Namespace) -> None: |
| 97 | +def main(args: Args) -> None: |
71 | 98 |
|
72 | 99 | # From the CLI, all datasets are visible |
73 | 100 | tfds.core.visibility.set_availables([ |
@@ -98,11 +125,7 @@ def main(args: argparse.Namespace) -> None: |
98 | 125 | new_stream = tfds.core.utils.tqdm_utils.TqdmStream() |
99 | 126 | python_handler.setStream(new_stream) |
100 | 127 |
|
101 | | - if args.dry_run: |
102 | | - print(args) |
103 | | - else: |
104 | | - # Launch the subcommand defined in the subparser (or default to print help) |
105 | | - args.subparser_fn(args) |
| 128 | + args.execute() |
106 | 129 |
|
107 | 130 |
|
108 | 131 | def launch_cli() -> None: |
|
0 commit comments