|
22 | 22 | import json |
23 | 23 | import os |
24 | 24 | import shutil |
| 25 | +import sys |
25 | 26 | import tempfile |
26 | 27 |
|
27 | 28 | import h5py |
@@ -434,7 +435,8 @@ def _parse_quantization_bytes(quantization_bytes): |
434 | 435 | raise ValueError('Unsupported quantization bytes: %s' % quantization_bytes) |
435 | 436 |
|
436 | 437 |
|
437 | | -def setup_arguments(): |
| 438 | +def get_arg_parser(): |
| 439 | + """Create the argument parser for the converter binary.""" |
438 | 440 | parser = argparse.ArgumentParser('TensorFlow.js model converters.') |
439 | 441 | parser.add_argument( |
440 | 442 | 'input_path', |
@@ -522,11 +524,22 @@ def setup_arguments(): |
522 | 524 | default=None, |
523 | 525 | help='Shard size (in bytes) of the weight files. Currently applicable ' |
524 | 526 | 'only to output_format=tfjs_layers_model.') |
525 | | - return parser.parse_args() |
| 527 | + return parser |
526 | 528 |
|
527 | 529 |
|
528 | 530 | def main(): |
529 | | - FLAGS = setup_arguments() |
| 531 | + try: |
| 532 | + FLAGS |
| 533 | + except NameError: |
| 534 | + # This code path is added in addition to to the flags-parsing |
| 535 | + # code in the `__name__ == '__main__'` branch below, because it is |
| 536 | + # required by the pip-packaged binary case. The pip-packaged binary calls |
| 537 | + # the `main()` method directly and therefore by passes the |
| 538 | + # `__name__ == '__main__'` branch. |
| 539 | + # pylint: disable=redefined-outer-name,invalid-name |
| 540 | + FLAGS = get_arg_parser().parse_args() |
| 541 | + # pylint: enable=redefined-outer-name,invalid-name |
| 542 | + |
530 | 543 | if FLAGS.show_version: |
531 | 544 | print('\ntensorflowjs %s\n' % version.version) |
532 | 545 | print('Dependency versions:') |
@@ -623,4 +636,7 @@ def main(): |
623 | 636 |
|
624 | 637 |
|
625 | 638 | if __name__ == '__main__': |
626 | | - main() |
| 639 | + # pylint: disable=redefined-outer-name,invalid-name |
| 640 | + FLAGS, unparsed = get_arg_parser().parse_known_args() |
| 641 | + # pylint: enable=redefined-outer-name,invalid-name |
| 642 | + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) |
0 commit comments