Skip to content

Commit dc90fd9

Browse files
committed
Propagate constants from functional graphs
Summary: The TF2XLA bridge is less aggressive. Gather ops generate sparse tensors which get converted to dense tensors - to do that the shapes (using VariableShape) need to be constant, but that needs propagating. Fix T43247 TF2.4 Only Test Plan: CI Reviewers: jackh, jakeh, alfiee, samuelh, #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, davidn Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, davidn Subscribers: davidn Maniphest Tasks: T43247 Differential Revision: https://phabricator.sourcevertex.net/D49099
1 parent 35135d9 commit dc90fd9

File tree

9 files changed

+78
-14
lines changed

9 files changed

+78
-14
lines changed

tensorflow/python/ipu/ops/application_compile_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def wrapped_func(*args):
8484
xla_context = control_flow_ops.XLAControlFlowContext()
8585
try:
8686
xla_context.Enter()
87-
func_graph, captured_args = _compile_function(
87+
func_graph, captured_args, _ = _compile_function(
8888
wrapped_func,
8989
inputs,
9090
scope, [],

tensorflow/python/ipu/ops/functional_ops.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
from tensorflow.compiler.xla import xla_data_pb2
2222
from tensorflow.core.framework import attr_value_pb2
2323
from tensorflow.compiler.plugin.poplar.ops import gen_functional_ops
24+
from tensorflow.python.framework import constant_op
2425
from tensorflow.python.framework import dtypes
2526
from tensorflow.python.framework import func_graph as func_graph_module
2627
from tensorflow.python.framework import ops
28+
from tensorflow.python.framework import tensor_util
2729
from tensorflow.python.ipu import scopes
2830
from tensorflow.python.ops import control_flow_util_v2 as util
31+
from tensorflow.python.ops import resource_variable_ops
2932
from tensorflow.python.util import nest
3033

3134

@@ -75,7 +78,7 @@ def decorated(inner_func):
7578
def func_wrapper(*args):
7679
args = _convert_to_list(args)
7780
with ops.name_scope(name) as scope:
78-
func_graph, captured_args = _compile_function(
81+
func_graph, captured_args, constant_outputs = _compile_function(
7982
inner_func, args, scope, [], allow_external_captures=True)
8083

8184
with ops.control_dependencies(list(func_graph.control_captures)):
@@ -87,6 +90,7 @@ def func_wrapper(*args):
8790
unique_sharding=unique_sharding,
8891
keep_input_layouts=keep_input_layouts,
8992
name=name)
93+
outputs = _replace_outputs(outputs, constant_outputs)
9094

9195
# pack_sequence_as requires a list of Tensors, but the gen_ operation
9296
# returns an Operation under some circumstances (probably when that
@@ -169,7 +173,8 @@ def func_wrapper(*args, **kwargs):
169173
op._set_shape_list_attr("_xla_inferred_shapes", output_shapes)
170174
# pylint: enable=protected-access
171175

172-
return func_graph, captured_args
176+
constant_outputs = _get_constant_outputs(func_graph, captured_args)
177+
return func_graph, captured_args, constant_outputs
173178

174179

175180
def _pack_sequence_as(structured_outputs, op_outputs):
@@ -203,3 +208,48 @@ def _convert_to_list(xs):
203208
if not isinstance(xs, (list, tuple)):
204209
return [xs]
205210
return list(xs)
211+
212+
213+
def _get_constant_outputs(func_graph, func_inputs):
214+
"""Get constant outputs for a functional graph.
215+
216+
Get constant outputs in order to propagate them in the XLA graph. This
217+
includes `VariableShape` operation which needs to return a constant."""
218+
if not func_graph.outputs:
219+
return None
220+
221+
def get_output_info(output):
222+
while output.op.type == "Identity":
223+
output = output.op.inputs[0]
224+
if constant_op.is_constant(output):
225+
# Propagate constants.
226+
return constant_op.constant(tensor_util.constant_value(output),
227+
dtype=output.dtype)
228+
229+
if output.op.type == "VariableShape":
230+
# Propagate variable shapes.
231+
# Find the variable inside the function and its inputs index.
232+
var = output.op.inputs[0]
233+
assert var.dtype == dtypes.resource
234+
index = [
235+
i for i, v in enumerate(func_graph.inputs)
236+
if v.dtype == dtypes.resource and v is var
237+
]
238+
assert len(index) == 1
239+
# Get the input variable.
240+
outter_var = func_inputs[index[0]]
241+
return resource_variable_ops.variable_shape(outter_var,
242+
out_type=output.dtype)
243+
return None
244+
245+
return [get_output_info(x) for x in nest.flatten(func_graph.outputs)]
246+
247+
248+
def _replace_outputs(outputs, to_replace_with):
249+
flat_outputs = nest.flatten(outputs)
250+
flat_to_replace_with = nest.flatten(to_replace_with)
251+
assert len(flat_outputs) == len(flat_to_replace_with)
252+
flat_outputs = [
253+
x if y is None else y for x, y in zip(flat_outputs, flat_to_replace_with)
254+
]
255+
return nest.pack_sequence_as(outputs, flat_outputs)

tensorflow/python/ipu/ops/functional_ops_grad.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,19 @@ def _get_gradients_for_function(op, *grads):
170170
fwd_op._set_shape_list_attr("output_shapes", func_graph.output_shapes)
171171
fwd_op._add_outputs([t.dtype for t in extra_func_outputs],
172172
[t.shape for t in extra_func_outputs])
173+
# pylint: enable=protected-access
173174

174175
func_grad_inputs = _resolve_grad_inputs(func_graph, func_grad_graph, op)
175-
# pylint: enable=protected-access
176-
return func_grad_graph, func_grad_inputs
176+
constant_outputs = functional_ops._get_constant_outputs( # pylint: disable=protected-access
177+
func_grad_graph, func_grad_inputs)
178+
return func_grad_graph, func_grad_inputs, constant_outputs
177179

178180

179181
@ops.RegisterGradient("Function")
180182
def _function_grad(op, *grads):
181183
"""The gradient of a Function op."""
182-
func_grad_graph, func_grad_inputs = _get_gradients_for_function(op, *grads)
184+
func_grad_graph, func_grad_inputs, constant_outputs = \
185+
_get_gradients_for_function(op, *grads)
183186
outputs = gen_functional_ops.function(
184187
func_grad_inputs,
185188
to_apply=util.create_new_tf_function(func_grad_graph),
@@ -188,6 +191,7 @@ def _function_grad(op, *grads):
188191
unique_sharding=op.get_attr("unique_sharding"),
189192
keep_input_layouts=True)
190193

194+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
191195
return functional_ops._pack_sequence_as( # pylint: disable=protected-access
192196
func_grad_graph.structured_outputs, outputs)
193197

tensorflow/python/ipu/ops/nn_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ def func_wrapper(*args):
168168

169169
args = functional_ops._convert_to_list(args) # pylint: disable=protected-access
170170
with ops.name_scope("multi_conv") as scope:
171-
func_graph, captured_args = functional_ops._compile_function( # pylint: disable=protected-access
171+
func_graph, captured_args, constant_outputs = \
172+
functional_ops._compile_function( # pylint: disable=protected-access
172173
func_wrapper,
173174
args,
174175
scope, [],
@@ -181,6 +182,7 @@ def func_wrapper(*args):
181182
Tout=func_graph.output_types,
182183
output_shapes=func_graph.output_shapes,
183184
option_flags=json_format.MessageToJson(option_proto))
185+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
184186

185187
return functional_ops._pack_sequence_as( # pylint: disable=protected-access
186188
func_graph.structured_outputs, outputs)

tensorflow/python/ipu/ops/nn_ops_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _ipu_swish_grad(op, grad):
4949
@ops.RegisterGradient("MultiConv")
5050
def _multi_conv_grad(op, *grads):
5151
"""The gradient of a MultiConv op."""
52-
func_grad_graph, func_grad_inputs = \
52+
func_grad_graph, func_grad_inputs, constant_outputs = \
5353
functional_ops_grad._get_gradients_for_function(op, *grads) # pylint: disable=protected-access
5454
outputs = gen_functional_ops.multi_conv(
5555
func_grad_inputs,
@@ -58,6 +58,7 @@ def _multi_conv_grad(op, *grads):
5858
output_shapes=func_grad_graph.output_shapes,
5959
option_flags=op.get_attr("option_flags"))
6060

61+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
6162
return functional_ops._pack_sequence_as( # pylint: disable=protected-access
6263
func_grad_graph.structured_outputs, outputs)
6364

tensorflow/python/ipu/ops/pipelining_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,7 +1021,8 @@ def resource_update_():
10211021
resource_update_ops.append(enqueue)
10221022

10231023
with ops.name_scope(name + "/WU") as scope:
1024-
func_graph, captured_args = functional_ops._compile_function( # pylint: disable=protected-access
1024+
func_graph, captured_args, constant_outputs = \
1025+
functional_ops._compile_function( # pylint: disable=protected-access
10251026
resource_update_, [], scope, resource_update_ops, True)
10261027

10271028
# Create the pipeline resource update stage and lower the function into XLA.
@@ -1035,6 +1036,7 @@ def resource_update_():
10351036
replicated_optimizer_state_sharding=
10361037
replicated_optimizer_state_sharding,
10371038
num_batches_to_accumulate=gradient_accumulation_count)
1039+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
10381040

10391041
if not isinstance(outputs, ops.Operation):
10401042
if not outfeed_queue:
@@ -1053,7 +1055,7 @@ def resource_update_():
10531055
with ops.name_scope(name) as scope:
10541056
# pylint: disable=protected-access
10551057
try:
1056-
func_graph, captured_args = functional_ops._compile_function(
1058+
func_graph, captured_args, _ = functional_ops._compile_function(
10571059
_pipeline, inputs, scope, control_outputs)
10581060
except functional_ops._InvalidCaptureException as e:
10591061
raise ValueError(
@@ -1212,7 +1214,8 @@ def gradient_override_wrapper(*args, **kwargs):
12121214
with ops.name_scope(name) as scope:
12131215
# pylint: disable=protected-access
12141216
try:
1215-
func_graph, captured_args = functional_ops._compile_function(
1217+
func_graph, captured_args, constant_outputs = \
1218+
functional_ops._compile_function(
12161219
gradient_override_wrapper, args, scope, control_outputs)
12171220
except functional_ops._InvalidCaptureException as e:
12181221
raise ValueError(
@@ -1233,6 +1236,7 @@ def gradient_override_wrapper(*args, **kwargs):
12331236
if isinstance(outputs, ops.Operation):
12341237
return outputs
12351238

1239+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
12361240
return functional_ops._pack_sequence_as( # pylint: disable=protected-access
12371241
func_graph.structured_outputs, outputs)
12381242

tensorflow/python/ipu/ops/pipelining_ops_grad.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
@ops.RegisterGradient("PipelineStage")
2626
def _pipeline_stage_grad(op, *grads):
2727
"""The gradient of a PipelineStage op."""
28-
func_grad_graph, func_grad_inputs = \
28+
func_grad_graph, func_grad_inputs, constant_outputs = \
2929
functional_ops_grad._get_gradients_for_function(op, *grads) # pylint: disable=protected-access
3030
stage_op = op.outputs[0].op
3131
stage_id = stage_op.get_attr('stage_id')
@@ -37,6 +37,7 @@ def _pipeline_stage_grad(op, *grads):
3737
output_shapes=func_grad_graph.output_shapes,
3838
stage_id=stage_id)
3939

40+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
4041
return functional_ops._pack_sequence_as( # pylint: disable=protected-access
4142
func_grad_graph.structured_outputs, outputs)
4243

tensorflow/python/ipu/optimizers/gradient_accumulation_optimizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ def apply_gradient_accumulation(resource_update_, name, apply_grad_ops,
137137
replicated_optimizer_state_sharding,
138138
num_mini_batches):
139139
with ops.name_scope(name + "/WU") as scope:
140-
func_graph, captured_args = functional_ops._compile_function( # pylint: disable=protected-access
140+
func_graph, captured_args, constant_outputs = \
141+
functional_ops._compile_function( # pylint: disable=protected-access
141142
resource_update_, [], scope, apply_grad_ops, True)
142143

143144
# Create the resource update and lower the function into XLA.
@@ -151,6 +152,7 @@ def apply_gradient_accumulation(resource_update_, name, apply_grad_ops,
151152
replicated_optimizer_state_sharding=
152153
replicated_optimizer_state_sharding,
153154
num_batches_to_accumulate=num_mini_batches)
155+
outputs = functional_ops._replace_outputs(outputs, constant_outputs) # pylint: disable=protected-access
154156

155157
return outputs
156158

tensorflow/python/ipu/tests/pipelining_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2276,7 +2276,7 @@ def stage1(indices):
22762276
shape=[300, 300],
22772277
dtype=dtypes.float16,
22782278
initializer=init_ops.ones_initializer())
2279-
return embedding_ops.embedding_lookup(table, indices)
2279+
return array_ops.gather(table, indices)
22802280

22812281
def identity(*args):
22822282
return args

0 commit comments

Comments
 (0)