@@ -71,6 +71,8 @@ private void _init_from_args(Tensor maximum_iterations,
7171 string name )
7272 {
7373 _name = ops . get_default_graph ( ) . unique_name ( name ) ;
74+ _maximum_iterations = maximum_iterations ;
75+ _parallel_iterations = parallel_iterations ;
7476 _back_prop = back_prop ;
7577 _swap_memory = swap_memory ;
7678 _loop_exits = new List < Tensor > ( ) ;
@@ -107,18 +109,27 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
107109 /// <summary>
108110 /// Add the loop termination condition and body to the graph.
109111 /// </summary>
110- internal Tensor [ ] BuildLoop < TItem > ( Func < Tensor , TItem , Tensor > pred ,
111- Func < Tensor , TItem , LoopVar < TItem > > body ,
112+ internal Tensor [ ] BuildLoop < TItem > ( Func < LoopVar < TItem > , Tensor > pred ,
113+ Func < LoopVar < TItem > , LoopVar < TItem > > body ,
112114 LoopVar < TItem > loop_vars ,
113- TensorShape shape_invariants ,
115+ TensorShape [ ] shape_invariants ,
114116 bool return_same_structure )
115117 {
116118 // Keep original_loop_vars to identify which are TensorArrays
117119 var original_loop_vars = loop_vars ;
118120 // Convert TensorArrays to their flow variables
121+ var loop_vars_tensors = nest . flatten2 ( loop_vars )
122+ . Select ( x => _convert_tensorarray_to_flow ( x ) )
123+ . ToArray ( ) ;
124+
125+ if ( shape_invariants == null )
126+ shape_invariants = loop_vars_tensors
127+ . Select ( x => _get_shape_invariant ( x as Tensor ) )
128+ . ToArray ( ) ;
129+
119130 Enter ( ) ;
120131 var ( original_body_result , exit_vars ) = _BuildLoop (
121- pred , body , original_loop_vars , loop_vars , shape_invariants ) ;
132+ pred , body , original_loop_vars , loop_vars_tensors , shape_invariants ) ;
122133 Exit ( ) ;
123134
124135 var flat_result = original_body_result ;
@@ -131,7 +142,7 @@ internal Tensor[] BuildLoop<TItem>(Func<Tensor, TItem, Tensor> pred,
131142 return packed_exit_vars as Tensor [ ] ;
132143 }
133144
134- private Tensor _convert_tensorarray_to_flow < TItem > ( TItem tensor_or_tensor_array )
145+ private Tensor _convert_tensorarray_to_flow ( object tensor_or_tensor_array )
135146 {
136147 if ( tensor_or_tensor_array is TensorArray tensor_array )
137148 return tensor_array . flow ;
@@ -141,97 +152,116 @@ private Tensor _convert_tensorarray_to_flow<TItem>(TItem tensor_or_tensor_array)
141152 throw new NotImplementedException ( "_convert_tensorarray_to_flow" ) ;
142153 }
143154
144- private ( Tensor [ ] , Tensor [ ] ) _BuildLoop < TItem > ( Func < Tensor , TItem , Tensor > pred ,
145- Func < Tensor , TItem , LoopVar < TItem > > body ,
146- LoopVar < TItem > original_loop_vars ,
147- LoopVar < TItem > loop_vars ,
148- TensorShape shape_invariants )
155+ private TensorShape _get_shape_invariant ( Tensor var , int [ ] shape = null )
149156 {
150- var flat_loop_vars = original_loop_vars ;
157+ return var . TensorShape ;
158+ }
151159
152- // Convert TensorArrays to their flow variables
153- var loop_vars_tensor = nest . map_structure (
154- _convert_tensorarray_to_flow ,
155- nest . flatten2 ( loop_vars ) ) ;
160+ /// <summary>
161+ /// Add the loop termination condition and body to the graph.
162+ /// </summary>
163+ /// <typeparam name="TItem"></typeparam>
164+ /// <param name="pred"></param>
165+ /// <param name="body"></param>
166+ /// <param name="original_loop_vars"></param>
167+ /// <param name="loop_vars"></param>
168+ /// <param name="shape_invariants"></param>
169+ /// <returns></returns>
170+ private ( Tensor [ ] , Tensor [ ] ) _BuildLoop < TItem > ( Func < LoopVar < TItem > , Tensor > pred ,
171+ Func < LoopVar < TItem > , LoopVar < TItem > > body ,
172+ LoopVar < TItem > original_loop_vars ,
173+ Tensor [ ] loop_vars ,
174+ TensorShape [ ] shape_invariants )
175+ {
176+ var flat_loop_vars = nest . flatten2 ( original_loop_vars )
177+ . Select ( x => ( ITensorOrTensorArray ) x )
178+ . ToArray ( ) ;
156179
157180 // Let the context know the loop variables so the loop variables
158181 // would be added in the outer contexts properly.
159- if ( loop_vars is Tensor [ ] real_vars )
182+ _InitializeValues ( loop_vars ) ;
183+ var real_vars = loop_vars ;
184+ Tensor [ ] enter_vars = null ;
185+ tf_with ( ops . control_dependencies ( null ) , delegate
160186 {
161- _InitializeValues ( real_vars ) ;
162- Tensor [ ] enter_vars = null ;
163- tf_with ( ops . control_dependencies ( null ) , delegate
164- {
165- enter_vars = real_vars . Select ( x => _Enter ( x ,
166- _name ,
167- is_constant : false ,
168- parallel_iterations : _parallel_iterations ,
169- use_input_shape : shape_invariants == null ) )
170- . ToArray ( ) ;
171-
172- foreach ( var x in enter_vars )
173- {
174- x . graph . prevent_feeding ( x ) ;
175- if ( _outer_context != null )
176- _outer_context . AddInnerOp ( x . op ) ;
177- }
178- } ) ;
179-
180- // Finds the closest enclosing non-None control pivot.
181- var outer_context = _outer_context ;
182- while ( outer_context != null )
187+ enter_vars = real_vars . Select ( x => _Enter ( x ,
188+ _name ,
189+ is_constant : false ,
190+ parallel_iterations : _parallel_iterations ,
191+ use_input_shape : shape_invariants == null ) )
192+ . ToArray ( ) ;
193+
194+ foreach ( var x in enter_vars )
183195 {
184-
196+ x . graph . prevent_feeding ( x ) ;
197+ if ( _outer_context != null )
198+ _outer_context . AddInnerOp ( x . op ) ;
185199 }
200+ } ) ;
186201
187- _SetShapeInvariants ( real_vars , enter_vars , shape_invariants ) ;
188-
189- // Fix the control inputs and control flow context of these enter ops.
190- _FixControlInputsAndContext ( enter_vars ) ;
191- _InitializeValues ( enter_vars ) ;
192- _loop_enters = enter_vars . ToList ( ) ;
193-
194- var merge_vars = enter_vars
195- . Select ( x => merge ( new [ ] { x , x } ) )
196- . ToArray ( ) ;
202+ // Finds the closest enclosing non-None control pivot.
203+ var outer_context = _outer_context ;
204+ object control_pivot = null ;
205+ while ( outer_context != null && control_pivot == null )
206+ {
197207
198- _pivot_for_pred = merge_vars [ 0 ] ;
208+ }
199209
200- // Build the graph for pred.
201- var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_loop_vars , merge_vars ) ;
202- // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
203- var c = ops . convert_to_tensor ( pred ( merge_vars_with_tensor_arrays [ 0 ] , default ( TItem ) ) ) ;
204- _pivot = gen_control_flow_ops . loop_cond ( c , name : "LoopCond" ) ;
205- var switch_vars = merge_vars . Select ( x => _SwitchRefOrTensor ( x , _pivot ) )
206- . ToArray ( ) ;
210+ if ( control_pivot != null )
211+ {
207212
208- // Build the graph for body.
209- var vars_for_body = switch_vars . Select ( x => _Identity ( x [ 1 ] ) ) . ToArray ( ) ;
210- // Convert TensorArray flow variables inside the context back into
211- // their associated TensorArrays for calling the body.
212- var packed_vars_for_body = _convert_flows_to_tensorarrays ( flat_loop_vars , vars_for_body ) ;
213- /*var body_result = body(packed_vars_for_body[0]);
214- var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
215-
216- // Store body_result to keep track of TensorArrays returned by body
217- var original_body_result = new[] { body_result };
218- // Convert TensorArrays returned by body into their flow variables
219- var result = new[] { body_result };
220-
221- var next_vars = new List<Tensor>();
222- foreach (var (m, v) in zip(merge_vars, result))
223- next_vars.Add(_AddNextAndBackEdge(m, v));
224-
225- // Add the exit ops.
226- var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
227- _loop_exits = exit_vars;
228-
229- // Exit the loop.
230- // ExitResult(exit_vars);
231- return (original_body_result, exit_vars.ToArray());*/
232213 }
233214
234- throw new NotImplementedException ( "" ) ;
215+ _SetShapeInvariants ( real_vars , enter_vars , shape_invariants ) ;
216+
217+ // Fix the control inputs and control flow context of these enter ops.
218+ _FixControlInputsAndContext ( enter_vars ) ;
219+ _InitializeValues ( enter_vars ) ;
220+ _loop_enters = enter_vars . ToList ( ) ;
221+
222+ var merge_vars = enter_vars
223+ . Select ( x => merge ( new [ ] { x , x } ) )
224+ . ToArray ( ) ;
225+
226+ _pivot_for_pred = merge_vars [ 0 ] ;
227+
228+ // Build the graph for pred.
229+ var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_loop_vars , merge_vars ) ;
230+ //var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays, expand_composites: true);
231+ var packed_vars = new LoopVar < TItem > ( ( Tensor ) merge_vars_with_tensor_arrays [ 0 ] ,
232+ ( TItem ) ( object ) new BodyItemInRnnWhileLoop ( ( Tensor ) merge_vars_with_tensor_arrays [ 1 ] ,
233+ new [ ] { ( TensorArray ) merge_vars_with_tensor_arrays [ 2 ] } ,
234+ ( Tensor ) merge_vars_with_tensor_arrays [ 3 ] ) ) ;
235+ var pp = pred ( packed_vars ) ;
236+ var c = ops . convert_to_tensor ( pp ) ;
237+ _pivot = gen_control_flow_ops . loop_cond ( c , name : "LoopCond" ) ;
238+ var switch_vars = merge_vars . Select ( x => _SwitchRefOrTensor ( x , _pivot ) )
239+ . ToArray ( ) ;
240+
241+ // Build the graph for body.
242+ var vars_for_body = switch_vars . Select ( x => _Identity ( x [ 1 ] ) ) . ToArray ( ) ;
243+ // Convert TensorArray flow variables inside the context back into
244+ // their associated TensorArrays for calling the body.
245+ var packed_vars_for_body = _convert_flows_to_tensorarrays ( flat_loop_vars , vars_for_body ) ;
246+ var body_result = body ( original_loop_vars ) ;
247+ var post_summaries = ops . get_collection ( tf . GraphKeys . _SUMMARY_COLLECTION ) ;
248+
249+ // Store body_result to keep track of TensorArrays returned by body
250+ var original_body_result = new [ ] { body_result } ;
251+ // Convert TensorArrays returned by body into their flow variables
252+ var result = new [ ] { body_result } ;
253+
254+ var next_vars = new List < Tensor > ( ) ;
255+ //foreach (var (m, v) in zip(merge_vars, result))
256+ //next_vars.Add(_AddNextAndBackEdge(m, v));
257+
258+ // Add the exit ops.
259+ var exit_vars = switch_vars . Select ( x => exit ( x [ 0 ] ) ) . ToList ( ) ;
260+ _loop_exits = exit_vars ;
261+
262+ // Exit the loop.
263+ // ExitResult(exit_vars);
264+ return ( null , exit_vars . ToArray ( ) ) ;
235265 }
236266
237267 private void _FixControlInputsAndContext ( Tensor [ ] enters )
@@ -258,6 +288,23 @@ private void _InitializeValues(Tensor[] values)
258288 _values . Add ( x . name ) ;
259289 }
260290
291+ public override Tensor AddValue ( Tensor val )
292+ {
293+ var result = val ;
294+ var new_value = _values . Contains ( val . name ) ;
295+ new_value &= val . op . _get_control_flow_context ( ) != this ;
296+ if ( new_value )
297+ throw new NotImplementedException ( "" ) ;
298+ else
299+ {
300+ var actual_val = _external_values . ContainsKey ( val . name ) ? _external_values [ val . name ] : null ;
301+ if ( actual_val != null )
302+ result = actual_val as Tensor ;
303+ }
304+
305+ return result ;
306+ }
307+
261308 public override WhileContext GetWhileContext ( )
262309 {
263310 return this ;
0 commit comments