@@ -81,7 +81,7 @@ def __repr__(self):
8181 else :
8282 return '{}[{}]' .format (class_str , fields_str )
8383
84- def call (self , inputs , params = (), ** kwargs ):
84+ def call (self , inputs , params = (), state = (), ** kwargs ):
8585 """Applies this layer to given activation tensors, using trainable params.
8686
8787 Args:
@@ -94,6 +94,7 @@ def call(self, inputs, params=(), **kwargs):
9494 and one for each of this layer's sublayers. If a layer (or sublayer)
9595 has no trainable parameters, the corresponding params element is an
9696 empty tuple.
97+ state: start state.
9798 **kwargs: Layer-specific keyword args.
9899
99100 Returns:
@@ -106,6 +107,7 @@ def call(self, inputs, params=(), **kwargs):
106107 """
107108 raise NotImplementedError
108109
110+ # TODO(wangpeng): Should be called `new_parameters_and_state`.
109111 def new_parameters (self , input_shapes , input_dtype , rng ):
110112 """Creates layer-specific parameters based on data shape, dtype and rng.
111113
@@ -144,7 +146,7 @@ def has_custom_grad(self):
144146 """Whether to use custom gradients (in which case, see below)."""
145147 return False
146148
147- def custom_grad (self , inputs , output , grad , params , ** kwargs ):
149+ def custom_grad (self , inputs , output , grad , params , state , ** kwargs ):
148150 """Custom backward pass to propagate gradients in a custom way.
149151
150152 Args:
@@ -153,6 +155,7 @@ def custom_grad(self, inputs, output, grad, params, **kwargs):
153155 grad: gradient signal (called cotangent in jax) computed based on
154156 subsequent layers. The structure and shape must match output.
155157 params: layer parameters
158+ state: start state.
156159 **kwargs: kwargs for the layer
157160
158161 Returns:
@@ -164,14 +167,15 @@ def custom_grad(self, inputs, output, grad, params, **kwargs):
164167
165168 # End of subclassing interface, all functions below are internal.
166169
167- def pseudo_call (self , pseudo_inputs , params ):
170+ def pseudo_call (self , pseudo_inputs , params , state ):
168171 """Computes shapes and types this layer would produce for the given inputs.
169172
170173 Args:
171174 pseudo_inputs: A ShapeType instance (input data minus the actual values)
172175 or a tuple of ShapeType instances, following the same conventions as
173176 Layer.call's input arg.
174177 params: Parameters for this layer.
178+ state: start state.
175179
176180 Returns:
177181 A ShapeType instance representing the shape and type of the output (if
@@ -183,12 +187,12 @@ def pseudo_call(self, pseudo_inputs, params):
183187 # cause a large number of dropout masks to be computed and permanently
184188 # stored in global memory.
185189 rng = ShapeType (shape = (2 ,), dtype = onp .uint32 )
186- def call_on_input (x , params , rng ):
187- return self .call (x , params = params , rng = rng )
190+ def call_on_input (x , params , state , rng ):
191+ return self .call (x , params = params , state = state , rng = rng )
188192 params_shapes = nested_map (
189193 params , lambda x : ShapeType (shape = x .shape , dtype = x .dtype ))
190194 s = backend .eval_on_shapes (call_on_input )(pseudo_inputs ,
191- params_shapes , rng )
195+ params_shapes , state , rng )
192196 return s
193197 except Exception :
194198 name , trace = self .__class__ .__name__ , _short_traceback (skip = 3 )
@@ -213,52 +217,74 @@ def initialize(self, input_shapes, input_dtype, rng):
213217 """
214218 try :
215219 # Initialize params once; store them for use when this layer is called.
220+ # Needs to call new_parameters regardless of _init_finished because state
221+ # also needs to be initialized. After jitting, graph pruning should be
222+ # able to remove unnecessary computation.
223+ # TODO(lukaszkaiser): Revisit this decision and see whether layers sharing
224+ # params should also share states.
225+ params , state = self .new_parameters (input_shapes , input_dtype , rng )
216226 if not self ._init_finished :
217- self ._params = self .new_parameters (input_shapes , input_dtype , rng )
218227 self ._init_finished = True
219- return self ._params
228+ self ._params = params
220229 else :
221- return ()
230+ params = ()
231+ return (params , state )
222232 except Exception :
223233 name , trace = self .__class__ .__name__ , _short_traceback (skip = 3 )
224234 raise LayerError (name , 'initialize' , self ._caller , input_shapes , trace )
225235
226- def __call__ (self , x , params = (), ** kwargs ):
236+ def __call__ (self , x , params = (), state = (), ** kwargs ):
227237 try :
228238 # If params are nothing, we may be reusing this layer.
229239 # Use the cached parameters to calculate the value.
230240 # Note: to make sure jit tracers can decide this branch in python we
231241 # use "params is ()" instead of, e.g., "not params" or "params == ()".
232242 if params is (): # pylint: disable=literal-comparison
233243 params = self ._params
234- # In this case, we're called for the first time: cache parameters.
235- self ._params = params
244+ else :
245+ # In this case, we're called for the first time: cache parameters.
246+ self ._params = params
236247
237248 if not self .has_custom_grad :
238- return self .call (x , params = params , ** kwargs )
249+ return self .call (x , params = params , state = state , ** kwargs )
239250
240251 # Custom gradients part.
241252 assert backend .get_name () == 'jax' , (
242253 'Custom gradients are only supported in JAX for now.' )
243254
255+ # TODO(wangpeng): JAX doesn't support custom grads for functions with
256+ # auxiliary output yet (https://github.com/google/jax/issues/844). Will
257+ # remove the constraints on state below when this feature is added to
258+ # JAX.
259+
260+ assert state is (), ( # pylint: disable=literal-comparison
261+ 'Custom gradients do not allow non-trivial start state.' )
262+
263+ def check_end_state (output_state ):
264+ output , state = output_state
265+ assert state is (), ( # pylint: disable=literal-comparison
266+ 'Custom gradients do not allow non-trivial end state.' )
267+ return output
268+
244269 # See this link for how custom transformations are defined in JAX:
245270 # https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
246271 # Note that we capture the kwargs and don't calculate gradients wrt. them.
247272 @jax .custom_transforms
248273 def do_call (y , params ):
249- return self .call (y , params = params , ** kwargs )
274+ return check_end_state ( self .call (y , params = params , state = (), ** kwargs ) )
250275
251276 # This is the custom gradient (vector-jacobian product in JAX) function.
252277 # For the exact specification of this custom transformation see this link:
253278 # https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
254279 def do_call_vjp (y , params ):
255- output = self .call (y , params = params , ** kwargs )
280+ output = check_end_state (self .call (y , params = params , state = (),
281+ ** kwargs ))
256282 def vjpfun (grad ):
257283 return self .custom_grad (y , output , grad , params , ** kwargs )
258284 return output , vjpfun
259285
260286 jax .defvjp_all (do_call , do_call_vjp )
261- return do_call (x , params )
287+ return do_call (x , params ), ()
262288
263289 except Exception :
264290 name , trace = self .__class__ .__name__ , _short_traceback ()
@@ -413,22 +439,23 @@ def _n_outputs(self):
413439
414440 def _new_parameters (self , input_shapes , input_dtype , rng ):
415441 if new_parameters is None :
416- return ()
442+ return (), ()
417443 kwargs = self ._init_kwargs # pylint: disable=protected-access
418- return new_parameters (input_shapes , input_dtype , rng , ** kwargs )
444+ return new_parameters (input_shapes , input_dtype , rng , ** kwargs ), ()
419445
420446 def _is_empty (raw_output ):
421447 return raw_output is None or (isinstance (raw_output , (list , tuple ))
422448 and len (raw_output ) == 0 ) # pylint: disable=g-explicit-length-test
423449
424- def _call_with_context (self , x , params = (), ** kwargs ):
450+ def _call_with_context (self , x , params = (), state = (), ** kwargs ):
425451 """Calls raw_call_fn with extra keyword args from Layer.__init__."""
426452 merged_kwargs = kwargs .copy ()
427453 merged_kwargs .update (self ._init_kwargs ) # pylint: disable=protected-access
428454
429455 _validate_call_input (x , n_inputs )
430456 raw_output = raw_call_fn (x , params = params , ** merged_kwargs )
431- return () if _is_empty (raw_output ) else raw_output
457+ output = () if _is_empty (raw_output ) else raw_output
458+ return (output , state )
432459
433460 # Set docstrings and create the class.
434461 _call_with_context .__doc__ = raw_call_fn .__doc__
@@ -502,15 +529,15 @@ def check_shape_agreement(layer_fn, input_shapes, integer_inputs=False):
502529 input_dtype = tuple (input_dtype for _ in input_shapes )
503530 else :
504531 pseudo_data = ShapeType (input_shapes , input_dtype )
505- params = layer_fn .initialize (input_shapes , input_dtype , rng1 )
506- pseudo_output = layer_fn .pseudo_call (pseudo_data , params )
532+ params , state = layer_fn .initialize (input_shapes , input_dtype , rng1 )
533+ pseudo_output , _ = layer_fn .pseudo_call (pseudo_data , params , state )
507534 if isinstance (pseudo_output , tuple ):
508535 output_shape = tuple (x .shape for x in pseudo_output )
509536 else :
510537 output_shape = pseudo_output .shape
511538
512539 random_input = _random_values (input_shapes , rng2 , integer_inputs )
513- real_output = layer_fn (random_input , params , rng = rng3 )
540+ real_output , _ = layer_fn (random_input , params , state = state , rng = rng3 )
514541 result_shape = shapes (real_output )
515542
516543 msg = 'output shape %s != real result shape %s' % (output_shape , result_shape )
0 commit comments