Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 65c2178

Browse files
jaingauravcopybara-github
authored andcommitted
Remove functional_ops based sru
PiperOrigin-RevId: 307958075
1 parent b1bebfb commit 65c2178

File tree

1 file changed

+6
-94
lines changed

1 file changed

+6
-94
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 6 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -2212,12 +2212,12 @@ def gated_linear_unit_layer(x, name=None):
22122212
return x * tf.nn.sigmoid(gating_x)
22132213

22142214

2215-
def sru_with_scan(x,
2216-
num_layers=2,
2217-
activation=None,
2218-
initial_state=None,
2219-
name=None,
2220-
reuse=None):
2215+
def sru(x,
2216+
num_layers=2,
2217+
activation=None,
2218+
initial_state=None,
2219+
name=None,
2220+
reuse=None):
22212221
"""SRU cell as in https://arxiv.org/abs/1709.02755.
22222222
22232223
This implementation uses tf.scan and can incur overhead, see the full SRU
@@ -2275,94 +2275,6 @@ def next_state(cur_state, args_tup):
22752275
return tf.reshape(x, x_shape)
22762276

22772277

2278-
class CumsumprodCell(object):
2279-
"""Cumulative sum and product object for use with functional_rnn API."""
2280-
2281-
def __init__(self, initializer):
2282-
self._initializer = initializer
2283-
2284-
@property
2285-
def output_size(self):
2286-
return int(shape_list(self._initializer)[-1])
2287-
2288-
def zero_state(self, batch_size, dtype):
2289-
dtype = dtype or tf.float32
2290-
return tf.zeros([batch_size, self.output_size], dtype=dtype)
2291-
2292-
def __call__(self, inputs_t, state_t):
2293-
cur_x_times_one_minus_f, cur_f = tf.split(inputs_t, 2, axis=-1)
2294-
state_next = cur_f * state_t + cur_x_times_one_minus_f
2295-
outputs_t = state_next
2296-
return outputs_t, state_next
2297-
2298-
2299-
def sru(x,
2300-
num_layers=2,
2301-
activation=None,
2302-
initial_state=None,
2303-
name=None,
2304-
reuse=None):
2305-
"""SRU cell as in https://arxiv.org/abs/1709.02755.
2306-
2307-
As defined in the paper:
2308-
(1) x'_t = W x_t
2309-
(2) f_t = sigmoid(Wf x_t + bf)
2310-
(3) r_t = sigmoid(Wr x_t + br)
2311-
(4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t
2312-
(5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t
2313-
2314-
This version uses functional ops to be faster on GPUs with TF-1.9+.
2315-
2316-
Args:
2317-
x: A tensor of shape [batch, ..., channels] ; ... is treated as time.
2318-
num_layers: How many SRU layers; default is 2 as results for 1 disappoint.
2319-
activation: Optional activation function, try tf.nn.tanh or tf.nn.relu.
2320-
initial_state: Optional initial c-state, set to zeros if None.
2321-
name: Optional name, "sru" by default.
2322-
reuse: Optional reuse.
2323-
2324-
Returns:
2325-
A tensor of the same shape as x.
2326-
2327-
Raises:
2328-
ValueError: if num_layers is not positive.
2329-
"""
2330-
if num_layers < 1:
2331-
raise ValueError("Number of layers must be positive: %d" % num_layers)
2332-
if is_xla_compiled(): # On TPU the XLA does a good job with while.
2333-
return sru_with_scan(x, num_layers, activation, initial_state, name, reuse)
2334-
try:
2335-
from tensorflow.contrib.recurrent.python.ops import functional_rnn # pylint: disable=g-import-not-at-top
2336-
except ImportError:
2337-
tf.logging.info("functional_rnn not found, using sru_with_scan instead")
2338-
return sru_with_scan(x, num_layers, activation, initial_state, name, reuse)
2339-
2340-
with tf.variable_scope(name, default_name="sru", values=[x], reuse=reuse):
2341-
# We assume x is [batch, ..., channels] and treat all ... as time.
2342-
x_shape = shape_list(x)
2343-
x = tf.reshape(x, [x_shape[0], -1, x_shape[-1]])
2344-
initial_state = initial_state or tf.zeros([x_shape[0], x_shape[-1]])
2345-
cell = CumsumprodCell(initial_state)
2346-
# Calculate SRU on each layer.
2347-
for i in range(num_layers):
2348-
# The parallel part of the SRU.
2349-
x_orig = x
2350-
x, f, r = tf.split(
2351-
layers().Dense(3 * x_shape[-1], name="kernel_%d" % i)(x), 3, axis=-1)
2352-
f, r = tf.sigmoid(f), tf.sigmoid(r)
2353-
x_times_one_minus_f = x * (1.0 - f) # Compute in parallel for speed.
2354-
# Calculate states.
2355-
concat = tf.concat([x_times_one_minus_f, f], axis=-1)
2356-
c_states, _ = functional_rnn.functional_rnn(
2357-
cell, concat, time_major=False)
2358-
# Final output.
2359-
if activation is not None:
2360-
c_states = activation(c_states)
2361-
h = c_states * r + (1.0 - r) * x_orig
2362-
x = h # Next layer.
2363-
return tf.reshape(x, x_shape)
2364-
2365-
23662278
def linear_set_layer(layer_size,
23672279
inputs,
23682280
context=None,

0 commit comments

Comments
 (0)