Skip to content

Commit ff3c30a

Browse files
committed
IPU Multi Replica Strategy support for Keras API in TF2.4
Summary: TF2.4 Only Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, babakk, jackh Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, babakk, jackh Subscribers: hakons, georgep, babakk Maniphest Tasks: T46807 Differential Revision: https://phabricator.sourcevertex.net/D52073
1 parent b4c53f0 commit ff3c30a

File tree

7 files changed

+218
-56
lines changed

7 files changed

+218
-56
lines changed

tensorflow/compiler/plugin/poplar/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4253,6 +4253,18 @@ tf_xla_py_test(
42534253
],
42544254
)
42554255

4256+
poprun_py_test(
4257+
name = "distributed_tf2_test",
4258+
size = "large",
4259+
srcs = ["tests/distributed_tf2_test.py"],
4260+
main = "tests/distributed_tf2_test.py",
4261+
num_instances = 2,
4262+
num_replicas = 4,
4263+
deps = [
4264+
"//tensorflow/python/ipu:ipu_lib",
4265+
],
4266+
)
4267+
42564268
xla_test(
42574269
name = "replicated_resource_update_elementwise_clustering_hw_test",
42584270
srcs = ["tests/replicated_resource_update_elementwise_clustering_hw_test.cc"],
@@ -5771,6 +5783,7 @@ test_suite(
57715783
"device_connection_test",
57725784
"distributed_batch_norm_decomposer_test",
57735785
"distributed_batch_norm_test",
5786+
"distributed_tf2_test",
57745787
"dump_poplar_info",
57755788
"dynamic_slice_layout_test",
57765789
"dynamic_slice_test",
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
import numpy as np
16+
17+
import popdist
18+
import popdist.tensorflow
19+
20+
import tensorflow as tf
21+
from tensorflow.python import ipu
22+
from tensorflow.python.framework import constant_op, test_util
23+
from tensorflow.python.ipu.horovod import ipu_multi_replica_strategy
24+
from tensorflow.python.platform import test
25+
from tensorflow.python.ipu import horovod as hvd
26+
27+
28+
class DistributedTF2Test(test_util.TensorFlowTestCase):
29+
def assert_all_instances_equal(self, local_value, name=None):
30+
"""Assert that the current instance has the same value as the root instance."""
31+
local_tensor = constant_op.constant(local_value)
32+
root_tensor = hvd.broadcast(local_tensor, root_rank=0)
33+
np.testing.assert_equal(local_value, root_tensor.numpy(), name)
34+
35+
def test_tf2_distributed(self):
36+
config = ipu.config.IPUConfig()
37+
popdist.tensorflow.set_ipu_config(config, ipus_per_replica=1)
38+
config.configure_ipu_system()
39+
40+
hvd.init()
41+
42+
strategy = ipu_multi_replica_strategy.IPUMultiReplicaStrategy()
43+
44+
def generator():
45+
for _ in range(100):
46+
yield np.random.rand(32, 32, 1), np.random.randint(1, 10, size=1)
47+
48+
dataset = tf.data.Dataset.from_generator(
49+
generator,
50+
output_types=(tf.float32, tf.float32),
51+
output_shapes=((32, 32, 1), (1,)),
52+
)
53+
54+
options = tf.data.Options()
55+
options.experimental_distribute.auto_shard_policy =\
56+
tf.data.experimental.AutoShardPolicy.OFF
57+
dataset = dataset.with_options(options)
58+
59+
dataset = dataset.shard(num_shards=popdist.getNumInstances(),
60+
index=popdist.getInstanceIndex())
61+
dataset = dataset.batch(10, drop_remainder=True)
62+
63+
with strategy.scope():
64+
model = tf.keras.models.Sequential([
65+
tf.keras.layers.Conv2D(32, 3, activation='relu'),
66+
tf.keras.layers.MaxPooling2D(),
67+
tf.keras.layers.Conv2D(32, 3, activation='relu'),
68+
tf.keras.layers.MaxPooling2D(),
69+
tf.keras.layers.Flatten(),
70+
tf.keras.layers.Dense(32, activation='relu'),
71+
tf.keras.layers.Dense(10),
72+
])
73+
74+
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
75+
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
76+
77+
model.compile(optimizer=optimizer,
78+
loss=loss_fn,
79+
steps_per_execution=popdist.getNumTotalReplicas())
80+
history = model.fit(dataset,
81+
steps_per_epoch=popdist.getNumTotalReplicas(),
82+
epochs=1)
83+
84+
# Make sure the losses and weights are identical as we reduce over all IPUs
85+
self.assert_all_instances_equal(history.history['loss'])
86+
87+
for v in model.trainable_variables:
88+
self.assert_all_instances_equal(v)
89+
90+
91+
if __name__ == "__main__":
92+
test.main()

tensorflow/python/ipu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ py_library(
144144
"keras/optimizers/ipu_wrappers.py",
145145
"keras/optimizers/map_gradient_optimizer.py",
146146
"keras/pipeline.py",
147+
"keras_extensions.py",
147148
"loops.py",
148149
"ops/all_to_all_op.py",
149150
"ops/cross_replica_ops.py",

tensorflow/python/ipu/horovod/ipu_multi_replica_strategy.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
from tensorflow.python.distribute import distribute_lib
1818
from tensorflow.python.distribute import reduce_util
1919
from tensorflow.python.distribute import values
20-
from tensorflow.python.distribute.cluster_resolver import cluster_resolver as cluster_resolver_lib
20+
from tensorflow.python.distribute.cluster_resolver import \
21+
cluster_resolver as cluster_resolver_lib
2122
from tensorflow.python.framework import device as tf_device
22-
from tensorflow.python.ipu import utils as ipu_utils
23-
from tensorflow.python.ipu.horovod import Sum, Average, size, rank, allreduce as hvd_allreduce, broadcast as hvd_broadcast
24-
from tensorflow.python.ipu.ipu_multi_worker_strategy import IPUMultiWorkerExtendedV1
23+
from tensorflow.python.ipu import keras_extensions
24+
from tensorflow.python.ipu.horovod import Sum, Average, size, rank, \
25+
allreduce as hvd_allreduce, \
26+
broadcast as hvd_broadcast
27+
from tensorflow.python.ipu.ipu_multi_worker_strategy import \
28+
IPUMultiWorkerExtendedV1
2529
from tensorflow.python.ipu.ops import cross_replica_ops
2630
from tensorflow.python.training import server_lib
2731

@@ -40,7 +44,8 @@ def _is_current_device_ipu():
4044
return current_device.device_type == "IPU"
4145

4246

43-
class IPUMultiReplicaStrategyV1(distribute_lib.StrategyV1):
47+
class IPUMultiReplicaStrategyV1(distribute_lib.StrategyV1,
48+
keras_extensions.KerasExtensions):
4449
"""This is a distribution strategy for multi-replica distribution
4550
that uses compiled communications with GCL for reductions over IPU
4651
links and gateway links, while using Horovod for broadcasting of
@@ -55,7 +60,9 @@ class IPUMultiReplicaStrategyV1(distribute_lib.StrategyV1):
5560

5661
def __init__(self,
5762
ipu_device="/device:IPU:0",
58-
add_ipu_cross_replica_reductions=True):
63+
add_ipu_cross_replica_reductions=True,
64+
enable_dataset_iterators=True,
65+
enable_keras_extensions=True):
5966
# We create an empty cluster here since we will not be using gRPC for communication.
6067
# All the communication is delegated to either GCL or Horovod (MPI) below.
6168
cluster_resolver = cluster_resolver_lib.SimpleClusterResolver(
@@ -64,6 +71,8 @@ def __init__(self,
6471
super().__init__(
6572
IPUMultiReplicaExtendedV1(self, cluster_resolver, ipu_device,
6673
add_ipu_cross_replica_reductions))
74+
keras_extensions.KerasExtensions.__init__(self, enable_dataset_iterators,
75+
enable_keras_extensions)
6776

6877
def update_ipu_config(self, config):
6978
"""Update the given IPU configuration with the multi-replica
@@ -89,6 +98,10 @@ def __init__(self, container_strategy, cluster_resolver, ipu_device,
8998
self._num_workers = size()
9099
self._add_ipu_cross_replica_reductions = add_ipu_cross_replica_reductions
91100

101+
def non_slot_devices(self, var_list):
102+
del var_list
103+
return self._ipu_device
104+
92105
def _reduce_to(self, reduce_op, value, destinations, options):
93106
del destinations
94107
del options

tensorflow/python/ipu/ipu_multi_worker_strategy.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from tensorflow.python.ops import array_ops
3232
from tensorflow.python.ops import collective_ops
3333
from tensorflow.python.ops import control_flow_util
34+
from tensorflow.python.ops import control_flow_ops
3435
from tensorflow.python.ops import variable_scope
3536
from tensorflow.python.util import tf_contextlib
3637

@@ -432,10 +433,19 @@ def _batch_reduce_to(self, reduce_op, value_destination_pairs, options):
432433
options)
433434

434435
def _call_for_each_replica(self, fn, args, kwargs):
435-
with distribute_lib.ReplicaContext(
436-
self._container_strategy(), replica_id_in_sync_group=0), \
437-
ops.device(self._ipu_device):
438-
return fn(*args, **kwargs)
436+
with distribute_lib.ReplicaContext(self._container_strategy(),
437+
replica_id_in_sync_group=0), ops.device(
438+
self._ipu_device):
439+
# Make sure it is compiled as a single engine when called in graph mode.
440+
# This is similar to the mechanism used by xla.compile.
441+
xla_context = control_flow_ops.XLAControlFlowContext()
442+
try:
443+
xla_context.Enter()
444+
outputs = fn(*args, **kwargs)
445+
finally:
446+
xla_context.Exit()
447+
448+
return outputs
439449

440450
def _validate_colocate_with_variable(self, colocate_with_variable):
441451
if colocate_with_variable.device != self._variable_device:

tensorflow/python/ipu/ipu_strategy.py

Lines changed: 5 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@
4040
from tensorflow.python.ipu.keras.extensions import functional_extensions
4141
from tensorflow.python.ipu.keras.extensions import sequential_extensions
4242
from tensorflow.python.ipu import ipu_infeed_queue
43+
from tensorflow.python.ipu import keras_extensions
4344

4445
_pvti_trace_channel = libpvti.createTraceChannel("TensorFlow")
4546

4647

47-
class IPUStrategyV1(distribute_lib.StrategyV1):
48+
class IPUStrategyV1(distribute_lib.StrategyV1,
49+
keras_extensions.KerasExtensions):
4850
"""This is a distribution strategy for targeting a system with one
4951
or more IPUs.
5052
@@ -78,9 +80,6 @@ class IPUStrategyV1(distribute_lib.StrategyV1):
7880
7981
8082
"""
81-
82-
_enable_legacy_iterators = True
83-
8483
def __init__(self,
8584
ipu_device="/device:IPU:0",
8685
cpu_device="/device:CPU:0",
@@ -98,15 +97,8 @@ def __init__(self,
9897
to improve Keras performance when using IPUs.
9998
"""
10099
super().__init__(IPUExtendedV1(self, ipu_device, cpu_device))
101-
self._enable_iterators = enable_dataset_iterators
102-
self._enable_keras_extensions = enable_keras_extensions
103-
self._keras_extensions = OrderedDict()
104-
# Insert Sequential before Functional as Sequential models inherit from
105-
# Functional models.
106-
self._register_keras_extension(sequential.Sequential,
107-
sequential_extensions.SequentialExtension)
108-
self._register_keras_extension(functional.Functional,
109-
functional_extensions.FunctionalExtension)
100+
keras_extensions.KerasExtensions.__init__(self, enable_dataset_iterators,
101+
enable_keras_extensions)
110102

111103
@libpvti.instrument_fn(_pvti_trace_channel)
112104
def run(self, fn, args=(), kwargs=None, options=None):
@@ -123,39 +115,6 @@ def _device_ordinal(self):
123115
current_device = tf_device.DeviceSpec.from_string(device_string)
124116
return current_device.device_index
125117

126-
def _enable_dataset_iterators(self):
127-
return context.executing_eagerly() and self._enable_iterators
128-
129-
def _create_dataset_iterator(self, dataset):
130-
assert self._enable_dataset_iterators()
131-
return ipu_infeed_queue.IPUOwnedIterator(dataset=dataset) # pylint: disable=protected-access
132-
133-
def _register_keras_extension(self, class_type, extension):
134-
self._keras_extensions[class_type] = extension
135-
136-
def _delete_keras_extension(self, class_type):
137-
self._keras_extensions.pop(class_type, None)
138-
139-
def _patch_keras_extension(self, instance):
140-
if not self._enable_keras_extensions:
141-
return
142-
143-
for class_type, extension in self._keras_extensions.items():
144-
if isinstance(instance, class_type):
145-
if isinstance(instance, base_layer.KerasExtension):
146-
if not isinstance(instance, extension):
147-
raise RuntimeError(
148-
"KerasExtension patching failed - already patched with a "
149-
"different extension.")
150-
break
151-
152-
# Patch in the extension.
153-
# Note that we keep the name as Keras sometimes does __name__ checks.
154-
cls = instance.__class__
155-
instance.__class__ = cls.__class__(cls.__name__, (cls, extension), {})
156-
extension.__init__(instance)
157-
break
158-
159118
@property
160119
def supports_loss_scaling(self):
161120
return True
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
from collections import OrderedDict
16+
17+
from tensorflow.python.eager import context
18+
from tensorflow.python.ipu import ipu_infeed_queue
19+
from tensorflow.python.keras.engine import base_layer
20+
from tensorflow.python.ipu.keras.extensions import functional_extensions
21+
from tensorflow.python.ipu.keras.extensions import sequential_extensions
22+
from tensorflow.python.keras.engine import functional
23+
from tensorflow.python.keras.engine import sequential
24+
25+
26+
class KerasExtensions:
27+
_enable_legacy_iterators = True
28+
29+
def __init__(self,
30+
enable_dataset_iterators=True,
31+
enable_keras_extensions=True):
32+
self._enable_iterators = enable_dataset_iterators
33+
self._enable_keras_extensions = enable_keras_extensions
34+
self._keras_extensions = OrderedDict()
35+
36+
# Insert Sequential before Functional as Sequential models inherit from
37+
# Functional models.
38+
self._register_keras_extension(sequential.Sequential,
39+
sequential_extensions.SequentialExtension)
40+
self._register_keras_extension(functional.Functional,
41+
functional_extensions.FunctionalExtension)
42+
43+
def _enable_dataset_iterators(self):
44+
return context.executing_eagerly() and self._enable_iterators
45+
46+
def _create_dataset_iterator(self, dataset):
47+
assert self._enable_dataset_iterators()
48+
return ipu_infeed_queue.IPUOwnedIterator(dataset=dataset) # pylint: disable=protected-access
49+
50+
def _register_keras_extension(self, class_type, extension):
51+
self._keras_extensions[class_type] = extension
52+
53+
def _delete_keras_extension(self, class_type):
54+
self._keras_extensions.pop(class_type, None)
55+
56+
def _patch_keras_extension(self, instance):
57+
if not self._enable_keras_extensions:
58+
return
59+
60+
for class_type, extension in self._keras_extensions.items():
61+
if isinstance(instance, class_type):
62+
if isinstance(instance, base_layer.KerasExtension):
63+
if not isinstance(instance, extension):
64+
raise RuntimeError(
65+
"KerasExtension patching failed - already patched with a "
66+
"different extension.")
67+
break
68+
69+
# Patch in the extension.
70+
# Note that we keep the name as Keras sometimes does __name__ checks.
71+
cls = instance.__class__
72+
instance.__class__ = cls.__class__(cls.__name__, (cls, extension), {})
73+
extension.__init__(instance)
74+
break

0 commit comments

Comments
 (0)