Skip to content

Commit 9666187

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Refactor TFDS CLI commands.
PiperOrigin-RevId: 793983080
1 parent 5b94090 commit 9666187

File tree

8 files changed

+105
-108
lines changed

8 files changed

+105
-108
lines changed

tensorflow_datasets/scripts/cli/build.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
"""`tfds build` command."""
1717

18-
import argparse
1918
from collections.abc import Iterator
2019
import dataclasses
2120
import functools
@@ -24,7 +23,6 @@
2423
import json
2524
import multiprocessing
2625
import os
27-
import typing
2826
from typing import Any, Type
2927

3028
from absl import logging
@@ -34,8 +32,8 @@
3432

3533

3634
@dataclasses.dataclass(frozen=True, kw_only=True)
37-
class Args:
38-
"""CLI arguments for building datasets.
35+
class Args(cli_utils.Args):
36+
"""Commands for downloading and preparing datasets.
3937
4038
Attributes:
4139
positional_datasets: Name(s) of the dataset(s) to build. Default to current
@@ -120,16 +118,6 @@ def execute(self) -> None:
120118
pool.map(process_builder_fn, builders)
121119

122120

123-
def register_subparser(parsers: argparse._SubParsersAction) -> None: # pylint: disable=protected-access
124-
"""Add subparser for `build` command."""
125-
parser = parsers.add_parser(
126-
'build', help='Commands for downloading and preparing datasets.'
127-
)
128-
parser = typing.cast(simple_parsing.ArgumentParser, parser)
129-
parser.add_arguments(Args, dest='args')
130-
parser.set_defaults(subparser_fn=lambda args: args.args.execute())
131-
132-
133121
def _make_builders(
134122
args: Args,
135123
builder_cls: Type[tfds.core.DatasetBuilder],

tensorflow_datasets/scripts/cli/build_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
import multiprocessing
2121
import os
22+
import typing
2223
from unittest import mock
2324

2425
from etils import epath
@@ -311,7 +312,7 @@ def test_download_only(build):
311312
)
312313
def test_make_download_config(args: str, download_config_kwargs):
313314
args = main._parse_flags(f'tfds build x {args}'.split())
314-
cmd_args: build_lib.Args = args.args
315+
cmd_args = typing.cast(build_lib.Args, args.command)
315316
actual = build_lib._make_download_config(cmd_args, dataset_name='x')
316317
# Ignore the beam runner
317318
actual = actual.replace(beam_runner=None)

tensorflow_datasets/scripts/cli/cli_utils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515

1616
"""Utility functions for TFDS CLI."""
1717

18+
import abc
1819
import argparse
19-
from collections.abc import Sequence
20+
from collections.abc import Callable, Sequence
2021
import dataclasses
2122
import itertools
2223
import pathlib
24+
from typing import TypeVar
2325

2426
from absl import logging
2527
from absl.flags import argparse_flags
@@ -33,6 +35,8 @@
3335
from tensorflow_datasets.core.utils import file_utils
3436
from tensorflow_datasets.scripts.utils import flag_utils
3537

38+
_DataclassT = TypeVar('_DataclassT')
39+
3640

