1111using System . Linq . Expressions ;
1212using Tensorflow . Keras . Utils ;
1313using Tensorflow . Common . Types ;
14+ using System . Runtime . CompilerServices ;
1415// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
1516
1617namespace Tensorflow . Keras . Layers . Rnn
@@ -30,7 +31,19 @@ public class RNN : RnnBase
3031 private int _num_constants ;
3132 protected IVariableV1 _kernel ;
3233 protected IVariableV1 _bias ;
33- protected IRnnCell _cell ;
34+ private IRnnCell _cell ;
35+ protected IRnnCell Cell
36+ {
37+ get
38+ {
39+ return _cell ;
40+ }
41+ init
42+ {
43+ _cell = value ;
44+ _self_tracked_trackables . Add ( _cell ) ;
45+ }
46+ }
3447
3548 public RNN ( RNNArgs args ) : base ( PreConstruct ( args ) )
3649 {
@@ -40,14 +53,14 @@ public RNN(RNNArgs args) : base(PreConstruct(args))
4053 // if is StackedRnncell
4154 if ( args . Cells != null )
4255 {
43- _cell = new StackedRNNCells ( new StackedRNNCellsArgs
56+ Cell = new StackedRNNCells ( new StackedRNNCellsArgs
4457 {
4558 Cells = args . Cells
4659 } ) ;
4760 }
4861 else
4962 {
50- _cell = args . Cell ;
63+ Cell = args . Cell ;
5164 }
5265
5366 // get input_shape
@@ -65,7 +78,7 @@ public Tensors States
6578 if ( _states == null )
6679 {
6780 // CHECK(Rinne): check if this is correct.
68- var nested = _cell . StateSize . MapStructure < Tensor ? > ( x => null ) ;
81+ var nested = Cell . StateSize . MapStructure < Tensor ? > ( x => null ) ;
6982 _states = nested . AsNest ( ) . ToTensors ( ) ;
7083 }
7184 return _states ;
@@ -83,7 +96,7 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
8396 }
8497
8598 // state_size is a array of ints or a positive integer
86- var state_size = _cell . StateSize . ToSingleShape ( ) ;
99+ var state_size = Cell . StateSize . ToSingleShape ( ) ;
87100
88101 // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
89102 Func < Shape , Shape > _get_output_shape ;
@@ -110,12 +123,12 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
110123 return output_shape ;
111124 } ;
112125
113- Type type = _cell . GetType ( ) ;
126+ Type type = Cell . GetType ( ) ;
114127 PropertyInfo output_size_info = type . GetProperty ( "output_size" ) ;
115128 Shape output_shape ;
116129 if ( output_size_info != null )
117130 {
118- output_shape = nest . map_structure ( _get_output_shape , _cell . OutputSize . ToSingleShape ( ) ) ;
131+ output_shape = nest . map_structure ( _get_output_shape , Cell . OutputSize . ToSingleShape ( ) ) ;
119132 // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
120133 output_shape = ( output_shape . Length == 1 ? ( int ) output_shape [ 0 ] : output_shape ) ;
121134 }
@@ -171,7 +184,9 @@ private Tensors compute_mask(Tensors inputs, Tensors mask)
171184
172185 public override void build ( KerasShapesWrapper input_shape )
173186 {
174- object get_input_spec ( Shape shape )
187+ input_shape = new KerasShapesWrapper ( input_shape . Shapes [ 0 ] ) ;
188+
189+ InputSpec get_input_spec ( Shape shape )
175190 {
176191 var input_spec_shape = shape . as_int_list ( ) ;
177192
@@ -213,10 +228,13 @@ object get_state_spec(Shape shape)
213228 // numpy inputs.
214229
215230
216- if ( ! _cell . Built )
231+ if ( Cell is Layer layer && ! layer . Built )
217232 {
218- _cell . build ( input_shape ) ;
233+ layer . build ( input_shape ) ;
234+ layer . Built = true ;
219235 }
236+
237+ this . built = true ;
220238 }
221239
222240 /// <summary>
@@ -247,10 +265,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
247265
248266 ( inputs , initial_state , constants ) = _process_inputs ( inputs , initial_state , constants ) ;
249267
250- _maybe_reset_cell_dropout_mask ( _cell ) ;
251- if ( _cell is StackedRNNCells )
268+ _maybe_reset_cell_dropout_mask ( Cell ) ;
269+ if ( Cell is StackedRNNCells )
252270 {
253- var stack_cell = _cell as StackedRNNCells ;
271+ var stack_cell = Cell as StackedRNNCells ;
254272 foreach ( IRnnCell cell in stack_cell . Cells )
255273 {
256274 _maybe_reset_cell_dropout_mask ( cell ) ;
@@ -300,10 +318,10 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
300318 bool is_tf_rnn_cell = false ;
301319 if ( constants is not null )
302320 {
303- if ( ! _cell . SupportOptionalArgs )
321+ if ( ! Cell . SupportOptionalArgs )
304322 {
305323 throw new ValueError (
306- $ "RNN cell { _cell } does not support constants." +
324+ $ "RNN cell { Cell } does not support constants." +
307325 $ "Received: constants={ constants } ") ;
308326 }
309327
@@ -312,7 +330,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
312330 constants = new Tensors ( states . TakeLast ( _num_constants ) . ToArray ( ) ) ;
313331 states = new Tensors ( states . SkipLast ( _num_constants ) . ToArray ( ) ) ;
314332 states = len ( states ) == 1 && is_tf_rnn_cell ? new Tensors ( states [ 0 ] ) : states ;
315- var ( output , new_states ) = _cell . Apply ( inputs , states , optional_args : new RnnOptionalArgs ( ) { Constants = constants } ) ;
333+ var ( output , new_states ) = Cell . Apply ( inputs , states , optional_args : new RnnOptionalArgs ( ) { Constants = constants } ) ;
316334 return ( output , new_states . Single ) ;
317335 } ;
318336 }
@@ -321,7 +339,7 @@ protected override Tensors Call(Tensors inputs, Tensors initial_state = null, bo
321339 step = ( inputs , states ) =>
322340 {
323341 states = len ( states ) == 1 && is_tf_rnn_cell ? new Tensors ( states . First ( ) ) : states ;
324- var ( output , new_states ) = _cell . Apply ( inputs , states ) ;
342+ var ( output , new_states ) = Cell . Apply ( inputs , states ) ;
325343 return ( output , new_states ) ;
326344 } ;
327345 }
@@ -562,7 +580,7 @@ protected Tensors get_initial_state(Tensors inputs)
562580 var batch_size = _args . TimeMajor ? input_shape [ 1 ] : input_shape [ 0 ] ;
563581 var dtype = input . dtype ;
564582
565- Tensors init_state = _cell . GetInitialState ( null , batch_size , dtype ) ;
583+ Tensors init_state = Cell . GetInitialState ( null , batch_size , dtype ) ;
566584
567585 return init_state ;
568586 }
0 commit comments