2121from tensorflow .compiler .xla import xla_data_pb2
2222from tensorflow .core .framework import attr_value_pb2
2323from tensorflow .compiler .plugin .poplar .ops import gen_functional_ops
24+ from tensorflow .python .framework import constant_op
2425from tensorflow .python .framework import dtypes
2526from tensorflow .python .framework import func_graph as func_graph_module
2627from tensorflow .python .framework import ops
28+ from tensorflow .python .framework import tensor_util
2729from tensorflow .python .ipu import scopes
2830from tensorflow .python .ops import control_flow_util_v2 as util
31+ from tensorflow .python .ops import resource_variable_ops
2932from tensorflow .python .util import nest
3033
3134
@@ -75,7 +78,7 @@ def decorated(inner_func):
7578 def func_wrapper (* args ):
7679 args = _convert_to_list (args )
7780 with ops .name_scope (name ) as scope :
78- func_graph , captured_args = _compile_function (
81+ func_graph , captured_args , constant_outputs = _compile_function (
7982 inner_func , args , scope , [], allow_external_captures = True )
8083
8184 with ops .control_dependencies (list (func_graph .control_captures )):
@@ -87,6 +90,7 @@ def func_wrapper(*args):
8790 unique_sharding = unique_sharding ,
8891 keep_input_layouts = keep_input_layouts ,
8992 name = name )
93+ outputs = _replace_outputs (outputs , constant_outputs )
9094
9195 # pack_sequence_as requires a list of Tensors, but the gen_ operation
9296 # returns an Operation under some circumstances (probably when that
@@ -169,7 +173,8 @@ def func_wrapper(*args, **kwargs):
169173 op ._set_shape_list_attr ("_xla_inferred_shapes" , output_shapes )
170174 # pylint: enable=protected-access
171175
172- return func_graph , captured_args
176+ constant_outputs = _get_constant_outputs (func_graph , captured_args )
177+ return func_graph , captured_args , constant_outputs
173178
174179
175180def _pack_sequence_as (structured_outputs , op_outputs ):
@@ -203,3 +208,48 @@ def _convert_to_list(xs):
203208 if not isinstance (xs , (list , tuple )):
204209 return [xs ]
205210 return list (xs )
211+
212+
213+ def _get_constant_outputs (func_graph , func_inputs ):
214+ """Get constant outputs for a functional graph.
215+
216+ Get constant outputs in order to propagate them in the XLA graph. This
217+ includes `VariableShape` operation which needs to return a constant."""
218+ if not func_graph .outputs :
219+ return None
220+
221+ def get_output_info (output ):
222+ while output .op .type == "Identity" :
223+ output = output .op .inputs [0 ]
224+ if constant_op .is_constant (output ):
225+ # Propagate constants.
226+ return constant_op .constant (tensor_util .constant_value (output ),
227+ dtype = output .dtype )
228+
229+ if output .op .type == "VariableShape" :
230+ # Propagate variable shapes.
231+ # Find the variable inside the function and its inputs index.
232+ var = output .op .inputs [0 ]
233+ assert var .dtype == dtypes .resource
234+ index = [
235+ i for i , v in enumerate (func_graph .inputs )
236+ if v .dtype == dtypes .resource and v is var
237+ ]
238+ assert len (index ) == 1
239+ # Get the input variable.
240+ outter_var = func_inputs [index [0 ]]
241+ return resource_variable_ops .variable_shape (outter_var ,
242+ out_type = output .dtype )
243+ return None
244+
245+ return [get_output_info (x ) for x in nest .flatten (func_graph .outputs )]
246+
247+
248+ def _replace_outputs (outputs , to_replace_with ):
249+ flat_outputs = nest .flatten (outputs )
250+ flat_to_replace_with = nest .flatten (to_replace_with )
251+ assert len (flat_outputs ) == len (flat_to_replace_with )
252+ flat_outputs = [
253+ x if y is None else y for x , y in zip (flat_outputs , flat_to_replace_with )
254+ ]
255+ return nest .pack_sequence_as (outputs , flat_outputs )
0 commit comments