Skip to content

Commit c05d4ec

Browse files
author
Frederik Mellbye
committed
Add popdist.init() to PopDistStrategy
Reviewers: christiana, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Reviewed By: christiana, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved Maniphest Tasks: T67743 Differential Revision: https://phabricator.sourcevertex.net/D73447
1 parent fccb3bf commit c05d4ec

File tree

2 files changed

+54
-47
lines changed

2 files changed

+54
-47
lines changed

tensorflow/python/ipu/distributed/popdist_strategy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def __init__(self,
7373
add_ipu_cross_replica_reductions=True,
7474
enable_dataset_iterators=True,
7575
enable_keras_extensions=True):
76+
popdist.init()
77+
7678
# We create an empty cluster here since we will not be using gRPC for communication.
7779
# All the communication is delegated to either GCL or Horovod (MPI) below.
7880
cluster_resolver = cluster_resolver_lib.SimpleClusterResolver(

tensorflow/python/ipu/distributed/popdist_strategy_test.py

Lines changed: 52 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,23 @@ class PopDistStrategyTest(test_util.TensorFlowTestCase): # pylint: disable=abst
5858
@classmethod
5959
def setUpClass(cls):
6060
hvd.init()
61-
popdist.init()
61+
# Instantiate here to call popdist.init() and popdist.finalizeBackend() once
62+
cls.strategy = PopDistStrategy()
6263

6364
@classmethod
6465
def tearDownClass(cls):
6566
hvd.shutdown()
6667

67-
def _create_test_objects(self,
68-
auto_select_ipus=1,
69-
add_ipu_cross_replica_reductions=True):
68+
def _create_test_objects(self, auto_select_ipus=1):
7069
config = IPUConfig()
7170
config.auto_select_ipus = auto_select_ipus
7271
config.configure_ipu_system()
7372

74-
strategy = PopDistStrategy(
75-
add_ipu_cross_replica_reductions=add_ipu_cross_replica_reductions)
76-
77-
return config, strategy
73+
return config
7874

7975
def test_update_ipu_config(self):
80-
config, strategy = self._create_test_objects()
81-
strategy.update_ipu_config(config)
76+
config = self._create_test_objects()
77+
self.strategy.update_ipu_config(config)
8278
self.assertEqual(
8379
config.experimental.multi_replica_distribution.process_count,
8480
popdist.getNumInstances())
@@ -87,9 +83,10 @@ def test_update_ipu_config(self):
8783
popdist.getInstanceIndex())
8884

8985
def test_strategy(self):
90-
config, strategy = self._create_test_objects()
86+
config = self._create_test_objects()
87+
self.strategy.update_ipu_config(config)
9188

92-
with strategy.scope():
89+
with self.strategy.scope():
9390
v = variables.Variable(initial_value=popdist.getInstanceIndex() + 1,
9491
dtype=np.float32)
9592
self.assertEndsWith(v.device, "/device:IPU:0")
@@ -108,11 +105,11 @@ def per_replica_fn(x):
108105

109106
return y_all_reduced
110107

111-
per_replica_value = strategy.run(per_replica_fn,
112-
args=[constant_op.constant(2.0)])
108+
per_replica_value = self.strategy.run(per_replica_fn,
109+
args=[constant_op.constant(2.0)])
113110

114111
# This reduction is performed on CPU, and hence uses Horovod.
115-
value_all_reduced = strategy.reduce(ReduceOp.SUM, per_replica_value)
112+
value_all_reduced = self.strategy.reduce(ReduceOp.SUM, per_replica_value)
116113

117114
# The initial value should be broadcast from rank 0.
118115
self.assertEqual(v, 1.0)
@@ -121,10 +118,13 @@ def per_replica_fn(x):
121118
self.assertEqual(value_all_reduced, popdist.getNumInstances() * 2.0)
122119

123120
def test_strategy_without_ipu_reduction(self):
124-
config, strategy = self._create_test_objects(
125-
add_ipu_cross_replica_reductions=False)
121+
config = self._create_test_objects()
122+
self.strategy.update_ipu_config(config)
123+
124+
# Modify to keep one instance of the strategy
125+
self.strategy._extended._add_ipu_cross_replica_reductions = False # pylint: disable=protected-access
126126

127-
with strategy.scope():
127+
with self.strategy.scope():
128128
v = variables.Variable(initial_value=1.0, dtype=np.float32)
129129

130130
@def_function.function
@@ -139,12 +139,16 @@ def per_replica_fn(x):
139139
return y_out
140140

141141
# It is sufficient to test the TF graph construction.
142-
strategy.run(per_replica_fn, args=[constant_op.constant(2.0)])
142+
self.strategy.run(per_replica_fn, args=[constant_op.constant(2.0)])
143+
144+
# Set back to default value
145+
self.strategy._extended._add_ipu_cross_replica_reductions = True # pylint: disable=protected-access
143146

144147
def test_strategy_with_sync_on_read_variable(self):
145-
config, strategy = self._create_test_objects()
148+
config = self._create_test_objects()
149+
self.strategy.update_ipu_config(config)
146150

