@@ -2212,12 +2212,12 @@ def gated_linear_unit_layer(x, name=None):
22122212 return x * tf .nn .sigmoid (gating_x )
22132213
22142214
2215- def sru_with_scan (x ,
2216- num_layers = 2 ,
2217- activation = None ,
2218- initial_state = None ,
2219- name = None ,
2220- reuse = None ):
2215+ def sru (x ,
2216+ num_layers = 2 ,
2217+ activation = None ,
2218+ initial_state = None ,
2219+ name = None ,
2220+ reuse = None ):
22212221 """SRU cell as in https://arxiv.org/abs/1709.02755.
22222222
22232223 This implementation uses tf.scan and can incur overhead, see the full SRU
@@ -2275,94 +2275,6 @@ def next_state(cur_state, args_tup):
22752275 return tf .reshape (x , x_shape )
22762276
22772277
2278- class CumsumprodCell (object ):
2279- """Cumulative sum and product object for use with functional_rnn API."""
2280-
2281- def __init__ (self , initializer ):
2282- self ._initializer = initializer
2283-
2284- @property
2285- def output_size (self ):
2286- return int (shape_list (self ._initializer )[- 1 ])
2287-
2288- def zero_state (self , batch_size , dtype ):
2289- dtype = dtype or tf .float32
2290- return tf .zeros ([batch_size , self .output_size ], dtype = dtype )
2291-
2292- def __call__ (self , inputs_t , state_t ):
2293- cur_x_times_one_minus_f , cur_f = tf .split (inputs_t , 2 , axis = - 1 )
2294- state_next = cur_f * state_t + cur_x_times_one_minus_f
2295- outputs_t = state_next
2296- return outputs_t , state_next
2297-
2298-
2299- def sru (x ,
2300- num_layers = 2 ,
2301- activation = None ,
2302- initial_state = None ,
2303- name = None ,
2304- reuse = None ):
2305- """SRU cell as in https://arxiv.org/abs/1709.02755.
2306-
2307- As defined in the paper:
2308- (1) x'_t = W x_t
2309- (2) f_t = sigmoid(Wf x_t + bf)
2310- (3) r_t = sigmoid(Wr x_t + br)
2311- (4) c_t = f_t * c_{t-1} + (1 - f_t) * x'_t
2312- (5) h_t = r_t * activation(c_t) + (1 - r_t) * x_t
2313-
2314- This version uses functional ops to be faster on GPUs with TF-1.9+.
2315-
2316- Args:
2317- x: A tensor of shape [batch, ..., channels] ; ... is treated as time.
2318- num_layers: How many SRU layers; default is 2 as results for 1 disappoint.
2319- activation: Optional activation function, try tf.nn.tanh or tf.nn.relu.
2320- initial_state: Optional initial c-state, set to zeros if None.
2321- name: Optional name, "sru" by default.
2322- reuse: Optional reuse.
2323-
2324- Returns:
2325- A tensor of the same shape as x.
2326-
2327- Raises:
2328- ValueError: if num_layers is not positive.
2329- """
2330- if num_layers < 1 :
2331- raise ValueError ("Number of layers must be positive: %d" % num_layers )
2332- if is_xla_compiled (): # On TPU the XLA does a good job with while.
2333- return sru_with_scan (x , num_layers , activation , initial_state , name , reuse )
2334- try :
2335- from tensorflow .contrib .recurrent .python .ops import functional_rnn # pylint: disable=g-import-not-at-top
2336- except ImportError :
2337- tf .logging .info ("functional_rnn not found, using sru_with_scan instead" )
2338- return sru_with_scan (x , num_layers , activation , initial_state , name , reuse )
2339-
2340- with tf .variable_scope (name , default_name = "sru" , values = [x ], reuse = reuse ):
2341- # We assume x is [batch, ..., channels] and treat all ... as time.
2342- x_shape = shape_list (x )
2343- x = tf .reshape (x , [x_shape [0 ], - 1 , x_shape [- 1 ]])
2344- initial_state = initial_state or tf .zeros ([x_shape [0 ], x_shape [- 1 ]])
2345- cell = CumsumprodCell (initial_state )
2346- # Calculate SRU on each layer.
2347- for i in range (num_layers ):
2348- # The parallel part of the SRU.
2349- x_orig = x
2350- x , f , r = tf .split (
2351- layers ().Dense (3 * x_shape [- 1 ], name = "kernel_%d" % i )(x ), 3 , axis = - 1 )
2352- f , r = tf .sigmoid (f ), tf .sigmoid (r )
2353- x_times_one_minus_f = x * (1.0 - f ) # Compute in parallel for speed.
2354- # Calculate states.
2355- concat = tf .concat ([x_times_one_minus_f , f ], axis = - 1 )
2356- c_states , _ = functional_rnn .functional_rnn (
2357- cell , concat , time_major = False )
2358- # Final output.
2359- if activation is not None :
2360- c_states = activation (c_states )
2361- h = c_states * r + (1.0 - r ) * x_orig
2362- x = h # Next layer.
2363- return tf .reshape (x , x_shape )
2364-
2365-
23662278def linear_set_layer (layer_size ,
23672279 inputs ,
23682280 context = None ,
0 commit comments