Skip to content

Commit 2f938dc

Browse files
authored
[TensorRT] Fix Graph contains EmbeddingVariable compiling issue. (#964)
Signed-off-by: 泊霆 <hujunqi.hjq@alibaba-inc.com> Co-authored-by: 泊霆 <hujunqi.hjq@alibaba-inc.com>
1 parent 0f536a2 commit 2f938dc

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

tensorflow/python/compiler/tensorrt/trt_convert.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -539,13 +539,10 @@ def _gather_names(tensor_info):
539539
# EmbeddingVariable can not be convert to constant, so we need to
540540
# load ev varibles at runtime always.
541541
if self._use_ev:
542-
global_step_collection_ops = sess.graph.get_collection("global_step")
543-
global_step_name = global_step_collection_ops[0].name.split(":")[0]
544542
output_node_names.add(filename_tensor_name)
545543
output_node_names.add(save_tensor_name)
546544
output_node_names.add(restore_op_name)
547545

548-
tf_logging.info("TensorRT - global_step_name: %s" % str(global_step_name))
549546
tf_logging.info("TensorRT - filename_tensor_name: %s" % str(filename_tensor_name))
550547
tf_logging.info("TensorRT - save_tensor_name: %s" % str(save_tensor_name))
551548
tf_logging.info("TensorRT - restore_op_name: %s" % str(restore_op_name))
@@ -559,18 +556,19 @@ def _gather_names(tensor_info):
559556

560557
# Freeze the variables in the SavedModel graph and copy the frozen
561558
# graph over.
562-
variable_names_blacklist = []
563559
if self._use_ev:
564-
variable_names_blacklist.append(global_step_name)
560+
global_step_collection_ops = sess.graph.get_collection("global_step")
561+
if len(global_step_collection_ops) > 0:
562+
sess.run([sess.graph.get_operation_by_name("global_step/Assign")])
565563

566564
frozen_graph_def = graph_util.convert_variables_to_constants(
567565
sess, sess.graph.as_graph_def(add_shapes=True),
568-
list(output_node_names), variable_names_blacklist=variable_names_blacklist)
566+
list(output_node_names))
569567

570568
if self._use_ev:
571569
# Keep KV Variable in saver_def, these kv-vars will be initialized at runtime.
572570
frozen_graph_def = graph_util.create_kv_variable_init_graph(
573-
frozen_graph_def, global_step_name, restore_op_name)
571+
frozen_graph_def, restore_op_name)
574572

575573
self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
576574
self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)

tensorflow/python/framework/graph_util_impl.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
169169
return nodes_to_keep
170170

171171
@tf_export(v1=["graph_util.create_kv_variable_init_graph"])
172-
def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
172+
def create_kv_variable_init_graph(graph, restore_all_op_name):
173173
name_to_input_name, name_to_node, name_to_seq_num = \
174174
_extract_graph_summary(graph)
175175

@@ -184,8 +184,10 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
184184
" {} in current graph.".format(restore_all_op_name))
185185

186186
for restore_shard_input_full_name in restore_all_op.input:
187-
restore_shard_input_name = re.sub(r"^\^", "", restore_shard_input_full_name)
188-
restore_shard_input_op = name_to_node[restore_shard_input_name]
187+
restore_shard_input_no_op_name = re.sub(r"^\^", "", restore_shard_input_full_name)
188+
restore_shard_input_no_op = name_to_node[restore_shard_input_no_op_name]
189+
restore_shard_input_op_name = re.sub(r"^\^", "",restore_shard_input_no_op.input[0])
190+
restore_shard_input_op = name_to_node[restore_shard_input_op_name]
189191
# go through all restore_shard ops
190192
new_node = node_def_pb2.NodeDef()
191193
new_node.CopyFrom(restore_shard_input_op)
@@ -198,10 +200,6 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
198200
n_node.op == "KvResourceImportV2" or \
199201
n_node.op == "KvResourceImport":
200202
new_node.input.append(n_full_name)
201-
else:
202-
# Keep global_step assign op in new save/restore_all
203-
if n_node.input[0] == global_step_name:
204-
new_node.input.append(n_full_name)
205203

206204
graph.node.remove(restore_shard_input_op)
207205
graph.node.extend([new_node])

0 commit comments

Comments
 (0)