@@ -169,7 +169,7 @@ def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
169169 return nodes_to_keep
170170
171171@tf_export (v1 = ["graph_util.create_kv_variable_init_graph" ])
172- def create_kv_variable_init_graph (graph , global_step_name , restore_all_op_name ):
172+ def create_kv_variable_init_graph (graph , restore_all_op_name ):
173173 name_to_input_name , name_to_node , name_to_seq_num = \
174174 _extract_graph_summary (graph )
175175
@@ -184,8 +184,10 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
184184 " {} in current graph." .format (restore_all_op_name ))
185185
186186 for restore_shard_input_full_name in restore_all_op .input :
187- restore_shard_input_name = re .sub (r"^\^" , "" , restore_shard_input_full_name )
188- restore_shard_input_op = name_to_node [restore_shard_input_name ]
187+ restore_shard_input_no_op_name = re .sub (r"^\^" , "" , restore_shard_input_full_name )
188+ restore_shard_input_no_op = name_to_node [restore_shard_input_no_op_name ]
189+ restore_shard_input_op_name = re .sub (r"^\^" , "" ,restore_shard_input_no_op .input [0 ])
190+ restore_shard_input_op = name_to_node [restore_shard_input_op_name ]
189191 # go through all restore_shard ops
190192 new_node = node_def_pb2 .NodeDef ()
191193 new_node .CopyFrom (restore_shard_input_op )
@@ -198,10 +200,6 @@ def create_kv_variable_init_graph(graph, global_step_name, restore_all_op_name):
198200 n_node .op == "KvResourceImportV2" or \
199201 n_node .op == "KvResourceImport" :
200202 new_node .input .append (n_full_name )
201- else :
202- # Keep global_step assign op in new save/restore_all
203- if n_node .input [0 ] == global_step_name :
204- new_node .input .append (n_full_name )
205203
206204 graph .node .remove (restore_shard_input_op )
207205 graph .node .extend ([new_node ])
0 commit comments