@@ -18,13 +18,106 @@ limitations under the License.
1818using System ;
1919using System . Collections . Generic ;
2020using System . Linq ;
21+ using Tensorflow . Framework ;
2122using Tensorflow . Util ;
2223using static Tensorflow . Binding ;
2324
2425namespace Tensorflow . Operations
2526{
26- internal class rnn
27+ public class rnn
2728 {
29+ /// <summary>
30+ /// Creates a bidirectional recurrent neural network.
31+ /// </summary>
32+ public static void static_bidirectional_rnn ( BasicLSTMCell cell_fw ,
33+ BasicLSTMCell cell_bw ,
34+ Tensor [ ] inputs ,
35+ Tensor initial_state_fw = null ,
36+ Tensor initial_state_bw = null ,
37+ TF_DataType dtype = TF_DataType . DtInvalid ,
38+ Tensor sequence_length = null ,
39+ string scope = null )
40+ {
41+ if ( inputs == null || inputs . Length == 0 )
42+ throw new ValueError ( "inputs must not be empty" ) ;
43+
44+ tf_with ( tf . variable_scope ( scope ?? "bidirectional_rnn" ) , delegate
45+ {
46+ // Forward direction
47+ tf_with ( tf . variable_scope ( "fw" ) , fw_scope =>
48+ {
49+ static_rnn (
50+ cell_fw ,
51+ inputs ,
52+ initial_state_fw ,
53+ dtype ,
54+ sequence_length ,
55+ scope : fw_scope ) ;
56+ } ) ;
57+ } ) ;
58+ }
59+
60+ public static void static_rnn ( BasicLSTMCell cell ,
61+ Tensor [ ] inputs ,
62+ Tensor initial_state ,
63+ TF_DataType dtype = TF_DataType . DtInvalid ,
64+ Tensor sequence_length = null ,
65+ VariableScope scope = null )
66+ {
67+ // Create a new scope in which the caching device is either
68+ // determined by the parent scope, or is set to place the cached
69+ // Variable using the same placement as for the rest of the RNN.
70+ if ( scope == null )
71+ tf_with ( tf . variable_scope ( "rnn" ) , varscope =>
72+ {
73+ throw new NotImplementedException ( "static_rnn" ) ;
74+ } ) ;
75+ else
76+ tf_with ( tf . variable_scope ( scope ) , varscope =>
77+ {
78+ Dimension fixed_batch_size = null ;
79+ Dimension batch_size = null ;
80+ Tensor batch_size_tensor = null ;
81+
82+ // Obtain the first sequence of the input
83+ var first_input = inputs [ 0 ] ;
84+ if ( first_input . TensorShape . rank != 1 )
85+ {
86+ var input_shape = first_input . TensorShape . with_rank_at_least ( 2 ) ;
87+ fixed_batch_size = input_shape . dims [ 0 ] ;
88+ var flat_inputs = nest . flatten2 ( inputs ) ;
89+ foreach ( var flat_input in flat_inputs )
90+ {
91+ input_shape = flat_input . TensorShape . with_rank_at_least ( 2 ) ;
92+ batch_size = tensor_shape . dimension_at_index ( input_shape , 0 ) ;
93+ var input_size = input_shape [ 1 ] ;
94+ fixed_batch_size . merge_with ( batch_size ) ;
95+ foreach ( var ( i , size ) in enumerate ( input_size . dims ) )
96+ {
97+ if ( size < 0 )
98+ throw new ValueError ( $ "Input size (dimension { i } of inputs) must be accessible via " +
99+ "shape inference, but saw value None." ) ;
100+ }
101+ }
102+ }
103+ else
104+ fixed_batch_size = first_input . TensorShape . with_rank_at_least ( 1 ) . dims [ 0 ] ;
105+
106+ if ( tensor_shape . dimension_value ( fixed_batch_size ) >= 0 )
107+ batch_size = tensor_shape . dimension_value ( fixed_batch_size ) ;
108+ else
109+ batch_size_tensor = array_ops . shape ( first_input ) [ 0 ] ;
110+
111+ Tensor state = null ;
112+ if ( initial_state != null )
113+ state = initial_state ;
114+ else
115+ {
116+ cell . get_initial_state ( batch_size : batch_size_tensor , dtype : dtype ) ;
117+ }
118+ } ) ;
119+ }
120+
28121 public static ( Tensor , Tensor ) dynamic_rnn ( RnnCell cell , Tensor inputs_tensor ,
29122 Tensor sequence_length = null , Tensor initial_state = null ,
30123 TF_DataType dtype = TF_DataType . DtInvalid ,
0 commit comments