Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 56d65f0

Browse files
author
Ryan Sepassi
committed
Working open-source distributed training
PiperOrigin-RevId: 161731856
1 parent 29f2e2e commit 56d65f0

File tree

3 files changed

+42
-28
lines changed

3 files changed

+42
-28
lines changed

tensor2tensor/bin/t2t-make-tf-configs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
1818
Usage:
1919
20-
`t2t-make-tf-configs --workers="server1:1234" --ps="server3:2134,server4:2334"`
20+
`t2t-make-tf-configs --masters="server1:1234" --ps="server3:2134,server4:2334"`
2121
22-
Outputs 1 line per job to stdout, first the workers, then the parameter servers.
22+
Outputs 1 line per job to stdout, first the masters, then the parameter servers.
2323
Each line has the TF_CONFIG, then a tab, then the command line flags for that
2424
job.
2525
26-
If there is a single worker, workers will have the `--sync` flag.
26+
If there is a single master, it will have the `--sync` flag.
2727
"""
2828
from __future__ import absolute_import
2929
from __future__ import division
@@ -38,31 +38,32 @@ import tensorflow as tf
3838
flags = tf.flags
3939
FLAGS = flags.FLAGS
4040

41-
flags.DEFINE_string("workers", "", "Comma-separated list of worker addresses")
41+
flags.DEFINE_string("masters", "", "Comma-separated list of master addresses")
4242
flags.DEFINE_string("ps", "", "Comma-separated list of ps addresses")
4343

4444

4545
def main(_):
46-
if not (FLAGS.workers and FLAGS.ps):
47-
raise ValueError("Must provide --workers and --ps")
46+
if not (FLAGS.masters and FLAGS.ps):
47+
raise ValueError("Must provide --masters and --ps")
4848

49-
workers = FLAGS.workers.split(",")
49+
masters = FLAGS.masters.split(",")
5050
ps = FLAGS.ps.split(",")
5151

52-
cluster = {"ps": ps, "worker": workers}
52+
cluster = {"ps": ps, "master": masters}
5353

54-
for task_type, jobs in (("worker", workers), ("ps", ps)):
54+
for task_type, jobs in (("master", masters), ("ps", ps)):
5555
for idx, job in enumerate(jobs):
56-
if task_type == "worker":
56+
if task_type == "master":
5757
cmd_line_flags = " ".join([
5858
"--master=grpc://%s" % job,
5959
"--ps_replicas=%d" % len(ps),
60-
"--worker_replicas=%d" % len(workers),
60+
"--worker_replicas=%d" % len(masters),
6161
"--worker_gpu=1",
6262
"--worker_id=%d" % idx,
63+
"--worker_job='/job:master'",
6364
"--ps_gpu=1",
6465
"--schedule=train",
65-
"--sync" if len(workers) == 1 else "",
66+
"--sync" if len(masters) == 1 else "",
6667
])
6768
else:
6869
cmd_line_flags = " ".join([

tensor2tensor/docs/distributed_training.md

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,52 +10,54 @@ along with a set of flags.
1010

1111
## `TF_CONFIG`
1212

13-
Both workers and parameter servers must have the `TF_CONFIG` environment
13+
Both masters and parameter servers must have the `TF_CONFIG` environment
1414
variable set.
1515

1616
The `TF_CONFIG` environment variable is a json-encoded string with the addresses
17-
of the workers and parameter servers (in the `'cluster'` key) and the
17+
of the masters and parameter servers (in the `'cluster'` key) and the
1818
identification of the current task (in the `'task'` key).
1919

2020
For example:
2121

2222
```
2323
cluster = {
2424
'ps': ['host1:2222', 'host2:2222'],
25-
'worker': ['host3:2222', 'host4:2222', 'host5:2222']
25+
'master': ['host3:2222', 'host4:2222', 'host5:2222']
2626
}
2727
os.environ['TF_CONFIG'] = json.dumps({
2828
'cluster': cluster,
29-
'task': {'type': 'worker', 'index': 1}
29+
'task': {'type': 'master', 'index': 1},
30+
'environment': 'cloud',
3031
})
3132
```
3233

3334
## Command-line flags
3435

35-
The following T2T command-line flags must also be set on the workers for
36+
The following T2T command-line flags must also be set on the masters for
3637
distributed training:
3738

3839
- `--master=grpc://$ADDRESS`
39-
- `--worker_replicas=$NUM_WORKERS`
40-
- `--worker_gpu=$NUM_GPUS_PER_WORKER`
41-
- `--worker_id=$WORKER_ID`
40+
- `--worker_replicas=$NUM_MASTERS`
41+
- `--worker_gpu=$NUM_GPUS_PER_MASTER`
42+
- `--worker_id=$MASTER_ID`
43+
- `--worker_job='/job:master'`
4244
- `--ps_replicas=$NUM_PS`
4345
- `--ps_gpu=$NUM_GPUS_PER_PS`
4446
- `--schedule=train`
4547
- `--sync`, if you want synchronous training, i.e. for there to be a single
46-
master worker coordinating the work across "ps" jobs (yes, the naming is
47-
unfortunate). If not set, then each worker operates independently while
48-
variables are shared on the parameter servers.
48+
master coordinating the work across "ps" jobs. If not set, then each master
49+
operates independently while variables are shared on the parameter servers.
4950

50-
Parameter servers only need `--schedule=run_std_server`.
51+
Parameter servers only need `--master=grpc://$ADDRESS` and
52+
`--schedule=run_std_server`.
5153

5254
## Utility to produce `TF_CONFIG` and flags
5355

5456
[`t2t-make-tf-configs`](https://github.com/tensorflow/tensor2tensor/tree/master/tensor2tensor/bin/t2t-make-tf-configs))
5557
generates the `TF_CONFIG` json strings and the above-mentioned command-line
56-
flags for the workers and parameter servers.
58+
flags for the masters and parameter servers.
5759

58-
Given a set of worker and parameter server addresses, the script outputs, for
60+
Given a set of master and parameter server addresses, the script outputs, for
5961
each job, a line with the `TF_CONFIG` environment variable and the command-line
6062
flags necessary for distributed training. For each job, you should invoke the
6163
`t2t-trainer` with the `TF_CONFIG` value and flags that are output.
@@ -66,6 +68,9 @@ For example:
6668
TF_CONFIG=$JOB_TF_CONFIG t2t-trainer $JOB_FLAGS --model=transformer ...
6769
```
6870

71+
Modify the `--worker_gpu` and `--ps_gpu` flags, which specify how many gpus are
72+
on each master and ps, respectively, as needed for your machine/cluster setup.
73+
6974
## Command-line flags for eval jobs
7075

7176
Eval jobs should set the following flags and do not need the `TF_CONFIG`

tensor2tensor/utils/trainer_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@
9191
flags.DEFINE_integer("worker_gpu", 1, "How many GPUs to use.")
9292
flags.DEFINE_integer("worker_replicas", 1, "How many workers to use.")
9393
flags.DEFINE_integer("worker_id", 0, "Which worker task are we.")
94+
flags.DEFINE_float("worker_gpu_memory_fraction", 1.,
95+
"Fraction of GPU memory to allocate.")
9496
flags.DEFINE_integer("ps_gpu", 0, "How many GPUs to use per ps.")
9597
flags.DEFINE_string("gpu_order", "", "Optional order for daisy-chaining gpus."
9698
" e.g. \"1 3 2 4\"")
@@ -177,6 +179,7 @@ def create_experiment_components(hparams, output_dir, data_dir, model_name):
177179
config=tf.contrib.learn.RunConfig(
178180
master=FLAGS.master,
179181
model_dir=output_dir,
182+
gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction,
180183
session_config=session_config(),
181184
keep_checkpoint_max=FLAGS.keep_checkpoint_max))
182185
# Store the hparams in the estimator as well
@@ -270,16 +273,21 @@ def session_config():
270273
"""The TensorFlow Session config to use."""
271274
graph_options = tf.GraphOptions(optimizer_options=tf.OptimizerOptions(
272275
opt_level=tf.OptimizerOptions.L1, do_function_inlining=False))
276+
273277
if FLAGS.experimental_optimize_placement:
274278
rewrite_options = tf.RewriterConfig(optimize_tensor_layout=True)
275279
rewrite_options.optimizers.append("pruning")
276280
rewrite_options.optimizers.append("constfold")
277281
rewrite_options.optimizers.append("layout")
278282
graph_options = tf.GraphOptions(
279283
rewrite_options=rewrite_options, infer_shapes=True)
280-
config = tf.ConfigProto(
281-
allow_soft_placement=True, graph_options=graph_options)
282284

285+
gpu_options = tf.GPUOptions(
286+
per_process_gpu_memory_fraction=FLAGS.worker_gpu_memory_fraction)
287+
288+
config = tf.ConfigProto(allow_soft_placement=True,
289+
graph_options=graph_options,
290+
gpu_options=gpu_options)
283291
return config
284292

285293

0 commit comments

Comments
 (0)