2525
2626from tensorflow .python .data .ops import dataset_ops
2727from tensorflow .python .eager import def_function
28- from tensorflow .python .framework import tensor_spec
2928from tensorflow .python .framework import convert_to_constants
29+ from tensorflow .python .framework import tensor_spec
30+ from tensorflow .python .framework .ops import convert_to_tensor
3031from tensorflow .python .framework .ops import Tensor
3132from tensorflow .python .ipu import application_compile_op
3233from tensorflow .python .ipu import embedded_runtime
@@ -205,7 +206,8 @@ def validate_single_signature(signature_name, signature):
205206def _prepare_input_signature (defunc ,
206207 defunc_signature = None ,
207208 input_dataset = None ,
208- non_feed_inputs = None ):
209+ non_feed_inputs = None ,
210+ remove_non_feed_inputs_from_signature = True ):
209211 """Prepare `input_signature` for `defunc` from given arguments.
210212
211213 Args:
@@ -218,6 +220,8 @@ def _prepare_input_signature(defunc,
218220 will be inferred.
219221 non_feed_inputs (list, optional): List of inputs that will be provided
220222 to the graph without usage of infeed queue.
223+ remove_non_feed_inputs_from_signature (bool, optional): If True passed
224+ non_feed_inputs will be removed from created input signature.
221225
222226 Returns:
223227 list: List of `tf.TensorSpec` objects with types, shapes and names.
@@ -241,6 +245,8 @@ def _prepare_input_signature(defunc,
241245 input_signature = input_dataset .element_spec
242246 if isinstance (input_signature , tensor_spec .TensorSpec ):
243247 input_signature = input_signature ,
248+ input_signature = tuple (
249+ _get_signature_from_tensors (non_feed_inputs )) + input_signature
244250 elif defunc_signature is None :
245251 if isinstance (defunc , def_function .Function ):
246252 input_signature = defunc .input_signature
@@ -253,14 +259,16 @@ def _prepare_input_signature(defunc,
253259 f'list, received { str (type (input_signature ))} ' )
254260
255261 names = list (inspect .signature (defunc ).parameters .keys ())
256- if non_feed_inputs :
262+ if non_feed_inputs is not None and remove_non_feed_inputs_from_signature :
257263 names = names [len (non_feed_inputs ):]
258264 if len (input_signature ) > len (names ):
259265 input_signature = input_signature [len (non_feed_inputs ):]
260266
261267 if len (input_signature ) != len (names ):
262- raise ValueError ('Length of input_signature does not match the number of '
263- f'{ defunc .__name__ } arguments' )
268+ raise ValueError (
269+ 'Length of input_signature does not match the number of '
270+ f'{ defunc .__name__ } arguments, input_signature : { input_signature } , '
271+ f'names : { names } ' )
264272
265273 # Store argument names in the input_signature
266274 input_signature = [
@@ -293,14 +301,52 @@ def _create_feeds(input_signature, input_dataset=None):
293301 return (infeed , outfeed )
294302
295303
304+ def _get_signature_from_tensors (tensors ):
305+ if tensors is None :
306+ return []
307+
308+ def to_tensor (data ):
309+ if not isinstance (data , Tensor ):
310+ return convert_to_tensor (data )
311+
312+ return data
313+
314+ return [
315+ tensor_spec .TensorSpec .from_tensor (to_tensor (tensor ))
316+ for tensor in tensors
317+ ]
318+
319+
296320def _freeze_defunc (defunc , input_signature ):
297321 @def_function .function (input_signature = input_signature )
298322 def defunc_wrapper (* args ):
299323 return defunc (* args )
300324
301- return convert_to_constants .convert_variables_to_constants_v2 (
325+ concrete_defunc = convert_to_constants .convert_variables_to_constants_v2 (
302326 defunc_wrapper .get_concrete_function (* input_signature ))
303327
328+ @def_function .function (input_signature = input_signature )
329+ def transformed_defunc_wrapper (* args ):
330+ return concrete_defunc (* args )
331+
332+ return transformed_defunc_wrapper , _get_signature_from_tensors (
333+ concrete_defunc .outputs )
334+
335+
336+ def _freeze_single_step (defunc , input_signature ):
337+ return _freeze_defunc (defunc , input_signature )[0 ]
338+
339+
340+ def _freeze_computational_stages (computational_stages , input_signature ):
341+ def transform (stage ):
342+ nonlocal input_signature
343+ transformed_stage , output_signature = _freeze_defunc (
344+ stage , input_signature )
345+ input_signature = output_signature
346+ return transformed_stage
347+
348+ return [transform (stage ) for stage in computational_stages ]
349+
304350
305351def _export_saved_model (predict_step ,
306352 export_dir ,
@@ -373,11 +419,12 @@ def _export_saved_model(predict_step,
373419 with_postprocessing = postprocessing_step is not None
374420
375421 if with_preprocessing :
376- preprocessing_step = _freeze_defunc (preprocessing_step , input_signature )
422+ preprocessing_step = _freeze_single_step (preprocessing_step ,
423+ input_signature )
377424
378425 if with_postprocessing :
379- postprocessing_step = _freeze_defunc (postprocessing_step ,
380- postprocessing_step_signature )
426+ postprocessing_step = _freeze_single_step (postprocessing_step ,
427+ postprocessing_step_signature )
381428
382429 def validate_io_matching (src_return_tensors , dst_input_signature ,
383430 src_step_name , dst_step_name ):
@@ -618,6 +665,7 @@ def export_single_step(predict_step,
618665 postprocessing_step_signature = _prepare_input_signature (
619666 postprocessing_step , postprocessing_step_signature )
620667
668+ predict_step = _freeze_single_step (predict_step , predict_step_signature )
621669 predict_loop = _wrap_in_loop (predict_step , predict_step_signature ,
622670 input_dataset , iterations )
623671 return _export_saved_model (predict_loop , export_dir , input_signature ,
@@ -769,6 +817,13 @@ def export_pipeline(computational_stages,
769817 preprocessing_step , preprocessing_step_signature ,
770818 postprocessing_step , postprocessing_step_signature )
771819
820+ computational_stages = _freeze_computational_stages (
821+ computational_stages ,
822+ _prepare_input_signature (predict_step ,
823+ predict_step_signature ,
824+ input_dataset ,
825+ non_feed_inputs = inputs ,
826+ remove_non_feed_inputs_from_signature = False ))
772827 if preprocessing_step is not None :
773828 input_signature = _prepare_input_signature (preprocessing_step ,
774829 preprocessing_step_signature ,
0 commit comments