3741
class ArgumentParser(
3842
argparse_flags.ArgumentParser, simple_parsing.ArgumentParser
@@ -77,6 +81,33 @@ def parse_known_args(
7781
return super().parse_known_args(args, namespace)
7882

7983

84+
def make_flags_parser(
85+
args_dataclass: type[_DataclassT], description: str
86+
) -> Callable[[list[str]], _DataclassT]:
87+
"""Returns a function that parses flags and returns the dataclass instance."""
88+
89+
def _parse_flags(argv: list[str]) -> _DataclassT:
90+
"""Command lines flag parsing."""
91+
parser = ArgumentParser(
92+
description=description,
93+
allow_abbrev=False,
94+
)
95+
parser.add_arguments(args_dataclass, dest='args')
96+
return parser.parse_args(argv[1:]).args
97+
98+
return _parse_flags
99+
100+
101+
@dataclasses.dataclass(frozen=True, kw_only=True)
102+
class Args(abc.ABC):
103+
"""CLI arguments for TFDS CLI commands."""
104+
105+
@abc.abstractmethod
106+
def execute(self) -> None:
107+
"""Execute the CLI command."""
108+
...
109+
110+
80111
@dataclasses.dataclass
81112
class DatasetInfo:
82113
"""Structure for common string used for formatting.

tensorflow_datasets/scripts/cli/convert_format.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,18 @@
2525
```
2626
"""
2727

28-
import argparse
2928
import dataclasses
30-
import typing
3129

3230
from etils import epath
3331
import simple_parsing
3432
from tensorflow_datasets.core import file_adapters
33+
from tensorflow_datasets.scripts.cli import cli_utils
3534
from tensorflow_datasets.scripts.cli import convert_format_utils
3635

3736

3837
@dataclasses.dataclass(frozen=True, kw_only=True)
39-
class Args:
40-
"""CLI arguments for converting datasets from one file format to another.
38+
class Args(cli_utils.Args):
39+
"""Converts a dataset from one file format to another format.
4140
4241
Attributes:
4342
root_data_dir: Root data dir that contains all datasets. All datasets and
@@ -94,14 +93,3 @@ def execute(self) -> None:
9493
num_workers=self.num_workers,
9594
fail_on_error=not self.only_log_errors,
9695
)
97-
98-
99-
def register_subparser(parsers: argparse._SubParsersAction) -> None:
100-
"""Add subparser for `convert_format` command."""
101-
parser = parsers.add_parser(
102-
'convert_format',
103-
help='Converts a dataset from one file format to another format.',
104-
)
105-
parser = typing.cast(simple_parsing.ArgumentParser, parser)
106-
parser.add_arguments(Args, dest='args')
107-
parser.set_defaults(subparser_fn=lambda args: args.args.execute())

tensorflow_datasets/scripts/cli/croissant.py

Lines changed: 6 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@
2626
```
2727
"""
2828

29-
import argparse
3029
import dataclasses
3130
import functools
3231
import json
33-
import typing
3432

3533
from etils import epath
3634
import mlcroissant as mlc
@@ -43,8 +41,8 @@
4341

4442

4543
@dataclasses.dataclass(frozen=True, kw_only=True)
46-
class CmdArgs(simple_parsing.helpers.FrozenSerializable):
47-
"""CLI arguments for preparing a Croissant dataset.
44+
class CmdArgs(simple_parsing.helpers.FrozenSerializable, cli_utils.Args):
45+
"""Prepares a Croissant dataset.
4846
4947
Attributes:
5048
jsonld: Path to the JSONLD file.
@@ -122,18 +120,10 @@ def version(self) -> version_lib.Version:
122120
self.overwrite_version or self.dataset.metadata.version or '1.0.0'
123121
)
124122

125-
126-
def register_subparser(parsers: argparse._SubParsersAction):
127-
"""Add subparser for `convert_format` command."""
128-
parser = parsers.add_parser(
129-
'build_croissant',
130-
help='Prepares a croissant dataset',
131-
)
132-
parser = typing.cast(simple_parsing.ArgumentParser, parser)
133-
parser.add_arguments(CmdArgs, dest='args')
134-
parser.set_defaults(
135-
subparser_fn=lambda args: prepare_croissant_builders(args.args)
136-
)
123+
def execute(self) -> None:
124+
"""Creates Croissant Builders and prepares them."""
125+
for record_set_id in self.record_set_ids:
126+
prepare_croissant_builder(args=self, record_set_id=record_set_id)
137127

138128

139129
def prepare_croissant_builder(
@@ -163,14 +153,3 @@ def prepare_croissant_builder(
163153
beam_pipeline_options=None,
164154
)
165155
return builder
166-
167-
168-
def prepare_croissant_builders(args: CmdArgs):
169-
"""Creates Croissant Builders and prepares them.
170-
171-
Args:
172-
args: CLI arguments.
173-
"""
174-
# Generate each config sequentially.
175-
for record_set_id in args.record_set_ids:
176-
prepare_croissant_builder(args=args, record_set_id=record_set_id)

tensorflow_datasets/scripts/cli/main.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@
2121
See: https://www.tensorflow.org/datasets/cli
2222
"""
2323

24-
import argparse
24+
import dataclasses
2525
import logging as python_logging
26-
from typing import List
2726

2827
from absl import app
2928
from absl import flags
3029
from absl import logging
30+
import simple_parsing
3131

3232
import tensorflow_datasets.public_api as tfds
3333

@@ -41,33 +41,60 @@
4141
FLAGS = flags.FLAGS
4242

4343

44-
def _parse_flags(argv: List[str]) -> argparse.Namespace:
45-
"""Command lines flag parsing."""
46-
parser = cli_utils.ArgumentParser(
47-
description='Tensorflow Datasets CLI tool',
48-
allow_abbrev=False,
49-
)
50-
parser.add_argument(
51-
'--version',
52-
action='version',
53-
version='TensorFlow Datasets: ' + tfds.__version__,
54-
)
55-
parser.add_argument(
56-
'--dry_run',
57-
action='store_true',
58-
help='If True, print the parsed arguments.',
44+
@dataclasses.dataclass(frozen=True, kw_only=True)
45+
class _DummyCommand:
46+
"""Dummy command to avoid `command is MISSING` error."""
47+
48+
pass
49+
50+
51+
version_field = simple_parsing.field(
52+
action='version',
53+
version='TensorFlow Datasets: ' + tfds.__version__,
54+
help='The version of the TensorFlow Datasets package.',
55+
)
56+
57+
58+
@dataclasses.dataclass(frozen=True, kw_only=True)
59+
class Args(cli_utils.Args):
60+
"""Tensorflow Datasets CLI tool."""
61+
62+
version: str = version_field
63+
"""The version of the TensorFlow Datasets package."""
64+
65+
dry_run: bool = simple_parsing.flag(default=False)
66+
"""If True, print the parsed arguments and exit."""
67+
68+
command: build.Args | new.Args | convert_format.Args | croissant.CmdArgs = (
69+
simple_parsing.subparsers(
70+
{
71+
'build': build.Args,
72+
'new': new.Args,
73+
'convert_format': convert_format.Args,
74+
'build_croissant': croissant.CmdArgs,
75+
},
76+
default_factory=_DummyCommand,
77+
)
5978
)
60-
parser.set_defaults(subparser_fn=lambda _: parser.print_help())
61-
# Register sub-commands
62-
subparser = parser.add_subparsers(title='command')
63-
build.register_subparser(subparser)
64-
new.register_subparser(subparser)
65-
convert_format.register_subparser(subparser)
66-
croissant.register_subparser(subparser)
67-
return parser.parse_args(argv[1:])
79+
"""The command to execute."""
80+
81+
def execute(self) -> None:
82+
"""Run the command."""
83+
if self.dry_run:
84+
print(self)
85+
# When no command is given, print the help message.
86+
elif isinstance(self.command, _DummyCommand):
87+
_parse_flags(['', '--help'])
88+
else:
89+
self.command.execute()
90+
91+
92+
_parse_flags = cli_utils.make_flags_parser(
93+
Args, description='Tensorflow Datasets CLI tool'
94+
)
6895

6996

70-
def main(args: argparse.Namespace) -> None:
97+
def main(args: Args) -> None:
7198

7299
# From the CLI, all datasets are visible
73100
tfds.core.visibility.set_availables([
@@ -98,11 +125,7 @@ def main(args: argparse.Namespace) -> None:
98125
new_stream = tfds.core.utils.tqdm_utils.TqdmStream()
99126
python_handler.setStream(new_stream)
100127

101-
if args.dry_run:
102-
print(args)
103-
else:
104-
# Launch the subcommand defined in the subparser (or default to print help)
105-
args.subparser_fn(args)
128+
args.execute()
106129

107130

108131
def launch_cli() -> None:

tensorflow_datasets/scripts/cli/new.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,11 @@
1515

1616
"""`tfds new` command."""
1717

18-
import argparse
1918
import dataclasses
2019
import os
2120
import pathlib
2221
import subprocess
2322
import textwrap
24-
import typing
2523

2624
import simple_parsing
2725
from tensorflow_datasets.core import constants
@@ -33,8 +31,8 @@
3331

3432

3533
@dataclasses.dataclass(frozen=True, kw_only=True)
36-
class Args:
37-
"""CLI arguments for creating a new dataset directory.
34+
class Args(utils.Args):
35+
"""Creates a new dataset directory from the template.
3836
3937
Attributes:
4038
dataset_name: Name of the dataset to be created (in snake_case).
@@ -71,17 +69,6 @@ def execute(self) -> None:
7169
)
7270

7371

74-
def register_subparser(parsers: argparse._SubParsersAction) -> None:
75-
"""Add subparser for `new` command."""
76-
parser = parsers.add_parser(
77-
'new',
78-
help='Creates a new dataset directory from the template.',
79-
)
80-
parser = typing.cast(simple_parsing.ArgumentParser, parser)
81-
parser.add_arguments(Args, dest='args')
82-
parser.set_defaults(subparser_fn=lambda args: args.args.execute())
83-
84-
8572
def create_dataset_files(
8673
dataset_name: str,
8774
dataset_dir: pathlib.Path,

0 commit comments

Comments
 (0)