File tree Expand file tree Collapse file tree 16 files changed +36
-22
lines changed Expand file tree Collapse file tree 16 files changed +36
-22
lines changed Original file line number Diff line number Diff line change @@ -612,9 +612,10 @@ def train(sess_config,
612612 hooks = []
613613 hooks .extend (input_hooks )
614614
615+ sharded_saver = tf_config != None
615616 scaffold = tf .train .Scaffold (
616617 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
617- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
618+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
618619
619620 stop_hook = tf .train .StopAtStepHook (last_step = steps )
620621 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -527,9 +527,10 @@ def train(sess_config,
527527 hooks = []
528528 hooks .extend (input_hooks )
529529
530+ sharded_saver = tf_config != None
530531 scaffold = tf .train .Scaffold (
531532 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
532- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
533+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
533534
534535 stop_hook = tf .train .StopAtStepHook (last_step = steps )
535536 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -594,9 +594,10 @@ def train(sess_config,
594594 hooks = []
595595 hooks .extend (input_hooks )
596596
597+ sharded_saver = tf_config != None
597598 scaffold = tf .train .Scaffold (
598599 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
599- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
600+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
600601
601602 stop_hook = tf .train .StopAtStepHook (last_step = steps )
602603 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -610,9 +610,10 @@ def train(sess_config,
610610 hooks = []
611611 hooks .extend (input_hooks )
612612
613+ sharded_saver = tf_config != None
613614 scaffold = tf .train .Scaffold (
614615 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
615- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
616+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
616617
617618 stop_hook = tf .train .StopAtStepHook (last_step = steps )
618619 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -472,9 +472,10 @@ def train(sess_config,
472472 hooks = []
473473 hooks .extend (input_hooks )
474474
475+ sharded_saver = tf_config != None
475476 scaffold = tf .train .Scaffold (
476477 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
477- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
478+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
478479
479480 stop_hook = tf .train .StopAtStepHook (last_step = steps )
480481 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -776,10 +776,10 @@ def train(sess_config,
776776 hooks = []
777777 hooks .extend (input_hooks )
778778
779+ sharded_saver = tf_config != None
779780 scaffold = tf .train .Scaffold (
780- local_init_op = tf .group (tf .tables_initializer (),
781- tf .local_variables_initializer (), data_init_op ),
782- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
781+ local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
782+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
783783
784784 stop_hook = tf .train .StopAtStepHook (last_step = steps )
785785 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -594,10 +594,10 @@ def train(sess_config,
594594 hooks = []
595595 hooks .extend (input_hooks )
596596
597+ sharded_saver = tf_config != None
597598 scaffold = tf .train .Scaffold (
598- local_init_op = tf .group (tf .tables_initializer (),
599- tf .local_variables_initializer (), data_init_op ),
600- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
599+ local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
600+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
601601
602602 stop_hook = tf .train .StopAtStepHook (last_step = steps )
603603 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -507,9 +507,10 @@ def train(sess_config,
507507 hooks = []
508508 hooks .extend (input_hooks )
509509
510+ sharded_saver = tf_config != None
510511 scaffold = tf .train .Scaffold (
511512 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
512- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
513+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
513514
514515 stop_hook = tf .train .StopAtStepHook (last_step = steps )
515516 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -478,9 +478,10 @@ def train(sess_config,
478478 hooks = []
479479 hooks .extend (input_hooks )
480480
481+ sharded_saver = tf_config != None
481482 scaffold = tf .train .Scaffold (
482483 local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
483- saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max ))
484+ saver = tf .train .Saver (max_to_keep = args .keep_checkpoint_max , sharded = sharded_saver ))
484485
485486 stop_hook = tf .train .StopAtStepHook (last_step = steps )
486487 log_hook = tf .train .LoggingTensorHook (
Original file line number Diff line number Diff line change @@ -534,9 +534,10 @@ def train(sess_config,
534534 hooks = []
535535 hooks .extend (input_hooks )
536536
537+ sharded_saver = tf_config != None
537538 scaffold = tf .train .Scaffold (
538- local_init_op = tf .group (tf .local_variables_initializer (), train_init_op ),
539- saver = tf .train .Saver (max_to_keep = keep_checkpoint_max ))
539+ local_init_op = tf .group (tf .local_variables_initializer (), data_init_op ),
540+ saver = tf .train .Saver (max_to_keep = args . keep_checkpoint_max , sharded = sharded_saver ))
540541
541542 stop_hook = tf .train .StopAtStepHook (last_step = train_steps )
542543 log_hook = tf .train .LoggingTensorHook (
You can’t perform that action at this time.
0 commit comments