Skip to content

Commit 998cc54

Browse files
authored
Add sh_test to cover the converter binary (#412)
- Also in this change: Simplify the argument parsing code in converter.py - Replace variable name "FLAGS" with "args" to remove the need for disabling pylint. Fixes b/139308919
1 parent ea0db5d commit 998cc54

File tree

5 files changed

+204
-52
lines changed

5 files changed

+204
-52
lines changed

python/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def _get_requirements(file):
2626
return requirements.readlines()
2727

2828
CONSOLE_SCRIPTS = [
29-
'tensorflowjs_converter = tensorflowjs.converters.converter:main',
29+
'tensorflowjs_converter = tensorflowjs.converters.converter:pip_main',
3030
]
3131

3232
setuptools.setup(

python/tensorflowjs/converters/BUILD

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,16 @@ py_binary(
101101
visibility = ["//visibility:public"],
102102
)
103103

104+
py_binary(
105+
name = "generate_test_model",
106+
srcs = ["generate_test_model.py"],
107+
testonly = True,
108+
srcs_version = "PY2AND3",
109+
deps = [
110+
"//tensorflowjs:expect_tensorflow_installed",
111+
]
112+
)
113+
104114
py_test(
105115
name = "converter_test",
106116
srcs = ["converter_test.py"],
@@ -113,3 +123,12 @@ py_test(
113123
"//tensorflowjs:version",
114124
],
115125
)
126+
127+
sh_test(
128+
name = "converter_binary_test",
129+
srcs = ["converter_binary_test.sh"],
130+
data = [
131+
":converter",
132+
":generate_test_model",
133+
],
134+
)

python/tensorflowjs/converters/converter.py

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -527,47 +527,54 @@ def get_arg_parser():
527527
return parser
528528

529529

530-
def main():
531-
try:
532-
FLAGS
533-
except NameError:
534-
# This code path is added in addition to to the flags-parsing
535-
# code in the `__name__ == '__main__'` branch below, because it is
536-
# required by the pip-packaged binary case. The pip-packaged binary calls
537-
# the `main()` method directly and therefore by passes the
538-
# `__name__ == '__main__'` branch.
539-
# pylint: disable=redefined-outer-name,invalid-name
540-
FLAGS = get_arg_parser().parse_args()
541-
# pylint: enable=redefined-outer-name,invalid-name
542-
543-
if FLAGS.show_version:
530+
def pip_main():
531+
"""Entry point for pip-packaged binary.
532+
533+
Note that pip-packaged binary calls the entry method without
534+
any arguments, which is why this method is needed in addition to the
535+
`main` method below.
536+
"""
537+
main([' '.join(sys.argv[1:])])
538+
539+
540+
def main(argv):
541+
args = get_arg_parser().parse_args(argv[0].split(' '))
542+
543+
if args.show_version:
544544
print('\ntensorflowjs %s\n' % version.version)
545545
print('Dependency versions:')
546546
print(' keras %s' % keras.__version__)
547547
print(' tensorflow %s' % tf.__version__)
548548
return
549549

550+
if not args.input_path:
551+
raise ValueError(
552+
'Missing input_path argument. For usage, use the --help flag.')
553+
if not args.output_path:
554+
raise ValueError(
555+
'Missing output_path argument. For usage, use the --help flag.')
556+
550557
weight_shard_size_bytes = 1024 * 1024 * 4
551-
if FLAGS.weight_shard_size_bytes:
552-
if FLAGS.output_format != 'tfjs_layers_model':
558+
if args.weight_shard_size_bytes:
559+
if args.output_format != 'tfjs_layers_model':
553560
raise ValueError(
554561
'The --weight_shard_size_byte flag is only supported under '
555562
'output_format=tfjs_layers_model.')
556-
weight_shard_size_bytes = FLAGS.weight_shard_size_bytes
563+
weight_shard_size_bytes = args.weight_shard_size_bytes
557564

558-
if FLAGS.input_path is None:
565+
if args.input_path is None:
559566
raise ValueError(
560567
'Error: The input_path argument must be set. '
561568
'Run with --help flag for usage information.')
562569

563570
input_format, output_format = _standardize_input_output_formats(
564-
FLAGS.input_format, FLAGS.output_format)
571+
args.input_format, args.output_format)
565572

566573
quantization_dtype = (
567-
quantization.QUANTIZATION_BYTES_TO_DTYPES[FLAGS.quantization_bytes]
568-
if FLAGS.quantization_bytes else None)
574+
quantization.QUANTIZATION_BYTES_TO_DTYPES[args.quantization_bytes]
575+
if args.quantization_bytes else None)
569576

