@@ -14,12 +14,17 @@ namespace Tensorflow
1414 /// Basic LSTM recurrent network cell.
1515 /// The implementation is based on: http://arxiv.org/abs/1409.2329.
1616 /// </summary>
17- public class BasicLSTMCell : LayerRnnCell
17+ public class BasicLstmCell : LayerRnnCell
1818 {
1919 int _num_units ;
2020 float _forget_bias ;
2121 bool _state_is_tuple ;
2222 IActivation _activation ;
23+ LSTMStateTuple _state ;
24+ VariableV1 _kernel ;
25+ VariableV1 _bias ;
26+ string _WEIGHTS_VARIABLE_NAME = "kernel" ;
27+ string _BIAS_VARIABLE_NAME = "bias" ;
2328
2429 /// <summary>
2530 /// Initialize the basic LSTM cell.
@@ -31,7 +36,7 @@ public class BasicLSTMCell : LayerRnnCell
3136 /// <param name="reuse"></param>
3237 /// <param name="name"></param>
3338 /// <param name="dtype"></param>
34- public BasicLSTMCell ( int num_units , float forget_bias = 1.0f , bool state_is_tuple = true ,
39+ public BasicLstmCell ( int num_units , float forget_bias = 1.0f , bool state_is_tuple = true ,
3540 IActivation activation = null , bool ? reuse = null , string name = null ,
3641 TF_DataType dtype = TF_DataType . DtInvalid ) : base ( _reuse : reuse , name : name , dtype : dtype )
3742 {
@@ -44,13 +49,123 @@ public BasicLSTMCell(int num_units, float forget_bias = 1.0f, bool state_is_tupl
4449 _activation = tf . nn . tanh ( ) ;
4550 }
4651
47- public LSTMStateTuple state_size
52+ protected override void build ( TensorShape input_shape )
53+ {
54+ var input_depth = input_shape . dims . Last ( ) ;
55+ var h_depth = _num_units ;
56+ _kernel = add_weight ( _WEIGHTS_VARIABLE_NAME ,
57+ shape : new [ ] { input_depth + h_depth , 4 * _num_units } ) ;
58+ _bias = add_weight ( _BIAS_VARIABLE_NAME ,
59+ shape : new [ ] { 4 * _num_units } ,
60+ initializer : tf . zeros_initializer ) ;
61+ built = true ;
62+ }
63+
64+ public Tensor [ ] __call__ ( Tensor inputs , LSTMStateTuple state )
65+ {
66+ _state = state ;
67+ return base . __call__ ( inputs ) ;
68+ }
69+
70+ /// <summary>
71+ /// Long short-term memory cell (LSTM).
72+ /// </summary>
73+ /// <param name="inputs"></param>
74+ /// <param name="training"></param>
75+ /// <param name="state"></param>
76+ /// <returns></returns>
77+ protected override Tensor [ ] call ( Tensor inputs , Tensor training = null , Tensor state = null )
78+ {
79+ var one = constant_op . constant ( 1 , dtype : dtypes . int32 ) ;
80+ // Parameters of gates are concatenated into one multiply for efficiency.
81+ Tensor c = null ;
82+ Tensor h = null ;
83+ if ( _state_is_tuple )
84+ ( c , h ) = ( ( Tensor ) _state . c , ( Tensor ) _state . h ) ;
85+ else
86+ {
87+ // array_ops.split(value: state, num_or_size_splits: 2, axis: one);
88+ throw new NotImplementedException ( "BasicLstmCell call" ) ;
89+ }
90+ var gate_inputs = math_ops . matmul ( array_ops . concat ( new [ ] { inputs , h } , 1 ) , _kernel as RefVariable ) ;
91+ gate_inputs = nn_ops . bias_add ( gate_inputs , _bias as RefVariable ) ;
92+
93+ // i = input_gate, j = new_input, f = forget_gate, o = output_gate
94+ var tensors = array_ops . split ( value : gate_inputs , num_or_size_splits : 4 , axis : one ) ;
95+ var ( i , j , f , o ) = ( tensors [ 0 ] , tensors [ 1 ] , tensors [ 2 ] , tensors [ 3 ] ) ;
96+
97+ var forget_bias_tensor = constant_op . constant ( _forget_bias , dtype : f . dtype ) ;
98+ // Note that using `add` and `multiply` instead of `+` and `*` gives a
99+ // performance improvement. So using those at the cost of readability.
100+ var new_c = gen_math_ops . add (
101+ math_ops . multiply ( c , math_ops . sigmoid ( gen_math_ops . add ( f , forget_bias_tensor ) ) ) ,
102+ math_ops . multiply ( math_ops . sigmoid ( i ) , _activation . Activate ( j ) ) ) ;
103+
104+ var new_h = math_ops . multiply ( _activation . Activate ( new_c ) , math_ops . sigmoid ( o ) ) ;
105+
106+
107+ if ( _state_is_tuple )
108+ return new [ ] { new_c , new_h } ;
109+ else
110+ return new [ ] { array_ops . concat ( new [ ] { new_c , new_h } , 1 ) } ;
111+ }
112+
113+ public override object get_initial_state ( Tensor inputs = null , Tensor batch_size = null , TF_DataType dtype = TF_DataType . DtInvalid )
114+ {
115+ if ( inputs != null )
116+ throw new NotImplementedException ( "get_initial_state input is not null" ) ;
117+
118+ return zero_state ( batch_size , dtype ) ;
119+ }
120+
121+ /// <summary>
122+ /// Return zero-filled state tensor(s).
123+ /// </summary>
124+ /// <param name="batch_size"></param>
125+ /// <param name="dtype"></param>
126+ /// <returns></returns>
127+ private LSTMStateTuple zero_state ( Tensor batch_size , TF_DataType dtype )
128+ {
129+ LSTMStateTuple output = null ;
130+ tf_with ( ops . name_scope ( $ "{ GetType ( ) . Name } ZeroState", values : new { batch_size } ) , delegate
131+ {
132+ output = _zero_state_tensors ( state_size , batch_size , dtype ) ;
133+ } ) ;
134+
135+ return output ;
136+ }
137+
138+ private LSTMStateTuple _zero_state_tensors ( object state_size , Tensor batch_size , TF_DataType dtype )
139+ {
140+ if ( state_size is LSTMStateTuple state_size_tuple )
141+ {
142+ var outputs = state_size_tuple . Flatten ( )
143+ . Select ( x => ( int ) x )
144+ . Select ( s =>
145+ {
146+ var c = rnn_cell_impl . _concat ( batch_size , s ) ;
147+ var size = array_ops . zeros ( c , dtype : dtype ) ;
148+
149+ var c_static = rnn_cell_impl . _concat ( batch_size , s , @static : true ) ;
150+ size . set_shape ( c_static ) ;
151+
152+ return size ;
153+ } ) . ToArray ( ) ;
154+
155+ return new LSTMStateTuple ( outputs [ 0 ] , outputs [ 1 ] ) ;
156+ }
157+
158+ throw new NotImplementedException ( "_zero_state_tensors" ) ;
159+ }
160+
161+ public override object state_size
48162 {
49163 get
50164 {
51- return _state_is_tuple ?
52- new LSTMStateTuple ( _num_units , _num_units ) :
53- ( LSTMStateTuple ) ( 2 * _num_units ) ;
165+ if ( _state_is_tuple )
166+ return new LSTMStateTuple ( _num_units , _num_units ) ;
167+ else
168+ return 2 * _num_units ;
54169 }
55170 }
56171 }
0 commit comments