Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 4de9708

Browse files
wangpengmitcopybara-github
authored andcommitted
Add state to layers
PiperOrigin-RevId: 264644849
1 parent 46cf96b commit 4de9708

23 files changed

+546
-317
lines changed

tensor2tensor/envs/env_problem_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def play_env_problem_with_policy(env,
5757
num_trajectories=1,
5858
max_timestep=None,
5959
reset=True,
60+
state=None,
6061
rng=None,
6162
temperature=1.0,
6263
boundary=32,
@@ -73,7 +74,8 @@ def play_env_problem_with_policy(env,
7374
trajectory that exceeds this time put it in the completed bin, and *dont*
7475
reset the env.
7576
reset: bool, true if we want to reset the envs. The envs are also reset if
76-
max_max_timestep is None or < 0
77+
max_max_timestep is None or < 0.
78+
state: the state for `policy_fn`.
7779
rng: jax rng, splittable.
7880
temperature: float, temperature used in Gumbel sampling.
7981
boundary: int, pad the sequences to the multiples of this number.
@@ -118,8 +120,8 @@ def gumbel_sample(log_probs):
118120
assert (B,) == lengths.shape
119121

120122
t1 = time.time()
121-
log_prob_actions, value_predictions, rng = policy_fun(
122-
padded_observations, rng=rng)
123+
log_prob_actions, value_predictions, state, rng = policy_fun(
124+
padded_observations, state=state, rng=rng)
123125
policy_application_total_time += (time.time() - t1)
124126

125127
assert (B, T) == log_prob_actions.shape[:2]
@@ -192,7 +194,7 @@ def gumbel_sample(log_probs):
192194
}
193195
timing_info = {k: round(1000 * v, 2) for k, v in timing_info.items()}
194196

195-
return completed_trajectories, num_done_trajectories, timing_info
197+
return completed_trajectories, num_done_trajectories, timing_info, state
196198

197199

198200
def make_env(batch_size=1,

tensor2tensor/envs/env_problem_utils_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_play_env_problem_with_policy(self):
5353
# Let's make sure that at-most 4 observations come to the policy function.
5454
len_history_for_policy = 4
5555

56-
def policy_fun(observations, rng=None):
56+
def policy_fun(observations, state=None, rng=None):
5757
b, t = observations.shape[:2]
5858
# Assert that observations from time-step len_history_for_policy onwards
5959
# are zeros.
@@ -65,11 +65,11 @@ def policy_fun(observations, rng=None):
6565
p = np.random.uniform(size=(b, t, a))
6666
p = np.exp(p)
6767
p = p / np.sum(p, axis=-1, keepdims=True)
68-
return np.log(p), np.log(p), rng
68+
return np.log(p), np.log(p), state, rng
6969

7070
max_timestep = 15
7171
num_trajectories = 2
72-
trajectories, _, _ = env_problem_utils.play_env_problem_with_policy(
72+
trajectories, _, _, _ = env_problem_utils.play_env_problem_with_policy(
7373
env,
7474
policy_fun,
7575
num_trajectories=num_trajectories,

tensor2tensor/trax/layers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def AttentionQKV(d_feature, n_heads=1, dropout=0.0, mode='train'):
169169
core.Dense(d_feature),
170170
),
171171
PureAttention( # pylint: disable=no-value-for-parameter
172-
d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode),
172+
n_heads=n_heads, dropout=dropout, mode=mode),
173173
core.Dense(d_feature),
174174
]
175175

tensor2tensor/trax/layers/attention_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def test_shift_right(self):
3030
# Test shifts right on axis=1
3131
layer = attention.ShiftRight()
3232
input_np = onp.arange(2*3*3).reshape(2, 3, 3)
33-
output_np = layer(input_np)
33+
output_np, _ = layer(input_np)
3434
self.assertEqual(input_np.shape, output_np.shape)
3535
self.assertAllEqual(onp.array([[[0, 0, 0],
3636
[0, 1, 2],
@@ -49,7 +49,7 @@ def test_shift_right_float(self):
4949
input_np /= 2.0
5050
self.assertEqual(input_np.dtype, onp.float32)
5151

52-
output_np = layer(input_np)
52+
output_np, _ = layer(input_np)
5353
self.assertEqual(input_np.shape, output_np.shape)
5454
self.assertEqual(output_np.dtype, onp.float32)
5555

tensor2tensor/trax/layers/base.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __repr__(self):
8181
else:
8282
return '{}[{}]'.format(class_str, fields_str)
8383

84-
def call(self, inputs, params=(), **kwargs):
84+
def call(self, inputs, params=(), state=(), **kwargs):
8585
"""Applies this layer to given activation tensors, using trainable params.
8686
8787
Args:
@@ -94,6 +94,7 @@ def call(self, inputs, params=(), **kwargs):
9494
and one for each of this layer's sublayers. If a layer (or sublayer)
9595
has no trainable parameters, the corresponding params element is an
9696
empty tuple.
97+
state: start state.
9798
**kwargs: Layer-specific keyword args.
9899
99100
Returns:
@@ -106,6 +107,7 @@ def call(self, inputs, params=(), **kwargs):
106107
"""
107108
raise NotImplementedError
108109

110+
# TODO(wangpeng): Should be called `new_parameters_and_state`.
109111
def new_parameters(self, input_shapes, input_dtype, rng):
110112
"""Creates layer-specific parameters based on data shape, dtype and rng.
111113
@@ -144,7 +146,7 @@ def has_custom_grad(self):
144146
"""Whether to use custom gradients (in which case, see below)."""
145147
return False
146148

147-
def custom_grad(self, inputs, output, grad, params, **kwargs):
149+
def custom_grad(self, inputs, output, grad, params, state, **kwargs):
148150
"""Custom backward pass to propagate gradients in a custom way.
149151
150152
Args:
@@ -153,6 +155,7 @@ def custom_grad(self, inputs, output, grad, params, **kwargs):
153155
grad: gradient signal (called cotangent in jax) computed based on
154156
subsequent layers. The structure and shape must match output.
155157
params: layer parameters
158+
state: start state.
156159
**kwargs: kwargs for the layer
157160
158161
Returns:
@@ -164,14 +167,15 @@ def custom_grad(self, inputs, output, grad, params, **kwargs):
164167

165168
# End of subclassing interface, all functions below are internal.
166169

167-
def pseudo_call(self, pseudo_inputs, params):
170+
def pseudo_call(self, pseudo_inputs, params, state):
168171
"""Computes shapes and types this layer would produce for the given inputs.
169172
170173
Args:
171174
pseudo_inputs: A ShapeType instance (input data minus the actual values)
172175
or a tuple of ShapeType instances, following the same conventions as
173176
Layer.call's input arg.
174177
params: Parameters for this layer.
178+
state: start state.
175179
176180
Returns:
177181
A ShapeType instance representing the shape and type of the output (if
@@ -183,12 +187,12 @@ def pseudo_call(self, pseudo_inputs, params):
183187
# cause a large number of dropout masks to be computed and permanently
184188
# stored in global memory.
185189
rng = ShapeType(shape=(2,), dtype=onp.uint32)
186-
def call_on_input(x, params, rng):
187-
return self.call(x, params=params, rng=rng)
190+
def call_on_input(x, params, state, rng):
191+
return self.call(x, params=params, state=state, rng=rng)
188192
params_shapes = nested_map(
189193
params, lambda x: ShapeType(shape=x.shape, dtype=x.dtype))
190194
s = backend.eval_on_shapes(call_on_input)(pseudo_inputs,
191-
params_shapes, rng)
195+
params_shapes, state, rng)
192196
return s
193197
except Exception:
194198
name, trace = self.__class__.__name__, _short_traceback(skip=3)
@@ -213,52 +217,74 @@ def initialize(self, input_shapes, input_dtype, rng):
213217
"""
214218
try:
215219
# Initialize params once; store them for use when this layer is called.
220+
# Needs to call new_parameters regardless of _init_finished because state
221+
# also needs to be initialized. After jitting, graph pruning should be
222+
# able to remove unnecessary computation.
223+
# TODO(lukaszkaiser): Revisit this decision and see whether layers sharing
224+
# params should also share states.
225+
params, state = self.new_parameters(input_shapes, input_dtype, rng)
216226
if not self._init_finished:
217-
self._params = self.new_parameters(input_shapes, input_dtype, rng)
218227
self._init_finished = True
219-
return self._params
228+
self._params = params
220229
else:
221-
return ()
230+
params = ()
231+
return (params, state)
222232
except Exception:
223233
name, trace = self.__class__.__name__, _short_traceback(skip=3)
224234
raise LayerError(name, 'initialize', self._caller, input_shapes, trace)
225235

226-
def __call__(self, x, params=(), **kwargs):
236+
def __call__(self, x, params=(), state=(), **kwargs):
227237
try:
228238
# If params are nothing, we may be reusing this layer.
229239
# Use the cached parameters to calculate the value.
230240
# Note: to make sure jit tracers can decide this branch in python we
231241
# use "params is ()" instead of, e.g., "not params" or "params == ()".
232242
if params is (): # pylint: disable=literal-comparison
233243
params = self._params
234-
# In this case, we're called for the first time: cache parameters.
235-
self._params = params
244+
else:
245+
# In this case, we're called for the first time: cache parameters.
246+
self._params = params
236247

237248
if not self.has_custom_grad:
238-
return self.call(x, params=params, **kwargs)
249+
return self.call(x, params=params, state=state, **kwargs)
239250

240251
# Custom gradients part.
241252
assert backend.get_name() == 'jax', (
242253
'Custom gradients are only supported in JAX for now.')
243254

255+
# TODO(wangpeng): JAX doesn't support custom grads for functions with
256+
# auxiliary output yet (https://github.com/google/jax/issues/844). Will
257+
# remove the constraints on state below when this feature is added to
258+
# JAX.
259+
260+
assert state is (), ( # pylint: disable=literal-comparison
261+
'Custom gradients do not allow non-trivial start state.')
262+
263+
def check_end_state(output_state):
264+
output, state = output_state
265+
assert state is (), ( # pylint: disable=literal-comparison
266+
'Custom gradients do not allow non-trivial end state.')
267+
return output
268+
244269
# See this link for how custom transformations are defined in JAX:
245270
# https://jax.readthedocs.io/en/latest/jax.html#jax.custom_transforms
246271
# Note that we capture the kwargs and don't calculate gradients wrt. them.
247272
@jax.custom_transforms
248273
def do_call(y, params):
249-
return self.call(y, params=params, **kwargs)
274+
return check_end_state(self.call(y, params=params, state=(), **kwargs))
250275

251276
# This is the custom gradient (vector-jacobian product in JAX) function.
252277
# For the exact specification of this custom transformation see this link:
253278
# https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp_all
254279
def do_call_vjp(y, params):
255-
output = self.call(y, params=params, **kwargs)
280+
output = check_end_state(self.call(y, params=params, state=(),
281+
**kwargs))
256282
def vjpfun(grad):
257283
return self.custom_grad(y, output, grad, params, **kwargs)
258284
return output, vjpfun
259285

260286
jax.defvjp_all(do_call, do_call_vjp)
261-
return do_call(x, params)
287+
return do_call(x, params), ()
262288

263289
except Exception:
264290
name, trace = self.__class__.__name__, _short_traceback()
@@ -413,22 +439,23 @@ def _n_outputs(self):
413439

414440
def _new_parameters(self, input_shapes, input_dtype, rng):
415441
if new_parameters is None:
416-
return ()
442+
return (), ()
417443
kwargs = self._init_kwargs # pylint: disable=protected-access
418-
return new_parameters(input_shapes, input_dtype, rng, **kwargs)
444+
return new_parameters(input_shapes, input_dtype, rng, **kwargs), ()
419445

420446
def _is_empty(raw_output):
421447
return raw_output is None or (isinstance(raw_output, (list, tuple))
422448
and len(raw_output) == 0) # pylint: disable=g-explicit-length-test
423449

424-
def _call_with_context(self, x, params=(), **kwargs):
450+
def _call_with_context(self, x, params=(), state=(), **kwargs):
425451
"""Calls raw_call_fn with extra keyword args from Layer.__init__."""
426452
merged_kwargs = kwargs.copy()
427453
merged_kwargs.update(self._init_kwargs) # pylint: disable=protected-access
428454

429455
_validate_call_input(x, n_inputs)
430456
raw_output = raw_call_fn(x, params=params, **merged_kwargs)
431-
return () if _is_empty(raw_output) else raw_output
457+
output = () if _is_empty(raw_output) else raw_output
458+
return (output, state)
432459

433460
# Set docstrings and create the class.
434461
_call_with_context.__doc__ = raw_call_fn.__doc__
@@ -502,15 +529,15 @@ def check_shape_agreement(layer_fn, input_shapes, integer_inputs=False):
502529
input_dtype = tuple(input_dtype for _ in input_shapes)
503530
else:
504531
pseudo_data = ShapeType(input_shapes, input_dtype)
505-
params = layer_fn.initialize(input_shapes, input_dtype, rng1)
506-
pseudo_output = layer_fn.pseudo_call(pseudo_data, params)
532+
params, state = layer_fn.initialize(input_shapes, input_dtype, rng1)
533+
pseudo_output, _ = layer_fn.pseudo_call(pseudo_data, params, state)
507534
if isinstance(pseudo_output, tuple):
508535
output_shape = tuple(x.shape for x in pseudo_output)
509536
else:
510537
output_shape = pseudo_output.shape
511538

512539
random_input = _random_values(input_shapes, rng2, integer_inputs)
513-
real_output = layer_fn(random_input, params, rng=rng3)
540+
real_output, _ = layer_fn(random_input, params, state=state, rng=rng3)
514541
result_shape = shapes(real_output)
515542

516543
msg = 'output shape %s != real result shape %s' % (output_shape, result_shape)

tensor2tensor/trax/layers/base_test.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ class IdWithZeroGrad(base.Layer):
4040

4141
def call(self, x, params, **kwargs):
4242
del params, kwargs
43-
return x
43+
return x, ()
4444

4545
def new_parameters(self, input_shapes, input_dtype, rng):
4646
del input_shapes, input_dtype, rng
47-
return ()
47+
return (), ()
4848

4949
@property
5050
def has_custom_grad(self):
@@ -59,7 +59,7 @@ def custom_grad(self, inputs, output, ct, params, **kwargs):
5959
input_shape = (9, 17)
6060
random_input = backend.random.uniform(rng, input_shape, minval=-1.0,
6161
maxval=1.0)
62-
f = lambda x: backend.numpy.mean(layer(x, params, rng=rng))
62+
f = lambda x: backend.numpy.mean(layer(x, params, rng=rng)[0])
6363
grad = backend.grad(f)(random_input)
6464
self.assertEqual(grad.shape, input_shape) # Gradient for each input.
6565
self.assertEqual(sum(sum(grad * grad)), 0.0) # Each one is 0.
@@ -70,11 +70,11 @@ class IdWithIdGrad(base.Layer):
7070

7171
def call(self, x, params, **kwargs):
7272
del params, kwargs
73-
return x
73+
return x, ()
7474

7575
def new_parameters(self, input_shapes, input_dtype, rng):
7676
del input_shapes, input_dtype, rng
77-
return ()
77+
return (), ()
7878

7979
@property
8080
def has_custom_grad(self):
@@ -89,7 +89,7 @@ def custom_grad(self, inputs, output, ct, params, **kwargs):
8989
input_shape = (9, 17)
9090
random_input = backend.random.uniform(rng, input_shape, minval=-1.0,
9191
maxval=1.0)
92-
f = lambda x: backend.numpy.mean(layer(x, params, rng=rng))
92+
f = lambda x: backend.numpy.mean(layer(x, params, rng=rng)[0])
9393
grad = backend.grad(f)(random_input)
9494
self.assertEqual(grad.shape, input_shape) # Gradient for each input.
9595
self.assertEqual(sum(sum(grad)), sum(sum(random_input))) # Same as input.

0 commit comments

Comments
 (0)