@@ -86,7 +86,7 @@ public Tensors States
8686 set { _states = value ; }
8787 }
8888
89- private OneOf < Shape , List < Shape > > compute_output_shape ( Shape input_shape )
89+ private INestStructure < Shape > compute_output_shape ( Shape input_shape )
9090 {
9191 var batch = input_shape [ 0 ] ;
9292 var time_step = input_shape [ 1 ] ;
@@ -96,13 +96,15 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
9696 }
9797
9898 // state_size is a array of ints or a positive integer
99- var state_size = Cell . StateSize . ToSingleShape ( ) ;
99+ var state_size = Cell . StateSize ;
100+ if ( state_size ? . TotalNestedCount == 1 )
101+ {
102+ state_size = new NestList < long > ( state_size . Flatten ( ) . First ( ) ) ;
103+ }
100104
101- // TODO(wanglongzhi2001),flat_output_size应该是什么类型的,Shape还是Tensor
102- Func < Shape , Shape > _get_output_shape ;
103- _get_output_shape = ( flat_output_size ) =>
105+ Func < long , Shape > _get_output_shape = ( flat_output_size ) =>
104106 {
105- var output_dim = flat_output_size . as_int_list ( ) ;
107+ var output_dim = new Shape ( flat_output_size ) . as_int_list ( ) ;
106108 Shape output_shape ;
107109 if ( _args . ReturnSequences )
108110 {
@@ -125,31 +127,28 @@ private OneOf<Shape, List<Shape>> compute_output_shape(Shape input_shape)
125127
126128 Type type = Cell . GetType ( ) ;
127129 PropertyInfo output_size_info = type . GetProperty ( "output_size" ) ;
128- Shape output_shape ;
130+ INestStructure < Shape > output_shape ;
129131 if ( output_size_info != null )
130132 {
131- output_shape = nest . map_structure ( _get_output_shape , Cell . OutputSize . ToSingleShape ( ) ) ;
132- // TODO(wanglongzhi2001),output_shape应该简单的就是一个元组还是一个Shape类型
133- output_shape = ( output_shape . Length == 1 ? ( int ) output_shape [ 0 ] : output_shape ) ;
133+ output_shape = Nest . MapStructure ( _get_output_shape , Cell . OutputSize ) ;
134134 }
135135 else
136136 {
137- output_shape = _get_output_shape ( state_size ) ;
137+ output_shape = new NestNode < Shape > ( _get_output_shape ( state_size . Flatten ( ) . First ( ) ) ) ;
138138 }
139139
140140 if ( _args . ReturnState )
141141 {
142- Func < Shape , Shape > _get_state_shape ;
143- _get_state_shape = ( flat_state ) =>
142+ Func < long , Shape > _get_state_shape = ( flat_state ) =>
144143 {
145- var state_shape = new int [ ] { ( int ) batch } . concat ( flat_state . as_int_list ( ) ) ;
144+ var state_shape = new int [ ] { ( int ) batch } . concat ( new Shape ( flat_state ) . as_int_list ( ) ) ;
146145 return new Shape ( state_shape ) ;
147146 } ;
148147
149148
150- var state_shape = _get_state_shape ( state_size ) ;
149+ var state_shape = Nest . MapStructure ( _get_state_shape , state_size ) ;
151150
152- return new List < Shape > { output_shape , state_shape } ;
151+ return new Nest < Shape > ( new [ ] { output_shape , state_shape } ) ;
153152 }
154153 else
155154 {
@@ -435,7 +434,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
435434 tmp . add ( tf . math . count_nonzero ( s . Single ( ) ) ) ;
436435 }
437436 var non_zero_count = tf . add_n ( tmp ) ;
438- // initial_state = tf.cond(non_zero_count > 0, () => States, () => initial_state);
437+ initial_state = tf . cond ( non_zero_count > 0 , States , initial_state ) ;
439438 if ( ( int ) non_zero_count . numpy ( ) > 0 )
440439 {
441440 initial_state = States ;
@@ -445,16 +444,7 @@ public override Tensors Apply(Tensors inputs, Tensors initial_states = null, boo
445444 {
446445 initial_state = States ;
447446 }
448- // TODO(Wanglongzhi2001),
449- // initial_state = tf.nest.map_structure(
450- //# When the layer has a inferred dtype, use the dtype from the
451- //# cell.
452- // lambda v: tf.cast(
453- // v, self.compute_dtype or self.cell.compute_dtype
454- // ),
455- // initial_state,
456- // )
457-
447+ //initial_state = Nest.MapStructure(v => tf.cast(v, this.), initial_state);
458448 }
459449 else if ( initial_state is null )
460450 {
0 commit comments