Skip to content

Commit 09ed826

Browse files
hakosgeorgepaw
authored andcommitted
Allow external functional captures when capturing by value
Summary: When an outer FuncGraph is set to be capturing by value, allow external non-resource captures by default when compiling/constructing an inner FuncGraph. This allows capturing resource variables by their constant tensor values (i.e. freezing them) when using the compile op to compile a pipeline op with resource variables. Fixes T43559. Reviewers: #tensorflow, simonl, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Subscribers: georgep Maniphest Tasks: T43559 Differential Revision: https://phabricator.sourcevertex.net/D49317
1 parent 733f924 commit 09ed826

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

tensorflow/python/ipu/ops/functional_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,25 @@ class _InvalidCaptureException(Exception):
112112
pass
113113

114114

115+
def _is_capturing_by_value(graph):
116+
return (isinstance(graph, func_graph_module.FuncGraph)
117+
and graph.capture_by_value)
118+
119+
115120
def _compile_function(func,
116121
args,
117122
scope,
118123
control_outputs,
119-
allow_external_captures=False,
124+
allow_external_captures=None,
120125
capture_by_value=None):
121126
parent_graph = ops.get_default_graph()
122127
# Automatic control dependencies are added in defuns, but not in v1
123128
# graphs. Propagate that behavior here.
124129
add_control_dependencies = parent_graph._add_control_dependencies # pylint: disable=protected-access
125130

131+
if allow_external_captures is None:
132+
allow_external_captures = _is_capturing_by_value(parent_graph)
133+
126134
# Functions inherit frontend attributes and the gradient override map from the
127135
# parent graph.
128136
proto = xla_data_pb2.FrontendAttributes()

tensorflow/python/ipu/tests/application_compile_test.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import tempfile
1818
import numpy as np
1919

20+
from absl.testing import parameterized
2021
from tensorflow.python.client import session
2122
from tensorflow.python.data.ops import dataset_ops
2223
from tensorflow.python.framework import dtypes
@@ -26,6 +27,7 @@
2627
from tensorflow.python.ipu import loops
2728
from tensorflow.python.ipu.config import DeviceConnectionType
2829
from tensorflow.python.ipu.config import IPUConfig
30+
from tensorflow.python.ipu.ops import pipelining_ops
2931
from tensorflow.python.ipu.ops.application_compile_op import experimental_application_compile_op as application_compile_op
3032
from tensorflow.python.keras import layers
3133
from tensorflow.python.ops import array_ops
@@ -36,7 +38,8 @@
3638
from tensorflow.python.training import gradient_descent
3739

3840

39-
class TestApplicationCompile(test_util.TensorFlowTestCase):
41+
class TestApplicationCompile(test_util.TensorFlowTestCase,
42+
parameterized.TestCase):
4043
def setUp(self):
4144
super().setUp()
4245

@@ -209,6 +212,39 @@ def my_net(lr):
209212

210213
self.assertGreater(os.path.getsize(compiled_path.decode()), 0)
211214

215+
@parameterized.named_parameters(("resources", False), ("constants", True))
216+
@test_util.deprecated_graph_mode_only
217+
def test_compile_pipeline(self, freeze_variables):
218+
with session.Session() as sess:
219+
220+
dataset = dataset_ops.Dataset.from_tensor_slices((np.ones(
221+
(10, 5), dtype=np.float32),))
222+
dataset = dataset.batch(1, drop_remainder=True)
223+
infeed_queue = ipu_infeed_queue.IPUInfeedQueue(dataset)
224+
outfeed_queue = ipu_outfeed_queue.IPUOutfeedQueue()
225+
226+
def stage1(offset, x):
227+
return layers.Dense(5, activation="relu")(x) + offset
228+
229+
def stage2(x):
230+
return layers.Dense(10, activation="softmax")(x)
231+
232+
def my_net():
233+
return pipelining_ops.pipeline(computational_stages=[stage1, stage2],
234+
gradient_accumulation_count=4,
235+
infeed_queue=infeed_queue,
236+
inputs=[42.0],
237+
outfeed_queue=outfeed_queue,
238+
device_mapping=[0, 0])
239+
240+
result = application_compile_op(my_net,
241+
freeze_variables=freeze_variables)
242+
243+
sess.run(variables.global_variables_initializer())
244+
compiled_path = sess.run(result)
245+
246+
self.assertGreater(os.path.getsize(compiled_path.decode()), 0)
247+
212248

213249
if __name__ == "__main__":
214250
test.main()

0 commit comments

Comments
 (0)