@@ -527,47 +527,54 @@ def get_arg_parser():
527527 return parser
528528
529529
530- def main ():
531- try :
532- FLAGS
533- except NameError :
534- # This code path is added in addition to to the flags-parsing
535- # code in the `__name__ == '__main__'` branch below, because it is
536- # required by the pip-packaged binary case. The pip-packaged binary calls
537- # the ` main()` method directly and therefore by passes the
538- # `__name__ == '__main__'` branch.
539- # pylint: disable=redefined-outer-name,invalid-name
540- FLAGS = get_arg_parser (). parse_args ()
541- # pylint: enable=redefined-outer-name,invalid-name
542-
543- if FLAGS .show_version :
530+ def pip_main ():
531+ """Entry point for pip-packaged binary.
532+
533+ Note that pip-packaged binary calls the entry method without
534+ any arguments, which is why this method is needed in addition to the
535+ `main` method below.
536+ """
537+ main ([ ' ' . join ( sys . argv [ 1 :])])
538+
539+
540+ def main ( argv ):
541+ args = get_arg_parser (). parse_args ( argv [ 0 ]. split ( ' ' ))
542+
543+ if args .show_version :
544544 print ('\n tensorflowjs %s\n ' % version .version )
545545 print ('Dependency versions:' )
546546 print (' keras %s' % keras .__version__ )
547547 print (' tensorflow %s' % tf .__version__ )
548548 return
549549
550+ if not args .input_path :
551+ raise ValueError (
552+ 'Missing input_path argument. For usage, use the --help flag.' )
553+ if not args .output_path :
554+ raise ValueError (
555+ 'Missing output_path argument. For usage, use the --help flag.' )
556+
550557 weight_shard_size_bytes = 1024 * 1024 * 4
551- if FLAGS .weight_shard_size_bytes :
552- if FLAGS .output_format != 'tfjs_layers_model' :
558+ if args .weight_shard_size_bytes :
559+ if args .output_format != 'tfjs_layers_model' :
553560 raise ValueError (
554561 'The --weight_shard_size_byte flag is only supported under '
555562 'output_format=tfjs_layers_model.' )
556- weight_shard_size_bytes = FLAGS .weight_shard_size_bytes
563+ weight_shard_size_bytes = args .weight_shard_size_bytes
557564
558- if FLAGS .input_path is None :
565+ if args .input_path is None :
559566 raise ValueError (
560567 'Error: The input_path argument must be set. '
561568 'Run with --help flag for usage information.' )
562569
563570 input_format , output_format = _standardize_input_output_formats (
564- FLAGS .input_format , FLAGS .output_format )
571+ args .input_format , args .output_format )
565572
566573 quantization_dtype = (
567- quantization .QUANTIZATION_BYTES_TO_DTYPES [FLAGS .quantization_bytes ]
568- if FLAGS .quantization_bytes else None )
574+ quantization .QUANTIZATION_BYTES_TO_DTYPES [args .quantization_bytes ]
575+ if args .quantization_bytes else None )
569576
570- if (FLAGS .signature_name and input_format not in
577+ if (args .signature_name and input_format not in
571578 ('tf_saved_model' , 'tf_hub' )):
572579 raise ValueError (
573580 'The --signature_name flag is applicable only to "tf_saved_model" and '
@@ -578,65 +585,62 @@ def main():
578585 # branches below.
579586 if input_format == 'keras' and output_format == 'tfjs_layers_model' :
580587 dispatch_keras_h5_to_tfjs_layers_model_conversion (
581- FLAGS .input_path , output_dir = FLAGS .output_path ,
588+ args .input_path , output_dir = args .output_path ,
582589 quantization_dtype = quantization_dtype ,
583- split_weights_by_layer = FLAGS .split_weights_by_layer )
590+ split_weights_by_layer = args .split_weights_by_layer )
584591 elif input_format == 'keras' and output_format == 'tfjs_graph_model' :
585592 dispatch_keras_h5_to_tfjs_graph_model_conversion (
586- FLAGS .input_path , output_dir = FLAGS .output_path ,
593+ args .input_path , output_dir = args .output_path ,
587594 quantization_dtype = quantization_dtype ,
588- skip_op_check = FLAGS .skip_op_check ,
589- strip_debug_ops = FLAGS .strip_debug_ops )
595+ skip_op_check = args .skip_op_check ,
596+ strip_debug_ops = args .strip_debug_ops )
590597 elif (input_format == 'keras_saved_model' and
591598 output_format == 'tfjs_layers_model' ):
592599 dispatch_keras_saved_model_to_tensorflowjs_conversion (
593- FLAGS .input_path , FLAGS .output_path ,
600+ args .input_path , args .output_path ,
594601 quantization_dtype = quantization_dtype ,
595- split_weights_by_layer = FLAGS .split_weights_by_layer )
602+ split_weights_by_layer = args .split_weights_by_layer )
596603 elif (input_format == 'tf_saved_model' and
597604 output_format == 'tfjs_graph_model' ):
598605 tf_saved_model_conversion_v2 .convert_tf_saved_model (
599- FLAGS .input_path , FLAGS .output_path ,
600- signature_def = FLAGS .signature_name ,
601- saved_model_tags = FLAGS .saved_model_tags ,
606+ args .input_path , args .output_path ,
607+ signature_def = args .signature_name ,
608+ saved_model_tags = args .saved_model_tags ,
602609 quantization_dtype = quantization_dtype ,
603- skip_op_check = FLAGS .skip_op_check ,
604- strip_debug_ops = FLAGS .strip_debug_ops )
610+ skip_op_check = args .skip_op_check ,
611+ strip_debug_ops = args .strip_debug_ops )
605612 elif (input_format == 'tf_hub' and
606613 output_format == 'tfjs_graph_model' ):
607614 tf_saved_model_conversion_v2 .convert_tf_hub_module (
608- FLAGS .input_path , FLAGS .output_path , FLAGS .signature_name ,
609- FLAGS .saved_model_tags , skip_op_check = FLAGS .skip_op_check ,
610- strip_debug_ops = FLAGS .strip_debug_ops )
615+ args .input_path , args .output_path , args .signature_name ,
616+ args .saved_model_tags , skip_op_check = args .skip_op_check ,
617+ strip_debug_ops = args .strip_debug_ops )
611618 elif (input_format == 'tfjs_layers_model' and
612619 output_format == 'keras' ):
613- dispatch_tensorflowjs_to_keras_h5_conversion (FLAGS .input_path ,
614- FLAGS .output_path )
620+ dispatch_tensorflowjs_to_keras_h5_conversion (args .input_path ,
621+ args .output_path )
615622 elif (input_format == 'tfjs_layers_model' and
616623 output_format == 'keras_saved_model' ):
617- dispatch_tensorflowjs_to_keras_saved_model_conversion (FLAGS .input_path ,
618- FLAGS .output_path )
624+ dispatch_tensorflowjs_to_keras_saved_model_conversion (args .input_path ,
625+ args .output_path )
619626 elif (input_format == 'tfjs_layers_model' and
620627 output_format == 'tfjs_layers_model' ):
621628 dispatch_tensorflowjs_to_tensorflowjs_conversion (
622- FLAGS .input_path , FLAGS .output_path ,
623- quantization_dtype = _parse_quantization_bytes (FLAGS .quantization_bytes ),
629+ args .input_path , args .output_path ,
630+ quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
624631 weight_shard_size_bytes = weight_shard_size_bytes )
625632 elif (input_format == 'tfjs_layers_model' and
626633 output_format == 'tfjs_graph_model' ):
627634 dispatch_tfjs_layers_model_to_tfjs_graph_conversion (
628- FLAGS .input_path , FLAGS .output_path ,
629- quantization_dtype = _parse_quantization_bytes (FLAGS .quantization_bytes ),
630- skip_op_check = FLAGS .skip_op_check ,
631- strip_debug_ops = FLAGS .strip_debug_ops )
635+ args .input_path , args .output_path ,
636+ quantization_dtype = _parse_quantization_bytes (args .quantization_bytes ),
637+ skip_op_check = args .skip_op_check ,
638+ strip_debug_ops = args .strip_debug_ops )
632639 else :
633640 raise ValueError (
634641 'Unsupported input_format - output_format pair: %s - %s' %
635642 (input_format , output_format ))
636643
637644
638645if __name__ == '__main__' :
639- # pylint: disable=redefined-outer-name,invalid-name
640- FLAGS , unparsed = get_arg_parser ().parse_known_args ()
641- # pylint: enable=redefined-outer-name,invalid-name
642- tf .app .run (main = main , argv = [sys .argv [0 ]] + unparsed )
646+ tf .app .run (main = main , argv = [' ' .join (sys .argv [1 :])])
0 commit comments