@@ -80,6 +80,44 @@ def _create_saved_model_v1(self):
8080
8181 builder .save ()
8282
83+ def _create_saved_model_v1_with_hashtable (self ):
84+ """Create a TensorFlow SavedModel V1 with unused hash table for testing."""
85+
86+ graph = tf .Graph ()
87+ with graph .as_default ():
88+ x = tf .placeholder ('float32' , [2 , 2 ])
89+ w = tf .compat .v1 .get_variable ('w' , shape = [2 , 2 ])
90+ output = tf .compat .v1 .matmul (x , w )
91+ init_op = w .initializer
92+
93+ # Add a hash table that is not used by the output.
94+ keys = tf .constant (['key' ])
95+ values = tf .constant ([1 ])
96+ initializer = tf .contrib .lookup .KeyValueTensorInitializer (keys , values )
97+ table = tf .contrib .lookup .HashTable (initializer , - 1 )
98+
99+ # Create a builder.
100+ save_dir = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR )
101+ builder = tf .compat .v1 .saved_model .builder .SavedModelBuilder (save_dir )
102+
103+ with tf .compat .v1 .Session () as sess :
104+ # Run the initializer on `w`.
105+ sess .run (init_op )
106+ table .init .run ()
107+
108+ builder .add_meta_graph_and_variables (
109+ sess , [tf .compat .v1 .saved_model .tag_constants .SERVING ],
110+ signature_def_map = {
111+ "serving_default" :
112+ tf .compat .v1 .saved_model \
113+ .signature_def_utils .predict_signature_def (
114+ inputs = {"x" : x },
115+ outputs = {"output" : output })
116+ },
117+ assets_collection = None )
118+
119+ builder .save ()
120+
83121 def _create_saved_model_with_fusable_conv2d (self ):
84122 """Test a basic model with fusable conv2d."""
85123 layers = [
@@ -192,32 +230,62 @@ def double_module_fn():
192230 def test_convert_saved_model_v1 (self ):
193231 self ._create_saved_model_v1 ()
194232
233+ input_dir = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR )
234+ output_dir = os .path .join (input_dir , 'js' )
195235 tf_saved_model_conversion_v2 .convert_tf_saved_model (
196- os . path . join ( self . _tmp_dir , SAVED_MODEL_DIR ) ,
197- os . path . join ( self . _tmp_dir , SAVED_MODEL_DIR )
236+ input_dir ,
237+ output_dir
198238 )
199239
200- weights = [{
240+ expected_weights_manifest = [{
201241 'paths' : ['group1-shard1of1.bin' ],
202242 'weights' : [{'dtype' : 'float32' , 'name' : 'w' , 'shape' : [2 , 2 ]}]}]
203243
204- tfjs_path = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR )
244+ tfjs_path = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR , 'js' )
205245 # Check model.json and weights manifest.
206246 with open (os .path .join (tfjs_path , 'model.json' ), 'rt' ) as f :
207247 model_json = json .load (f )
208248 self .assertTrue (model_json ['modelTopology' ])
209249 weights_manifest = model_json ['weightsManifest' ]
210- self .assertEqual (weights_manifest , weights )
250+ self .assertEqual (weights_manifest , expected_weights_manifest )
211251 # Check meta-data in the artifact JSON.
212252 self .assertEqual (model_json ['format' ], 'graph-model' )
213253 self .assertEqual (
214254 model_json ['convertedBy' ],
215255 'TensorFlow.js Converter v%s' % version .version )
216256 self .assertEqual (model_json ['generatedBy' ],
217257 tf .__version__ )
218- self .assertTrue (
219- glob .glob (
220- os .path .join (self ._tmp_dir , SAVED_MODEL_DIR , 'group*-*' )))
258+ self .assertTrue (glob .glob (os .path .join (output_dir , 'group*-*' )))
259+
260+ def test_convert_saved_model_v1_with_hashtable (self ):
261+ self ._create_saved_model_v1_with_hashtable ()
262+
263+ input_dir = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR )
264+ output_dir = os .path .join (input_dir , 'js' )
265+ tf_saved_model_conversion_v2 .convert_tf_saved_model (
266+ input_dir ,
267+ output_dir
268+ )
269+
270+ expected_weights_manifest = [{
271+ 'paths' : ['group1-shard1of1.bin' ],
272+ 'weights' : [{'dtype' : 'float32' , 'name' : 'w' , 'shape' : [2 , 2 ]}]}]
273+
274+ tfjs_path = os .path .join (self ._tmp_dir , SAVED_MODEL_DIR , 'js' )
275+ # Check model.json and weights manifest.
276+ with open (os .path .join (tfjs_path , 'model.json' ), 'rt' ) as f :
277+ model_json = json .load (f )
278+ self .assertTrue (model_json ['modelTopology' ])
279+ weights_manifest = model_json ['weightsManifest' ]
280+ self .assertEqual (weights_manifest , expected_weights_manifest )
281+ # Check meta-data in the artifact JSON.
282+ self .assertEqual (model_json ['format' ], 'graph-model' )
283+ self .assertEqual (
284+ model_json ['convertedBy' ],
285+ 'TensorFlow.js Converter v%s' % version .version )
286+ self .assertEqual (model_json ['generatedBy' ],
287+ tf .__version__ )
288+ self .assertTrue (glob .glob (os .path .join (output_dir , 'group*-*' )))
221289
222290 def test_convert_saved_model (self ):
223291 self ._create_saved_model ()
0 commit comments