Skip to content

Commit ea0db5d

Browse files
authored
Converter binary: call tf.app.run() (#410)
- instead of calling main() directly. Fixes b/139308919 BUG
1 parent 8efdabd commit ea0db5d

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

python/run-python-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TEST_FILES="$(find "${SCRIPTS_DIR}" -name '*_test.py')"
2424

2525
pip install virtualenv
2626

27-
TMP_VENV_DIR="$(mktemp -d)"
27+
TMP_VENV_DIR="$(mktemp -d --suffix=_venv)"
2828
virtualenv -p "python" "${TMP_VENV_DIR}"
2929
source "${TMP_VENV_DIR}/bin/activate"
3030

python/tensorflowjs/converters/converter.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import json
2323
import os
2424
import shutil
25+
import sys
2526
import tempfile
2627

2728
import h5py
@@ -434,7 +435,8 @@ def _parse_quantization_bytes(quantization_bytes):
434435
raise ValueError('Unsupported quantization bytes: %s' % quantization_bytes)
435436

436437

437-
def setup_arguments():
438+
def get_arg_parser():
439+
"""Create the argument parser for the converter binary."""
438440
parser = argparse.ArgumentParser('TensorFlow.js model converters.')
439441
parser.add_argument(
440442
'input_path',
@@ -522,11 +524,22 @@ def setup_arguments():
522524
default=None,
523525
help='Shard size (in bytes) of the weight files. Currently applicable '
524526
'only to output_format=tfjs_layers_model.')
525-
return parser.parse_args()
527+
return parser
526528

527529

528530
def main():
529-
FLAGS = setup_arguments()
531+
try:
532+
FLAGS
533+
except NameError:
534+
# This code path is added in addition to to the flags-parsing
535+
# code in the `__name__ == '__main__'` branch below, because it is
536+
# required by the pip-packaged binary case. The pip-packaged binary calls
537+
# the `main()` method directly and therefore by passes the
538+
# `__name__ == '__main__'` branch.
539+
# pylint: disable=redefined-outer-name,invalid-name
540+
FLAGS = get_arg_parser().parse_args()
541+
# pylint: enable=redefined-outer-name,invalid-name
542+
530543
if FLAGS.show_version:
531544
print('\ntensorflowjs %s\n' % version.version)
532545
print('Dependency versions:')
@@ -623,4 +636,7 @@ def main():
623636

624637

625638
if __name__ == '__main__':
626-
main()
639+
# pylint: disable=redefined-outer-name,invalid-name
640+
FLAGS, unparsed = get_arg_parser().parse_known_args()
641+
# pylint: enable=redefined-outer-name,invalid-name
642+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 commit comments

Comments
 (0)