2222import json
2323import os
2424import shutil
25+ import sys
2526import tempfile
2627
2728import h5py
@@ -434,7 +435,8 @@ def _parse_quantization_bytes(quantization_bytes):
434435 raise ValueError ('Unsupported quantization bytes: %s' % quantization_bytes )
435436
436437
437- def setup_arguments ():
438+ def get_arg_parser ():
439+ """Create the argument parser for the converter binary."""
438440 parser = argparse .ArgumentParser ('TensorFlow.js model converters.' )
439441 parser .add_argument (
440442 'input_path' ,
@@ -522,39 +524,57 @@ def setup_arguments():
522524 default = None ,
523525 help = 'Shard size (in bytes) of the weight files. Currently applicable '
524526 'only to output_format=tfjs_layers_model.' )
525- return parser . parse_args ()
527+ return parser
526528
527529
528- def main ():
529- FLAGS = setup_arguments ()
530- if FLAGS .show_version :
530+ def pip_main ():
531+ """Entry point for pip-packaged binary.
532+
533+ Note that pip-packaged binary calls the entry method without
534+ any arguments, which is why this method is needed in addition to the
535+ `main` method below.
536+ """
537+ main ([' ' .join (sys .argv [1 :])])
538+
539+
540+ def main (argv ):
541+ args = get_arg_parser ().parse_args (argv [0 ].split (' ' ))
542+
543+ if args .show_version :
531544 print ('\n tensorflowjs %s\n ' % version .version )
532545 print ('Dependency versions:' )
533546 print (' keras %s' % keras .__version__ )
534547 print (' tensorflow %s' % tf .__version__ )
535548 return
536549
550+ if not args .input_path :
551+ raise ValueError (
552+ 'Missing input_path argument. For usage, use the --help flag.' )
553+ if not args .output_path :
554+ raise ValueError (
555+ 'Missing output_path argument. For usage, use the --help flag.' )
556+
537557 weight_shard_size_bytes = 1024 * 1024 * 4
538- if FLAGS .weight_shard_size_bytes :
539- if FLAGS .output_format != 'tfjs_layers_model' :
558+ if args .weight_shard_size_bytes :
559+ if args .output_format != 'tfjs_layers_model' :
540560 raise ValueError (
541561 'The --weight_shard_size_byte flag is only supported under '
542562 'output_format=tfjs_layers_model.' )
543- weight_shard_size_bytes = FLAGS .weight_shard_size_bytes
563+ weight_shard_size_bytes = args .weight_shard_size_bytes
544564
545- if FLAGS .input_path is None :
565+ if args .input_path is None :
546566 raise ValueError (
547567 'Error: The input_path argument must be set. '
548568 'Run with --help flag for usage information.' )
549569
550570 input_format , output_format = _standardize_input_output_formats (
551- FLAGS .input_format , FLAGS .output_format )
571+ args .input_format , args .output_format )
552572
553573 quantization_dtype = (
554- quantization .QUANTIZATION_BYTES_TO_DTYPES [FLAGS .quantization_bytes ]
555- if FLAGS .quantization_bytes else None )
574+ quantization .QUANTIZATION_BYTES_TO_DTYPES [args .quantization_bytes ]
575+ if args .quantization_bytes else None )
556576
557- if (FLAGS .signature_name and input_format not in
577+ if (args .signature_name and input_format not in
558578 ('tf_saved_model' , 'tf_hub' )):
559579 raise ValueError (
560580 'The --signature_name flag is applicable only to "tf_saved_model" and '
@@ -565,62 +585,62 @@ def main():
565585 # branches below.
566586 if input_format == 'keras' and output_format == 'tfjs_layers_model' :
567587 dispatch_keras_h5_to_tfjs_layers_model_conversion (
568- FLAGS .input_path , output_dir = FLAGS .output_path ,
588+ args .input_path , output_dir = args .output_path ,
569589 quantization_dtype = quantization_dtype ,
570- split_weights_by_layer = FLAGS .split_weights_by_layer )
590+ split_weights_by_layer = args .split_weights_by_layer )
571591 elif input_format == 'keras' and output_format == 'tfjs_graph_model' :
572592 dispatch_keras_h5_to_tfjs_graph_model_conversion (
573- FLAGS .input_path , output_dir = FLAGS .output_path ,
593+ args .input_path , output_dir = args .output_path ,
574594 quantization_dtype = quantization_dtype ,
575- skip_op_check = FLAGS .skip_op_check ,
576- strip_debug_ops = FLAGS .strip_debug_ops )
595+ skip_op_check = args .skip_op_check ,
596+ strip_debug_ops = args .strip_debug_ops )
577597 elif (input_format == 'keras_saved_model' and
578598 output_format == 'tfjs_layers_model' ):
579599 dispatch_keras_saved_model_to_tensorflowjs_conversion (
580- FLAGS .input_path , FLAGS .output_path ,
600+ args .input_path , args .output_path ,
581601 quantization_dtype = quantization_dtype ,
582- split_weights_by_layer = FLAGS .split_weights_by_layer )
602+ split_weights_by_layer = args .split_weights_by_layer )
583603 elif (input_format == 'tf_saved_model' and
584604 output_format == 'tfjs_graph_model' ):
585605 tf_saved_model_conversion_v2 .convert_tf_saved_model (
586- FLAGS .input_path , FLAGS .output_path ,
587- signature_def = FLAGS .signature_name ,
588- saved_model_tags = FLAGS .saved_model_tags ,
606+ args .input_path , args .output_path ,
607+ signature_def = args .signature_name ,
608+ saved_model_tags = args .saved_model_tags ,
589609 quantization_dtype = quantization_dtype ,
590- skip_op_check = FLAGS .skip_op_check ,
591- strip_debug_ops = FLAGS .strip_debug_ops )
610+ skip_op_check = args .skip_op_check ,
611+ strip_debug_ops = args .strip_debug_ops )
592612 elif (input_format == 'tf_hub' and
593613 output_format == 'tfjs_graph_model' ):
594614 tf_saved_model_conversion_v2 .convert_tf_hub_module (
595- FLAGS .input_path , FLAGS .output_path , FLAGS .signature_name ,
596- FLAGS .saved_model_tags , skip_op_check = FLAGS .skip_op_check ,
597- strip_debug_ops = FLAGS .strip_debug_ops )
615+ args .input_path , args .output_path , args .signature_name ,
616+ args .saved_model_tags , skip_op_check = args .skip_op_check ,
617+ strip_debug_ops = args .strip_debug_ops )
598618 elif (input_format == 'tfjs_layers_model' and
599619 output_format == 'keras' ):
600- dispatch_tensorflowjs_to_keras_h5_conversion (FLAGS .input_path ,
601- FLAGS .output_path )
620+ dispatch_tensorflowjs_to_keras_h5_conversion (args .input_path ,
621+ args .output_path )
602622 elif (input_format == 'tfjs_layers_model' and
603623 output_format == 'keras_saved_model' ):
604- dispatch_tensorflowjs_to_keras_saved_model_conversion (FLAGS .input_path ,
605- FLAGS .output_path )
624+ dispatch_tensorflowjs_to_keras_saved_model_conversion (args .input_path ,
625+ args .output_path )
606626 elif (input_format == 'tfjs_layers_model' and
607627 output_format == 'tfjs_layers_model' ):
608628 dispatch_tensorflowjs_to_tensorflowjs_conversion (
609- FLAGS .input_path , FLAGS .output_path ,
610- quantization_dtype = _parse_quantization_bytes (FLAGS .quantization_bytes ),
629+ args .input_path , args .output_path ,
630+ quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
611631 weight_shard_size_bytes = weight_shard_size_bytes )
612632 elif (input_format == 'tfjs_layers_model' and
613633 output_format == 'tfjs_graph_model' ):
614634 dispatch_tfjs_layers_model_to_tfjs_graph_conversion (
615- FLAGS .input_path , FLAGS .output_path ,
616- quantization_dtype = _parse_quantization_bytes (FLAGS .quantization_bytes ),
617- skip_op_check = FLAGS .skip_op_check ,
618- strip_debug_ops = FLAGS .strip_debug_ops )
635+ args .input_path , args .output_path ,
636+ quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
637+ skip_op_check = args .skip_op_check ,
638+ strip_debug_ops = args .strip_debug_ops )
619639 else :
620640 raise ValueError (
621641 'Unsupported input_format - output_format pair: %s - %s' %
622642 (input_format , output_format ))
623643
624644
625645if __name__ == '__main__' :
626- main ( )
646+ tf . app . run ( main = main , argv = [ ' ' . join ( sys . argv [ 1 :])] )
0 commit comments