@@ -58,27 +58,23 @@ class PopDistStrategyTest(test_util.TensorFlowTestCase): # pylint: disable=abst
5858 @classmethod
5959 def setUpClass (cls ):
6060 hvd .init ()
61- popdist .init ()
61+ # Instantiate here to call popdist.init() and popdist.finalizeBackend() once
62+ cls .strategy = PopDistStrategy ()
6263
6364 @classmethod
6465 def tearDownClass (cls ):
6566 hvd .shutdown ()
6667
67- def _create_test_objects (self ,
68- auto_select_ipus = 1 ,
69- add_ipu_cross_replica_reductions = True ):
68+ def _create_test_objects (self , auto_select_ipus = 1 ):
7069 config = IPUConfig ()
7170 config .auto_select_ipus = auto_select_ipus
7271 config .configure_ipu_system ()
7372
74- strategy = PopDistStrategy (
75- add_ipu_cross_replica_reductions = add_ipu_cross_replica_reductions )
76-
77- return config , strategy
73+ return config
7874
7975 def test_update_ipu_config (self ):
80- config , strategy = self ._create_test_objects ()
81- strategy .update_ipu_config (config )
76+ config = self ._create_test_objects ()
77+ self . strategy .update_ipu_config (config )
8278 self .assertEqual (
8379 config .experimental .multi_replica_distribution .process_count ,
8480 popdist .getNumInstances ())
@@ -87,9 +83,10 @@ def test_update_ipu_config(self):
8783 popdist .getInstanceIndex ())
8884
8985 def test_strategy (self ):
90- config , strategy = self ._create_test_objects ()
86+ config = self ._create_test_objects ()
87+ self .strategy .update_ipu_config (config )
9188
92- with strategy .scope ():
89+ with self . strategy .scope ():
9390 v = variables .Variable (initial_value = popdist .getInstanceIndex () + 1 ,
9491 dtype = np .float32 )
9592 self .assertEndsWith (v .device , "/device:IPU:0" )
@@ -108,11 +105,11 @@ def per_replica_fn(x):
108105
109106 return y_all_reduced
110107
111- per_replica_value = strategy .run (per_replica_fn ,
112- args = [constant_op .constant (2.0 )])
108+ per_replica_value = self . strategy .run (per_replica_fn ,
109+ args = [constant_op .constant (2.0 )])
113110
114111 # This reduction is performed on CPU, and hence uses Horovod.
115- value_all_reduced = strategy .reduce (ReduceOp .SUM , per_replica_value )
112+ value_all_reduced = self . strategy .reduce (ReduceOp .SUM , per_replica_value )
116113
117114 # The initial value should be broadcast from rank 0.
118115 self .assertEqual (v , 1.0 )
@@ -121,10 +118,13 @@ def per_replica_fn(x):
121118 self .assertEqual (value_all_reduced , popdist .getNumInstances () * 2.0 )
122119
123120 def test_strategy_without_ipu_reduction (self ):
124- config , strategy = self ._create_test_objects (
125- add_ipu_cross_replica_reductions = False )
121+ config = self ._create_test_objects ()
122+ self .strategy .update_ipu_config (config )
123+
124+ # Modify to keep one instance of the strategy
125+ self .strategy ._extended ._add_ipu_cross_replica_reductions = False # pylint: disable=protected-access
126126
127- with strategy .scope ():
127+ with self . strategy .scope ():
128128 v = variables .Variable (initial_value = 1.0 , dtype = np .float32 )
129129
130130 @def_function .function
@@ -139,12 +139,16 @@ def per_replica_fn(x):
139139 return y_out
140140
141141 # It is sufficient to test the TF graph construction.
142- strategy .run (per_replica_fn , args = [constant_op .constant (2.0 )])
142+ self .strategy .run (per_replica_fn , args = [constant_op .constant (2.0 )])
143+
144+ # Set back to default value
145+ self .strategy ._extended ._add_ipu_cross_replica_reductions = True # pylint: disable=protected-access
143146
144147 def test_strategy_with_sync_on_read_variable (self ):
145- config , strategy = self ._create_test_objects ()
148+ config = self ._create_test_objects ()
149+ self .strategy .update_ipu_config (config )
146150
147- with strategy .scope ():
151+ with self . strategy .scope ():
148152 w = variables .Variable (initial_value = float (popdist .getInstanceIndex () +
149153 1 ),
150154 dtype = np .float32 ,
@@ -160,15 +164,16 @@ def per_replica_fn(x):
160164
161165 # Both should have initial value from first worker
162166 debugging .assert_equal ([1.0 ], w )
163- strategy .run (
167+ self . strategy .run (
164168 per_replica_fn ,
165169 args = [constant_op .constant (popdist .getInstanceIndex () + 1.0 )])
166170 debugging .assert_equal ([2.5 ], w .read_value ())
167171
168172 def test_strategy_with_mirrored_variable (self ):
169- config , strategy = self ._create_test_objects ()
173+ config = self ._create_test_objects ()
174+ self .strategy .update_ipu_config (config )
170175
171- with strategy .scope ():
176+ with self . strategy .scope ():
172177 w = variables .Variable (initial_value = float (popdist .getInstanceIndex () +
173178 1 ),
174179 dtype = np .float32 ,
@@ -180,19 +185,20 @@ def per_replica_fn():
180185 self .assertIsInstance (w , IPUMirroredVariable )
181186 return w * w
182187
183- per_replica_ret = strategy .run (per_replica_fn , args = [])
184- sum_ret = strategy .reduce (ReduceOp .SUM , per_replica_ret , axis = None )
188+ per_replica_ret = self . strategy .run (per_replica_fn , args = [])
189+ sum_ret = self . strategy .reduce (ReduceOp .SUM , per_replica_ret , axis = None )
185190 self .assertEqual ([1.0 ], per_replica_ret )
186191 self .assertEqual (2.0 , sum_ret )
187192
188193 def test_distribute_dataset (self ):
189- config , strategy = self ._create_test_objects ()
194+ config = self ._create_test_objects ()
195+ self .strategy .update_ipu_config (config )
190196
191197 dataset = dataset_ops .Dataset .range (10 , output_type = np .float32 )
192198 dataset = dataset .shard (num_shards = popdist .getNumInstances (),
193199 index = popdist .getInstanceIndex ())
194200
195- with strategy .scope ():
201+ with self . strategy .scope ():
196202
197203 @def_function .function
198204 def step_fn (iterator ):
@@ -203,31 +209,33 @@ def step_fn(iterator):
203209 dist_iterator = ipu_infeed_queue .IPUIterator (dataset = dataset )
204210
205211 def run_fn (iterator ):
206- per_replica_y = strategy .run (step_fn , args = [iterator ])
207- return strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
212+ per_replica_y = self . strategy .run (step_fn , args = [iterator ])
213+ return self . strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
208214
209215 self .assertEqual (1.0 , run_fn (dist_iterator ))
210216 self .assertEqual (13.0 , run_fn (dist_iterator ))
211217 self .assertEqual (41.0 , run_fn (dist_iterator ))
212218
213219 def test_all_reduce (self ):
214- config , strategy = self ._create_test_objects ()
220+ config = self ._create_test_objects ()
221+ self .strategy .update_ipu_config (config )
215222
216- with strategy .scope ():
223+ with self . strategy .scope ():
217224
218225 @def_function .function
219226 def per_replica_fn (x ):
220227 return x * x
221228
222- per_replica_y = strategy .run (per_replica_fn ,
223- args = [popdist .getInstanceIndex () + 1 ])
224- sum_y = strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
229+ per_replica_y = self . strategy .run (per_replica_fn ,
230+ args = [popdist .getInstanceIndex () + 1 ])
231+ sum_y = self . strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
225232 self .assertEqual (5 , sum_y ) # 1*1 + 2*2
226233
227234 def test_batch_normalization (self ):
228- config , strategy = self ._create_test_objects ()
235+ config = self ._create_test_objects ()
236+ self .strategy .update_ipu_config (config )
229237
230- with strategy .scope ():
238+ with self . strategy .scope ():
231239 batch_norm = layers .BatchNormalization (momentum = 0.0 )
232240
233241 @def_function .function
@@ -245,35 +253,32 @@ def per_replica_fn(x):
245253
246254 x = constant_op .constant ([[2.0 * (popdist .getInstanceIndex () + 1 )],
247255 [0.0 ]])
248- per_replica_y = strategy .run (per_replica_fn , args = (x ,))
249- sum_y = strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
256+ per_replica_y = self . strategy .run (per_replica_fn , args = (x ,))
257+ sum_y = self . strategy .reduce (ReduceOp .SUM , per_replica_y , axis = None )
250258
251259 # mean(mean(2, 0), mean(4, 0)) = mean(1, 3) = 1.5
252260 self .assertAllEqual ([1.5 ], batch_norm .moving_mean )
253261 # mean(var(2, 0), var(4, 0)) = mean(1, 4) = 2.5
254262 self .assertAllEqual ([2.5 ], batch_norm .moving_variance )
255263
256264 def test_dataset_infeed (self ):
257- config , strategy = self ._create_test_objects ()
265+ config = self ._create_test_objects ()
266+ self .strategy .update_ipu_config (config )
258267
259268 dataset = dataset_ops .Dataset .from_tensor_slices ([0.0 ]).repeat ()
260269 # Test with a dataset host op.
261270 dataset = dataset .map (lambda x : x + popdist .getInstanceIndex ())
262271 infeed_queue = iter (dataset )
263272
264- @def_function .function
265- def body (x ):
266- return 2 * x
267-
268273 @def_function .function
269274 def net (iterator ):
270275 s = 0.0
271276 for _ in math_ops .range (10 ):
272- s += body ( next (iterator ) )
277+ s += 2 * next (iterator )
273278 return s
274279
275- with strategy .scope ():
276- res = strategy .run (net , args = (infeed_queue ,))
280+ with self . strategy .scope ():
281+ res = self . strategy .run (net , args = (infeed_queue ,))
277282 self .assertEqual (popdist .getInstanceIndex () * 20 , res )
278283
279284
0 commit comments