Skip to content

Commit a9e29eb

Browse files
authored
Fix bug in converter to load saved model (#1981)
BUG Fix a bug that prohibits the conversion of a Saved Model produced by AutoML object detection. The error was produced by an internal `convert_variables_to_constants() method`: `ValueError: Cannot find the variable that is an input to the ReadVariableOp.` Comparing our code to the `freeze_graph` util showed that we were not loading the graph properly into a Session object, before calling `convert_variables_to_constants()`. The unit test emulates the Saved Model by adding a hash table op that is not used by the inference signature. Verified that this test fails at master. Cloud will make an integration test internally that asserts that this conversion continuous to work.
1 parent 61b355e commit a9e29eb

File tree

3 files changed

+93
-18
lines changed

3 files changed

+93
-18
lines changed

tfjs-converter/python/run-python-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ TEST_FILES="$(find "${SCRIPTS_DIR}" -name '*_test.py')"
2424

2525
pip install virtualenv
2626

27-
TMP_VENV_DIR="$(mktemp -d --suffix=_venv)"
27+
TMP_VENV_DIR="$(mktemp -u).venv"
2828
virtualenv -p "python" "${TMP_VENV_DIR}"
2929
source "${TMP_VENV_DIR}/bin/activate"
3030

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from tensorflow.python.grappler import cluster as gcluster
3030
from tensorflow.python.grappler import tf_optimizer
3131
from tensorflow.python.saved_model.load import load
32+
from tensorflow.python.saved_model import loader
33+
from tensorflow.python.tools import saved_model_utils
3234
from tensorflow.python.training.saver import export_meta_graph
3335
from google.protobuf.json_format import MessageToDict
3436
import tensorflow_hub as hub
@@ -272,15 +274,20 @@ def _check_signature_in_model(saved_model, signature_name):
272274
saved_model.signatures.keys()))
273275

274276

275-
def _freeze_saved_model_v1(graph, output_node_names):
276-
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
277-
tf.compat.v1.Session(), graph.as_graph_def(), output_node_names)
277+
def _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
278+
output_node_names):
279+
with tf.compat.v1.Session() as sess:
280+
loader.load(sess, saved_model_tags, saved_model_dir)
281+
input_graph_def = saved_model_utils.get_meta_graph_def(
282+
saved_model_dir, ','.join(saved_model_tags)).graph_def
283+
frozen_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
284+
sess, input_graph_def, output_node_names)
278285

279-
frozen_graph = tf.Graph()
280-
with frozen_graph.as_default():
281-
tf.import_graph_def(frozen_graph_def, name='')
286+
frozen_graph = tf.Graph()
287+
with frozen_graph.as_default():
288+
tf.import_graph_def(frozen_graph_def, name='')
282289

283-
return frozen_graph
290+
return frozen_graph
284291

