Skip to content

Commit d814969

Browse files
authored
[ModelZoo] Set Saver's parameter sharded=True in distributed training. (#954)
Signed-off-by: 泊霆 <hujunqi.hjq@alibaba-inc.com>
1 parent 3bc9888 commit d814969

File tree

16 files changed

+36
-22
lines changed

16 files changed

+36
-22
lines changed

modelzoo/bst/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,9 +612,10 @@ def train(sess_config,
612612
hooks = []
613613
hooks.extend(input_hooks)
614614

615+
sharded_saver = tf_config != None
615616
scaffold = tf.train.Scaffold(
616617
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
617-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
618+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
618619

619620
stop_hook = tf.train.StopAtStepHook(last_step=steps)
620621
log_hook = tf.train.LoggingTensorHook(

modelzoo/dbmtl/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,10 @@ def train(sess_config,
527527
hooks = []
528528
hooks.extend(input_hooks)
529529

530+
sharded_saver = tf_config != None
530531
scaffold = tf.train.Scaffold(
531532
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
532-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
533+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
533534

534535
stop_hook = tf.train.StopAtStepHook(last_step=steps)
535536
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcn/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,9 +594,10 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597+
sharded_saver = tf_config != None
597598
scaffold = tf.train.Scaffold(
598599
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
599-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
600+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
600601

601602
stop_hook = tf.train.StopAtStepHook(last_step=steps)
602603
log_hook = tf.train.LoggingTensorHook(

modelzoo/dcnv2/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,10 @@ def train(sess_config,
610610
hooks = []
611611
hooks.extend(input_hooks)
612612

613+
sharded_saver = tf_config != None
613614
scaffold = tf.train.Scaffold(
614615
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
615-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
616+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
616617

617618
stop_hook = tf.train.StopAtStepHook(last_step=steps)
618619
log_hook = tf.train.LoggingTensorHook(

modelzoo/deepfm/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,10 @@ def train(sess_config,
472472
hooks = []
473473
hooks.extend(input_hooks)
474474

475+
sharded_saver = tf_config != None
475476
scaffold = tf.train.Scaffold(
476477
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
477-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
478+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
478479

479480
stop_hook = tf.train.StopAtStepHook(last_step=steps)
480481
log_hook = tf.train.LoggingTensorHook(

modelzoo/dien/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,10 +776,10 @@ def train(sess_config,
776776
hooks = []
777777
hooks.extend(input_hooks)
778778

779+
sharded_saver = tf_config != None
779780
scaffold = tf.train.Scaffold(
780-
local_init_op=tf.group(tf.tables_initializer(),
781-
tf.local_variables_initializer(), data_init_op),
782-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
781+
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
782+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
783783

784784
stop_hook = tf.train.StopAtStepHook(last_step=steps)
785785
log_hook = tf.train.LoggingTensorHook(

modelzoo/din/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -594,10 +594,10 @@ def train(sess_config,
594594
hooks = []
595595
hooks.extend(input_hooks)
596596

597+
sharded_saver = tf_config != None
597598
scaffold = tf.train.Scaffold(
598-
local_init_op=tf.group(tf.tables_initializer(),
599-
tf.local_variables_initializer(), data_init_op),
600-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
599+
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
600+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
601601

602602
stop_hook = tf.train.StopAtStepHook(last_step=steps)
603603
log_hook = tf.train.LoggingTensorHook(

modelzoo/dlrm/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,9 +507,10 @@ def train(sess_config,
507507
hooks = []
508508
hooks.extend(input_hooks)
509509

510+
sharded_saver = tf_config != None
510511
scaffold = tf.train.Scaffold(
511512
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
512-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
513+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
513514

514515
stop_hook = tf.train.StopAtStepHook(last_step=steps)
515516
log_hook = tf.train.LoggingTensorHook(

modelzoo/dssm/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,10 @@ def train(sess_config,
478478
hooks = []
479479
hooks.extend(input_hooks)
480480

481+
sharded_saver = tf_config != None
481482
scaffold = tf.train.Scaffold(
482483
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
483-
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max))
484+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
484485

485486
stop_hook = tf.train.StopAtStepHook(last_step=steps)
486487
log_hook = tf.train.LoggingTensorHook(

modelzoo/esmm/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,10 @@ def train(sess_config,
534534
hooks = []
535535
hooks.extend(input_hooks)
536536

537+
sharded_saver = tf_config != None
537538
scaffold = tf.train.Scaffold(
538-
local_init_op=tf.group(tf.local_variables_initializer(), train_init_op),
539-
saver=tf.train.Saver(max_to_keep=keep_checkpoint_max))
539+
local_init_op=tf.group(tf.local_variables_initializer(), data_init_op),
540+
saver=tf.train.Saver(max_to_keep=args.keep_checkpoint_max, sharded=sharded_saver))
540541

541542
stop_hook = tf.train.StopAtStepHook(last_step=train_steps)
542543
log_hook = tf.train.LoggingTensorHook(

0 commit comments

Comments
 (0)