Skip to content

Commit 369aaca

Browse files
committed
Freeze variables using convert_variables_to_constants_v2 in serving export
Summary: TF2.6 Only Ref T67667 Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, bartlomiejw, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Maniphest Tasks: T67667 Differential Revision: https://phabricator.sourcevertex.net/D73929
1 parent 94c5d3b commit 369aaca

File tree

3 files changed

+113
-10
lines changed

3 files changed

+113
-10
lines changed

tensorflow/python/ipu/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1765,7 +1765,7 @@ tf_py_test(
17651765
size = "large",
17661766
srcs = ["tests/serving_export_test.py"],
17671767
# Shard count needs to match the number of tests.
1768-
shard_count = 20,
1768+
shard_count = 21,
17691769
tags = ["hw_poplar_test_1_ipus"],
17701770
deps = [
17711771
"//tensorflow/compiler/tests:xla_test",

tensorflow/python/ipu/serving.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@
2525

2626
from tensorflow.python.data.ops import dataset_ops
2727
from tensorflow.python.eager import def_function
28-
from tensorflow.python.framework import tensor_spec
2928
from tensorflow.python.framework import convert_to_constants
29+
from tensorflow.python.framework import tensor_spec
30+
from tensorflow.python.framework.ops import convert_to_tensor
3031
from tensorflow.python.framework.ops import Tensor
3132
from tensorflow.python.ipu import application_compile_op
3233
from tensorflow.python.ipu import embedded_runtime
@@ -205,7 +206,8 @@ def validate_single_signature(signature_name, signature):
205206
def _prepare_input_signature(defunc,
206207
defunc_signature=None,
207208
input_dataset=None,
208-
non_feed_inputs=None):
209+
non_feed_inputs=None,
210+
remove_non_feed_inputs_from_signature=True):
209211
"""Prepare `input_signature` for `defunc` from given arguments.
210212
211213
Args:
@@ -218,6 +220,8 @@ def _prepare_input_signature(defunc,
218220
will be inferred.
219221
non_feed_inputs (list, optional): List of inputs that will be provided
220222
to the graph without usage of infeed queue.
223+
remove_non_feed_inputs_from_signature (bool, optional): If True passed
224+
non_feed_inputs will be removed from created input signature.
221225
222226
Returns:
223227
list: List of `tf.TensorSpec` objects with types, shapes and names.
@@ -241,6 +245,8 @@ def _prepare_input_signature(defunc,
241245
input_signature = input_dataset.element_spec
242246
if isinstance(input_signature, tensor_spec.TensorSpec):
243247
input_signature = input_signature,
248+
input_signature = tuple(
249+
_get_signature_from_tensors(non_feed_inputs)) + input_signature
244250
elif defunc_signature is None:
245251
if isinstance(defunc, def_function.Function):
246252
input_signature = defunc.input_signature
@@ -253,14 +259,16 @@ def _prepare_input_signature(defunc,
253259
f'list, received {str(type(input_signature))}')
254260

255261
names = list(inspect.signature(defunc).parameters.keys())
256-
if non_feed_inputs:
262+
if non_feed_inputs is not None and remove_non_feed_inputs_from_signature:
257263
names = names[len(non_feed_inputs):]
258264
if len(input_signature) > len(names):
259265
input_signature = input_signature[len(non_feed_inputs):]
260266

261267
if len(input_signature) != len(names):
262-
raise ValueError('Length of input_signature does not match the number of '
263-
f'{defunc.__name__} arguments')
268+
raise ValueError(
269+
'Length of input_signature does not match the number of '
270+
f'{defunc.__name__} arguments, input_signature : {input_signature}, '
271+
f'names : {names}')
264272

265273
# Store argument names in the input_signature
266274
input_signature = [
@@ -293,14 +301,52 @@ def _create_feeds(input_signature, input_dataset=None):
293301
return (infeed, outfeed)
294302

295303

304+
def _get_signature_from_tensors(tensors):
305+
if tensors is None:
306+
return []
307+
308+
def to_tensor(data):
309+
if not isinstance(data, Tensor):
310+
return convert_to_tensor(data)
311+
312+
return data
313+
314+
return [
315+
tensor_spec.TensorSpec.from_tensor(to_tensor(tensor))
316+
for tensor in tensors
317+
]
318+
319+
296320
def _freeze_defunc(defunc, input_signature):
297321
@def_function.function(input_signature=input_signature)
298322
def defunc_wrapper(*args):
299323
return defunc(*args)
300324

301-
return convert_to_constants.convert_variables_to_constants_v2(
325+
concrete_defunc = convert_to_constants.convert_variables_to_constants_v2(
302326
defunc_wrapper.get_concrete_function(*input_signature))
303327

328+
@def_function.function(input_signature=input_signature)
329+
def transformed_defunc_wrapper(*args):
330+
return concrete_defunc(*args)
331+
332+
return transformed_defunc_wrapper, _get_signature_from_tensors(
333+
concrete_defunc.outputs)
334+
335+
336+
def _freeze_single_step(defunc, input_signature):
337+
return _freeze_defunc(defunc, input_signature)[0]
338+
339+
340+
def _freeze_computational_stages(computational_stages, input_signature):
341+
def transform(stage):
342+
nonlocal input_signature
343+
transformed_stage, output_signature = _freeze_defunc(
344+
stage, input_signature)
345+
input_signature = output_signature
346+
return transformed_stage
347+
348+
return [transform(stage) for stage in computational_stages]
349+
304350

305351
def _export_saved_model(predict_step,
306352
export_dir,
@@ -373,11 +419,12 @@ def _export_saved_model(predict_step,
373419
with_postprocessing = postprocessing_step is not None
374420

375421
if with_preprocessing:
376-
preprocessing_step = _freeze_defunc(preprocessing_step, input_signature)
422+
preprocessing_step = _freeze_single_step(preprocessing_step,
423+
input_signature)
377424

378425
if with_postprocessing:
379-
postprocessing_step = _freeze_defunc(postprocessing_step,
380-
postprocessing_step_signature)
426+
postprocessing_step = _freeze_single_step(postprocessing_step,
427+
postprocessing_step_signature)
381428

382429
def validate_io_matching(src_return_tensors, dst_input_signature,
383430
src_step_name, dst_step_name):
@@ -618,6 +665,7 @@ def export_single_step(predict_step,
618665
postprocessing_step_signature = _prepare_input_signature(
619666
postprocessing_step, postprocessing_step_signature)
620667

668+
predict_step = _freeze_single_step(predict_step, predict_step_signature)
621669
predict_loop = _wrap_in_loop(predict_step, predict_step_signature,
622670
input_dataset, iterations)
623671
return _export_saved_model(predict_loop, export_dir, input_signature,
@@ -769,6 +817,13 @@ def export_pipeline(computational_stages,
769817
preprocessing_step, preprocessing_step_signature,
770818
postprocessing_step, postprocessing_step_signature)
771819

820+
computational_stages = _freeze_computational_stages(
821+
computational_stages,
822+
_prepare_input_signature(predict_step,
823+
predict_step_signature,
824+
input_dataset,
825+
non_feed_inputs=inputs,
826+
remove_non_feed_inputs_from_signature=False))
772827
if preprocessing_step is not None:
773828
input_signature = _prepare_input_signature(preprocessing_step,
774829
preprocessing_step_signature,

tensorflow/python/ipu/tests/serving_export_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@
2424
from tensorflow.python.eager import def_function
2525
from tensorflow.python.framework import constant_op
2626
from tensorflow.python.framework import dtypes
27+
from tensorflow.python.framework import function
2728
from tensorflow.python.framework import test_util
2829
from tensorflow.python.framework import tensor_spec
30+
from tensorflow.python.ops.functional_ops import partitioned_call
2931
from tensorflow.python.ipu import config
3032
from tensorflow.python.ipu import ipu_strategy
3133
from tensorflow.python.ipu import serving
@@ -863,6 +865,52 @@ def stage(x):
863865
device_mapping=[0, 0],
864866
predict_step_signature=predict_step_signature)
865867

868+
@tu.test_uses_ipus(num_ipus=1, allow_ipu_model=False)
869+
@test_util.run_v2_only
870+
def test_export_simple_model_with_variables_stateful_partition_call(self):
871+
element_count = 3
872+
input_shape = (element_count,)
873+
input_tensor = array_ops.zeros(shape=input_shape, dtype=np.float16)
874+
875+
predict_step_signature = (tensor_spec.TensorSpec(shape=input_shape,
876+
dtype=dtypes.float32),)
877+
878+
var_value = np.float32(10.)
879+
w = variables.Variable(var_value)
880+
881+
z_value = np.float32(15.)
882+
z = variables.Variable(z_value)
883+
884+
@function.Defun(dtypes.float32, dtypes.float32)
885+
def partition_call_body(constant, var):
886+
return (constant + constant) * var
887+
888+
@def_function.function(
889+
input_signature=(tensor_spec.TensorSpec(shape=input_shape,
890+
dtype=dtypes.float32),
891+
tensor_spec.TensorSpec(shape=input_shape,
892+
dtype=dtypes.float32)))
893+
def predict_step(x, y):
894+
call_result = partitioned_call(args=[x, w], f=partition_call_body)
895+
896+
predict_step_result = call_result + y + z
897+
return array_ops.reshape(predict_step_result, input_shape)
898+
899+
x = np.arange(element_count, dtype=np.float32)
900+
y = np.arange(element_count, dtype=np.float32)
901+
expected_result = list(((x + x) * var_value) + y + z_value)
902+
903+
with tempfile.TemporaryDirectory() as tmp_folder:
904+
iterations = 16
905+
serving.export_single_step(predict_step, tmp_folder, iterations)
906+
907+
result = list(
908+
self._load_and_run(tmp_folder, {
909+
'x': x,
910+
'y': y
911+
})["output_0"].numpy())
912+
self.assertAllClose(result, expected_result)
913+
866914

867915
if __name__ == "__main__":
868916
test.main()

0 commit comments

Comments
 (0)