285292
def _freeze_saved_model_v2(concrete_func):
286293
return convert_to_constants.convert_variables_to_constants_v2(
@@ -336,8 +343,8 @@ def convert_tf_saved_model(saved_model_dir,
336343
try:
337344
frozen_graph = _freeze_saved_model_v2(concrete_func)
338345
except BaseException:
339-
frozen_graph = _freeze_saved_model_v1(
340-
concrete_func.graph, output_node_names)
346+
frozen_graph = _freeze_saved_model_v1(saved_model_dir, saved_model_tags,
347+
output_node_names)
341348

342349
optimize_graph(frozen_graph, output_node_names, output_graph,
343350
model.tensorflow_version,

tfjs-converter/python/tensorflowjs/converters/tf_saved_model_conversion_v2_test.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,44 @@ def _create_saved_model_v1(self):
8080

8181
builder.save()
8282

83+
def _create_saved_model_v1_with_hashtable(self):
84+
"""Create a TensorFlow SavedModel V1 with unused hash table for testing."""
85+
86+
graph = tf.Graph()
87+
with graph.as_default():
88+
x = tf.placeholder('float32', [2, 2])
89+
w = tf.compat.v1.get_variable('w', shape=[2, 2])
90+
output = tf.compat.v1.matmul(x, w)
91+
init_op = w.initializer
92+
93+
# Add a hash table that is not used by the output.
94+
keys = tf.constant(['key'])
95+
values = tf.constant([1])
96+
initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
97+
table = tf.contrib.lookup.HashTable(initializer, -1)
98+
99+
# Create a builder.
100+
save_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
101+
builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(save_dir)
102+
103+
with tf.compat.v1.Session() as sess:
104+
# Run the initializer on `w`.
105+
sess.run(init_op)
106+
table.init.run()
107+
108+
builder.add_meta_graph_and_variables(
109+
sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
110+
signature_def_map={
111+
"serving_default":
112+
tf.compat.v1.saved_model \
113+
.signature_def_utils.predict_signature_def(
114+
inputs={"x": x},
115+
outputs={"output": output})
116+
},
117+
assets_collection=None)
118+
119+
builder.save()
120+
83121
def _create_saved_model_with_fusable_conv2d(self):
84122
"""Test a basic model with fusable conv2d."""
85123
layers = [
@@ -192,32 +230,62 @@ def double_module_fn():
192230
def test_convert_saved_model_v1(self):
193231
self._create_saved_model_v1()
194232

233+
input_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
234+
output_dir = os.path.join(input_dir, 'js')
195235
tf_saved_model_conversion_v2.convert_tf_saved_model(
196-
os.path.join(self._tmp_dir, SAVED_MODEL_DIR),
197-
os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
236+
input_dir,
237+
output_dir
198238
)
199239

200-
weights = [{
240+
expected_weights_manifest = [{
201241
'paths': ['group1-shard1of1.bin'],
202242
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]
203243

204-
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
244+
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
205245
# Check model.json and weights manifest.
206246
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
207247
model_json = json.load(f)
208248
self.assertTrue(model_json['modelTopology'])
209249
weights_manifest = model_json['weightsManifest']
210-
self.assertEqual(weights_manifest, weights)
250+
self.assertEqual(weights_manifest, expected_weights_manifest)
211251
# Check meta-data in the artifact JSON.
212252
self.assertEqual(model_json['format'], 'graph-model')
213253
self.assertEqual(
214254
model_json['convertedBy'],
215255
'TensorFlow.js Converter v%s' % version.version)
216256
self.assertEqual(model_json['generatedBy'],
217257
tf.__version__)
218-
self.assertTrue(
219-
glob.glob(
220-
os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'group*-*')))
258+
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))
259+
260+
def test_convert_saved_model_v1_with_hashtable(self):
261+
self._create_saved_model_v1_with_hashtable()
262+
263+
input_dir = os.path.join(self._tmp_dir, SAVED_MODEL_DIR)
264+
output_dir = os.path.join(input_dir, 'js')
265+
tf_saved_model_conversion_v2.convert_tf_saved_model(
266+
input_dir,
267+
output_dir
268+
)
269+
270+
expected_weights_manifest = [{
271+
'paths': ['group1-shard1of1.bin'],
272+
'weights': [{'dtype': 'float32', 'name': 'w', 'shape': [2, 2]}]}]
273+
274+
tfjs_path = os.path.join(self._tmp_dir, SAVED_MODEL_DIR, 'js')
275+
# Check model.json and weights manifest.
276+
with open(os.path.join(tfjs_path, 'model.json'), 'rt') as f:
277+
model_json = json.load(f)
278+
self.assertTrue(model_json['modelTopology'])
279+
weights_manifest = model_json['weightsManifest']
280+
self.assertEqual(weights_manifest, expected_weights_manifest)
281+
# Check meta-data in the artifact JSON.
282+
self.assertEqual(model_json['format'], 'graph-model')
283+
self.assertEqual(
284+
model_json['convertedBy'],
285+
'TensorFlow.js Converter v%s' % version.version)
286+
self.assertEqual(model_json['generatedBy'],
287+
tf.__version__)
288+
self.assertTrue(glob.glob(os.path.join(output_dir, 'group*-*')))
221289

222290
def test_convert_saved_model(self):
223291
self._create_saved_model()

0 commit comments

Comments
 (0)