Skip to content

Commit 0b6c675

Browse files
Alfie-Edwardsgeorgepaw
authored andcommitted
Supporting models with required training arg in call
Summary: Support keras models with call signatures which have no default value for the `training` parameter. These are supported in upstream keras models by checking the model's call signature and passing False if the argument is required but has not been passed. TF2.5 Only Test Plan: Added a test which checks that you can call build/fit/evaluate/predict on models with a required training parameter. Added a test which checks that you get an appropriate error message when call requires other extra arguments. Reviewers: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Reviewed By: #tensorflow, #framework_ip_review_-_any_oss_or_third-party_code_use_has_been_approved, georgep Maniphest Tasks: T57217 Differential Revision: https://phabricator.sourcevertex.net/D62100
1 parent db3d683 commit 0b6c675

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

tensorflow/python/ipu/keras/extensions/model_extensions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def to_input_tensor(x):
215215
return x
216216
return input_layer.Input(batch_shape=x)
217217

218+
self._check_call_args('build')
218219
inputs = map_shapes(input_shape, to_input_tensor)
219220
outputs = self._trace_graph_network(inputs)
220221
self._update_graph_network(inputs, outputs)
@@ -547,6 +548,18 @@ def _trace_graph_network(self, inputs):
547548
layer._inbound_nodes.clear() # pylint: disable=protected-access
548549
layer._outbound_nodes.clear() # pylint: disable=protected-access
549550

551+
# Case where `training` is a positional arg with no default.
552+
kwargs = {}
553+
# Update self._call_full_argspec.
554+
self._init_call_fn_args()
555+
call_signature = self._call_full_argspec
556+
if len(call_signature.args) > 2:
557+
n_required = len(call_signature.args) - len(call_signature.defaults
558+
or [])
559+
for arg in call_signature.args[2:n_required]:
560+
if arg == 'training':
561+
kwargs['training'] = False
562+
550563
# Invoke call within a call_context in case metrics are added in call.
551564
# We can't invoke self() because the base_layer __call__ override sees that
552565
# we are passing in keras tensors and treats the whole model as a layer
@@ -567,7 +580,7 @@ def get_shape(tensor):
567580
input_shape = get_shape(inputs)
568581

569582
self.build(input_shape)
570-
return self.call(inputs)
583+
return self.call(inputs, **kwargs)
571584

572585
@trackable.no_automatic_dependency_tracking
573586
def _update_graph_network(self, inputs, outputs):

tensorflow/python/ipu/tests/keras/extensions/single_ipu/functional_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,40 @@ def call(self, x): # pylint: disable=arguments-differ
570570
output = model(sample_input) # pylint: disable=not-callable
571571
self.assertEqual(output.shape, (1, 3))
572572

573+
def testModelWithRequiredTrainingArg(self):
574+
class MyModel(training_lib.Model): # pylint: disable=abstract-method
575+
def call(self, inputs, training): # pylint: disable=arguments-differ,unused-argument
576+
return inputs
577+
578+
model = MyModel()
579+
model.compile()
580+
x = np.ones(8)
581+
y = np.ones(8)
582+
583+
model.build([8])
584+
model.fit(x, y, batch_size=1)
585+
model.evaluate(x, y, batch_size=1)
586+
model.predict(x, batch_size=1)
587+
588+
def testModelWithAdditionalRequiredArgs(self):
589+
class MyModel(training_lib.Model): # pylint: disable=abstract-method
590+
def call(self, inputs, additional_arg): # pylint: disable=arguments-differ,unused-argument
591+
return inputs
592+
593+
model = MyModel()
594+
model.compile()
595+
x = np.ones(8)
596+
y = np.ones(8)
597+
598+
with self.assertRaisesRegex(ValueError, r'Models passed to `build`'):
599+
model.build([8])
600+
with self.assertRaisesRegex(ValueError, r'Models passed to `fit`'):
601+
model.fit(x, y, batch_size=1)
602+
with self.assertRaisesRegex(ValueError, r'Models passed to `evaluate`'):
603+
model.evaluate(x, y, batch_size=1)
604+
with self.assertRaisesRegex(ValueError, r'Models passed to `predict`'):
605+
model.predict(x, batch_size=1)
606+
573607
def testNoneInShapeWithFunctionalAPI(self):
574608
# pylint: disable=abstract-method
575609
class BasicBlock(training_lib.Model):

0 commit comments

Comments
 (0)