570-
if (FLAGS.signature_name and input_format not in
577+
if (args.signature_name and input_format not in
571578
('tf_saved_model', 'tf_hub')):
572579
raise ValueError(
573580
'The --signature_name flag is applicable only to "tf_saved_model" and '
@@ -578,65 +585,62 @@ def main():
578585
# branches below.
579586
if input_format == 'keras' and output_format == 'tfjs_layers_model':
580587
dispatch_keras_h5_to_tfjs_layers_model_conversion(
581-
FLAGS.input_path, output_dir=FLAGS.output_path,
588+
args.input_path, output_dir=args.output_path,
582589
quantization_dtype=quantization_dtype,
583-
split_weights_by_layer=FLAGS.split_weights_by_layer)
590+
split_weights_by_layer=args.split_weights_by_layer)
584591
elif input_format == 'keras' and output_format == 'tfjs_graph_model':
585592
dispatch_keras_h5_to_tfjs_graph_model_conversion(
586-
FLAGS.input_path, output_dir=FLAGS.output_path,
593+
args.input_path, output_dir=args.output_path,
587594
quantization_dtype=quantization_dtype,
588-
skip_op_check=FLAGS.skip_op_check,
589-
strip_debug_ops=FLAGS.strip_debug_ops)
595+
skip_op_check=args.skip_op_check,
596+
strip_debug_ops=args.strip_debug_ops)
590597
elif (input_format == 'keras_saved_model' and
591598
output_format == 'tfjs_layers_model'):
592599
dispatch_keras_saved_model_to_tensorflowjs_conversion(
593-
FLAGS.input_path, FLAGS.output_path,
600+
args.input_path, args.output_path,
594601
quantization_dtype=quantization_dtype,
595-
split_weights_by_layer=FLAGS.split_weights_by_layer)
602+
split_weights_by_layer=args.split_weights_by_layer)
596603
elif (input_format == 'tf_saved_model' and
597604
output_format == 'tfjs_graph_model'):
598605
tf_saved_model_conversion_v2.convert_tf_saved_model(
599-
FLAGS.input_path, FLAGS.output_path,
600-
signature_def=FLAGS.signature_name,
601-
saved_model_tags=FLAGS.saved_model_tags,
606+
args.input_path, args.output_path,
607+
signature_def=args.signature_name,
608+
saved_model_tags=args.saved_model_tags,
602609
quantization_dtype=quantization_dtype,
603-
skip_op_check=FLAGS.skip_op_check,
604-
strip_debug_ops=FLAGS.strip_debug_ops)
610+
skip_op_check=args.skip_op_check,
611+
strip_debug_ops=args.strip_debug_ops)
605612
elif (input_format == 'tf_hub' and
606613
output_format == 'tfjs_graph_model'):
607614
tf_saved_model_conversion_v2.convert_tf_hub_module(
608-
FLAGS.input_path, FLAGS.output_path, FLAGS.signature_name,
609-
FLAGS.saved_model_tags, skip_op_check=FLAGS.skip_op_check,
610-
strip_debug_ops=FLAGS.strip_debug_ops)
615+
args.input_path, args.output_path, args.signature_name,
616+
args.saved_model_tags, skip_op_check=args.skip_op_check,
617+
strip_debug_ops=args.strip_debug_ops)
611618
elif (input_format == 'tfjs_layers_model' and
612619
output_format == 'keras'):
613-
dispatch_tensorflowjs_to_keras_h5_conversion(FLAGS.input_path,
614-
FLAGS.output_path)
620+
dispatch_tensorflowjs_to_keras_h5_conversion(args.input_path,
621+
args.output_path)
615622
elif (input_format == 'tfjs_layers_model' and
616623
output_format == 'keras_saved_model'):
617-
dispatch_tensorflowjs_to_keras_saved_model_conversion(FLAGS.input_path,
618-
FLAGS.output_path)
624+
dispatch_tensorflowjs_to_keras_saved_model_conversion(args.input_path,
625+
args.output_path)
619626
elif (input_format == 'tfjs_layers_model' and
620627
output_format == 'tfjs_layers_model'):
621628
dispatch_tensorflowjs_to_tensorflowjs_conversion(
622-
FLAGS.input_path, FLAGS.output_path,
623-
quantization_dtype=_parse_quantization_bytes(FLAGS.quantization_bytes),
629+
args.input_path, args.output_path,
630+
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
624631
weight_shard_size_bytes=weight_shard_size_bytes)
625632
elif (input_format == 'tfjs_layers_model' and
626633
output_format == 'tfjs_graph_model'):
627634
dispatch_tfjs_layers_model_to_tfjs_graph_conversion(
628-
FLAGS.input_path, FLAGS.output_path,
629-
quantization_dtype=_parse_quantization_bytes(FLAGS.quantization_bytes),
630-
skip_op_check=FLAGS.skip_op_check,
631-
strip_debug_ops=FLAGS.strip_debug_ops)
635+
args.input_path, args.output_path,
636+
quantization_dtype=_parse_quantization_bytes(args.quantization_bytes),
637+
skip_op_check=args.skip_op_check,
638+
strip_debug_ops=args.strip_debug_ops)
632639
else:
633640
raise ValueError(
634641
'Unsupported input_format - output_format pair: %s - %s' %
635642
(input_format, output_format))
636643