147-
with strategy.scope():
151+
with self.strategy.scope():
148152
w = variables.Variable(initial_value=float(popdist.getInstanceIndex() +
149153
1),
150154
dtype=np.float32,
@@ -160,15 +164,16 @@ def per_replica_fn(x):
160164

161165
# Both should have initial value from first worker
162166
debugging.assert_equal([1.0], w)
163-
strategy.run(
167+
self.strategy.run(
164168
per_replica_fn,
165169
args=[constant_op.constant(popdist.getInstanceIndex() + 1.0)])
166170
debugging.assert_equal([2.5], w.read_value())
167171

168172
def test_strategy_with_mirrored_variable(self):
169-
config, strategy = self._create_test_objects()
173+
config = self._create_test_objects()
174+
self.strategy.update_ipu_config(config)
170175

171-
with strategy.scope():
176+
with self.strategy.scope():
172177
w = variables.Variable(initial_value=float(popdist.getInstanceIndex() +
173178
1),
174179
dtype=np.float32,
@@ -180,19 +185,20 @@ def per_replica_fn():
180185
self.assertIsInstance(w, IPUMirroredVariable)
181186
return w * w
182187

183-
per_replica_ret = strategy.run(per_replica_fn, args=[])
184-
sum_ret = strategy.reduce(ReduceOp.SUM, per_replica_ret, axis=None)
188+
per_replica_ret = self.strategy.run(per_replica_fn, args=[])
189+
sum_ret = self.strategy.reduce(ReduceOp.SUM, per_replica_ret, axis=None)
185190
self.assertEqual([1.0], per_replica_ret)
186191
self.assertEqual(2.0, sum_ret)
187192

188193
def test_distribute_dataset(self):
189-
config, strategy = self._create_test_objects()
194+
config = self._create_test_objects()
195+
self.strategy.update_ipu_config(config)
190196

191197
dataset = dataset_ops.Dataset.range(10, output_type=np.float32)
192198
dataset = dataset.shard(num_shards=popdist.getNumInstances(),
193199
index=popdist.getInstanceIndex())
194200

195-
with strategy.scope():
201+
with self.strategy.scope():
196202

197203
@def_function.function
198204
def step_fn(iterator):
@@ -203,31 +209,33 @@ def step_fn(iterator):
203209
dist_iterator = ipu_infeed_queue.IPUIterator(dataset=dataset)
204210

205211
def run_fn(iterator):
206-
per_replica_y = strategy.run(step_fn, args=[iterator])
207-
return strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
212+
per_replica_y = self.strategy.run(step_fn, args=[iterator])
213+
return self.strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
208214

209215
self.assertEqual(1.0, run_fn(dist_iterator))
210216
self.assertEqual(13.0, run_fn(dist_iterator))
211217
self.assertEqual(41.0, run_fn(dist_iterator))
212218

213219
def test_all_reduce(self):
214-
config, strategy = self._create_test_objects()
220+
config = self._create_test_objects()
221+
self.strategy.update_ipu_config(config)
215222

216-
with strategy.scope():
223+
with self.strategy.scope():
217224

218225
@def_function.function
219226
def per_replica_fn(x):
220227
return x * x
221228

222-
per_replica_y = strategy.run(per_replica_fn,
223-
args=[popdist.getInstanceIndex() + 1])
224-
sum_y = strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
229+
per_replica_y = self.strategy.run(per_replica_fn,
230+
args=[popdist.getInstanceIndex() + 1])
231+
sum_y = self.strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
225232
self.assertEqual(5, sum_y) # 1*1 + 2*2
226233

227234
def test_batch_normalization(self):
228-
config, strategy = self._create_test_objects()
235+
config = self._create_test_objects()
236+
self.strategy.update_ipu_config(config)
229237

230-
with strategy.scope():
238+
with self.strategy.scope():
231239
batch_norm = layers.BatchNormalization(momentum=0.0)
232240

233241
@def_function.function
@@ -245,35 +253,32 @@ def per_replica_fn(x):
245253

246254
x = constant_op.constant([[2.0 * (popdist.getInstanceIndex() + 1)],
247255
[0.0]])
248-
per_replica_y = strategy.run(per_replica_fn, args=(x,))
249-
sum_y = strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
256+
per_replica_y = self.strategy.run(per_replica_fn, args=(x,))
257+
sum_y = self.strategy.reduce(ReduceOp.SUM, per_replica_y, axis=None)
250258

251259
# mean(mean(2, 0), mean(4, 0)) = mean(1, 3) = 1.5
252260
self.assertAllEqual([1.5], batch_norm.moving_mean)
253261
# mean(var(2, 0), var(4, 0)) = mean(1, 4) = 2.5
254262
self.assertAllEqual([2.5], batch_norm.moving_variance)
255263

256264
def test_dataset_infeed(self):
257-
config, strategy = self._create_test_objects()
265+
config = self._create_test_objects()
266+
self.strategy.update_ipu_config(config)
258267

259268
dataset = dataset_ops.Dataset.from_tensor_slices([0.0]).repeat()
260269
# Test with a dataset host op.
261270
dataset = dataset.map(lambda x: x + popdist.getInstanceIndex())
262271
infeed_queue = iter(dataset)
263272

264-
@def_function.function
265-
def body(x):
266-
return 2 * x
267-
268273
@def_function.function
269274
def net(iterator):
270275
s = 0.0
271276
for _ in math_ops.range(10):
272-
s += body(next(iterator))
277+
s += 2 * next(iterator)
273278
return s
274279

275-
with strategy.scope():
276-
res = strategy.run(net, args=(infeed_queue,))
280+
with self.strategy.scope():
281+
res = self.strategy.run(net, args=(infeed_queue,))
277282
self.assertEqual(popdist.getInstanceIndex() * 20, res)
278283

279284

0 commit comments

Comments
 (0)