637644

638645
if __name__ == '__main__':
639-
# pylint: disable=redefined-outer-name,invalid-name
640-
FLAGS, unparsed = get_arg_parser().parse_known_args()
641-
# pylint: enable=redefined-outer-name,invalid-name
642-
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
646+
tf.app.run(main=main, argv=[' '.join(sys.argv[1:])])
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
set -e
17+
18+
GENERATE_BIN="${TEST_SRCDIR}/org_tensorflow_js/tensorflowjs/converters/generate_test_model"
19+
CONVERTER_BIN="${TEST_SRCDIR}/org_tensorflow_js/tensorflowjs/converters/converter"
20+
21+
# 1. Test tf_saved_model --> tfjs_graph_model conversion.
22+
SAVED_MODEL_DIR="$(mktemp -d)"
23+
echo "Genearting TF SavedModel for testing..."
24+
"${GENERATE_BIN}" "${SAVED_MODEL_DIR}" --model_type tf_saved_model
25+
echo "Done genearting TF SavedModel for testing at ${SAVED_MODEL_DIR}"
26+
27+
OUTPUT_DIR="${SAVED_MODEL_DIR}_converted"
28+
"${CONVERTER_BIN}" \
29+
--input_format tf_saved_model \
30+
--output_format tfjs_graph_model \
31+
"${SAVED_MODEL_DIR}" \
32+
"${OUTPUT_DIR}"
33+
34+
if [[ ! -d "${OUTPUT_DIR}" ]]; then
35+
echo "ERROR: Failed to find conversion output directory: ${OUTPUT_DIR}" 1>&2
36+
exit 1
37+
fi
38+
39+
# Clean up files.
40+
rm -rf "${SAVED_MODEL_DIR}" "${OUTPUT_DIR}"
41+
42+
# 2. Test keras HDF5 --> tfjs_layers_model conversion.
43+
KERAS_H5_PATH="$(mktemp).h5"
44+
echo "Genearting Keras HDF5 model for testing..."
45+
"${GENERATE_BIN}" "${KERAS_H5_PATH}" --model_type tf_keras_h5
46+
echo "Done genearting Keras HDF5 model for testing at ${KERAS_H5_PATH}"
47+
48+
OUTPUT_H5_PATH="${KERAS_H5_PATH}_converted.h5"
49+
"${CONVERTER_BIN}" \
50+
--input_format keras \
51+
--output_format tfjs_layers_model \
52+
"${KERAS_H5_PATH}" \
53+
"${OUTPUT_H5_PATH}"
54+
55+
if [[ ! -d "${OUTPUT_H5_PATH}" ]]; then
56+
echo "ERROR: Failed to find conversion output directory: ${OUTPUT_H5_PATH}" 1>&2
57+
exit 1
58+
fi
59+
60+
# Clean up files.
61+
rm -rf "${KERAS_H5_PATH}" "${OUTPUT_H5_PATH}"
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2019 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""A binary that generates saved model artifacts for testing."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import argparse
22+
import os
23+
import sys
24+
25+
import tensorflow as tf
26+
27+
tf.enable_eager_execution()
28+
29+
30+
def parse_args():
31+
parser = argparse.ArgumentParser(
32+
'Generates saved model artifacts for testing.')
33+
parser.add_argument(
34+
'output_path',
35+
type=str,
36+
help='Model output path.')
37+
parser.add_argument(
38+
'--model_type',
39+
type=str,
40+
required=True,
41+
choices=set(['tf_keras_h5', 'tf_saved_model']),
42+
help='Model format to generate.')
43+
return parser.parse_known_args()
44+
45+
46+
def main(_):
47+
48+
if args.model_type == 'tf_keras_h5':
49+
model = tf.keras.Sequential()
50+
model.add(tf.keras.layers.Dense(5, activation='relu', input_shape=(8,)))
51+
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
52+
model.save(os.path.join(args.output_path))
53+
elif args.model_type == 'tf_saved_model':
54+
class TimesThreePlusOne(tf.train.Checkpoint):
55+
56+
@tf.function(input_signature=[
57+
tf.TensorSpec(shape=None, dtype=tf.float32)])
58+
def compute(self, x):
59+
return x * 3.0 + 1.0
60+
61+
tf.saved_model.save(TimesThreePlusOne(), args.output_path)
62+
else:
63+
raise ValueError('Unrecognized model type: %s' % args.model_type)
64+
65+
66+
if __name__ == '__main__':
67+
args, unparsed = parse_args()
68+
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 commit comments

Comments
 (0)