@@ -42,7 +42,7 @@ public class WhileContext : ControlFlowContext
4242 public override GradLoopState grad_state => _grad_state ;
4343 public override bool back_prop => _back_prop ;
4444
45- public WhileContext ( int ? maximum_iterations = null ,
45+ public WhileContext ( Tensor maximum_iterations = null ,
4646 int parallel_iterations = 10 ,
4747 bool back_prop = true ,
4848 bool swap_memory = false ,
@@ -64,7 +64,7 @@ public WhileContext(int? maximum_iterations = null,
6464 _grad_state = grad_state ;
6565 }
6666
67- private void _init_from_args ( int ? maximum_iterations ,
67+ private void _init_from_args ( Tensor maximum_iterations ,
6868 int parallel_iterations ,
6969 bool back_prop ,
7070 bool swap_memory ,
@@ -107,9 +107,9 @@ private void _init_from_proto(WhileContextDef context_def, string import_scope =
107107 /// <summary>
108108 /// Add the loop termination condition and body to the graph.
109109 /// </summary>
110- public Tensor [ ] BuildLoop ( Func < Tensor , Tensor > pred ,
111- Func < Tensor , Tensor > body ,
112- Tensor [ ] loop_vars ,
110+ internal Tensor [ ] BuildLoop < TItem > ( Func < Tensor , TItem , Tensor > pred ,
111+ Func < Tensor , TItem , LoopVar < TItem > > body ,
112+ TItem loop_vars ,
113113 TensorShape shape_invariants ,
114114 bool return_same_structure )
115115 {
@@ -131,88 +131,107 @@ public Tensor[] BuildLoop(Func<Tensor, Tensor> pred,
131131 return packed_exit_vars as Tensor [ ] ;
132132 }
133133
134- private ( Tensor [ ] , Tensor [ ] ) _BuildLoop ( Func < Tensor , Tensor > pred ,
135- Func < Tensor , Tensor > body ,
136- Tensor [ ] original_loop_vars ,
137- Tensor [ ] loop_vars ,
134+ private Tensor _convert_tensorarray_to_flow < TItem > ( TItem tensor_or_tensor_array )
135+ {
136+ if ( tensor_or_tensor_array is TensorArray tensor_array )
137+ return tensor_array . flow ;
138+ else if ( tensor_or_tensor_array is Tensor tensor )
139+ return tensor ;
140+
141+ throw new NotImplementedException ( "_convert_tensorarray_to_flow" ) ;
142+ }
143+
144+ private ( Tensor [ ] , Tensor [ ] ) _BuildLoop < TItem > ( Func < Tensor , TItem , Tensor > pred ,
145+ Func < Tensor , TItem , LoopVar < TItem > > body ,
146+ TItem original_loop_vars ,
147+ TItem loop_vars ,
138148 TensorShape shape_invariants )
139149 {
140150 var flat_loop_vars = original_loop_vars ;
141151
152+ // Convert TensorArrays to their flow variables
153+ var loop_vars_tensor = nest . map_structure (
154+ _convert_tensorarray_to_flow ,
155+ nest . flatten ( loop_vars ) ) ;
156+
142157 // Let the context know the loop variables so the loop variables
143158 // would be added in the outer contexts properly.
144- _InitializeValues ( loop_vars ) ;
145- var real_vars = loop_vars ;
146- Tensor [ ] enter_vars = null ;
147- tf_with ( ops . control_dependencies ( null ) , delegate
159+ if ( loop_vars is Tensor [ ] real_vars )
148160 {
149- enter_vars = real_vars . Select ( x => _Enter ( x ,
150- _name ,
151- is_constant : false ,
152- parallel_iterations : _parallel_iterations ,
153- use_input_shape : shape_invariants == null ) )
154- . ToArray ( ) ;
155-
156- foreach ( var x in enter_vars )
161+ _InitializeValues ( real_vars ) ;
162+ Tensor [ ] enter_vars = null ;
163+ tf_with ( ops . control_dependencies ( null ) , delegate
164+ {
165+ enter_vars = real_vars . Select ( x => _Enter ( x ,
166+ _name ,
167+ is_constant : false ,
168+ parallel_iterations : _parallel_iterations ,
169+ use_input_shape : shape_invariants == null ) )
170+ . ToArray ( ) ;
171+
172+ foreach ( var x in enter_vars )
173+ {
174+ x . graph . prevent_feeding ( x ) ;
175+ if ( _outer_context != null )
176+ _outer_context . AddInnerOp ( x . op ) ;
177+ }
178+ } ) ;
179+
180+ // Finds the closest enclosing non-None control pivot.
181+ var outer_context = _outer_context ;
182+ while ( outer_context != null )
157183 {
158- x . graph . prevent_feeding ( x ) ;
159- if ( _outer_context != null )
160- _outer_context . AddInnerOp ( x . op ) ;
184+
161185 }
162- } ) ;
163186
164- // Finds the closest enclosing non-None control pivot.
165- var outer_context = _outer_context ;
166- while ( outer_context != null )
167- {
187+ _SetShapeInvariants ( real_vars , enter_vars , shape_invariants ) ;
188+
189+ // Fix the control inputs and control flow context of these enter ops.
190+ _FixControlInputsAndContext ( enter_vars ) ;
191+ _InitializeValues ( enter_vars ) ;
192+ _loop_enters = enter_vars . ToList ( ) ;
193+
194+ var merge_vars = enter_vars
195+ . Select ( x => merge ( new [ ] { x , x } ) )
196+ . ToArray ( ) ;
197+
198+ _pivot_for_pred = merge_vars [ 0 ] ;
199+
200+ // Build the graph for pred.
201+ var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_loop_vars , merge_vars ) ;
202+ // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
203+ var c = ops . convert_to_tensor ( pred ( merge_vars_with_tensor_arrays [ 0 ] , default ( TItem ) ) ) ;
204+ _pivot = gen_control_flow_ops . loop_cond ( c , name : "LoopCond" ) ;
205+ var switch_vars = merge_vars . Select ( x => _SwitchRefOrTensor ( x , _pivot ) )
206+ . ToArray ( ) ;
168207
208+ // Build the graph for body.
209+ var vars_for_body = switch_vars . Select ( x => _Identity ( x [ 1 ] ) ) . ToArray ( ) ;
210+ // Convert TensorArray flow variables inside the context back into
211+ // their associated TensorArrays for calling the body.
212+ var packed_vars_for_body = _convert_flows_to_tensorarrays ( flat_loop_vars , vars_for_body ) ;
213+ /*var body_result = body(packed_vars_for_body[0]);
214+ var post_summaries = ops.get_collection(tf.GraphKeys._SUMMARY_COLLECTION);
215+
216+ // Store body_result to keep track of TensorArrays returned by body
217+ var original_body_result = new[] { body_result };
218+ // Convert TensorArrays returned by body into their flow variables
219+ var result = new[] { body_result };
220+
221+ var next_vars = new List<Tensor>();
222+ foreach (var (m, v) in zip(merge_vars, result))
223+ next_vars.Add(_AddNextAndBackEdge(m, v));
224+
225+ // Add the exit ops.
226+ var exit_vars = switch_vars.Select(x => exit(x[0])).ToList();
227+ _loop_exits = exit_vars;
228+
229+ // Exit the loop.
230+ // ExitResult(exit_vars);
231+ return (original_body_result, exit_vars.ToArray());*/
169232 }
170233
171- _SetShapeInvariants ( real_vars , enter_vars , shape_invariants ) ;
172-
173- // Fix the control inputs and control flow context of these enter ops.
174- _FixControlInputsAndContext ( enter_vars ) ;
175- _InitializeValues ( enter_vars ) ;
176- _loop_enters = enter_vars . ToList ( ) ;
177-
178- var merge_vars = enter_vars
179- . Select ( x => merge ( new [ ] { x , x } ) )
180- . ToArray ( ) ;
181-
182- _pivot_for_pred = merge_vars [ 0 ] ;
183-
184- // Build the graph for pred.
185- var merge_vars_with_tensor_arrays = _convert_flows_to_tensorarrays ( flat_loop_vars , merge_vars ) ;
186- // var packed_vars = nest.pack_sequence_as(original_loop_vars, merge_vars_with_tensor_arrays);
187- var c = ops . convert_to_tensor ( pred ( merge_vars_with_tensor_arrays [ 0 ] ) ) ;
188- _pivot = gen_control_flow_ops . loop_cond ( c , name : "LoopCond" ) ;
189- var switch_vars = merge_vars . Select ( x => _SwitchRefOrTensor ( x , _pivot ) )
190- . ToArray ( ) ;
191-
192- // Build the graph for body.
193- var vars_for_body = switch_vars . Select ( x => _Identity ( x [ 1 ] ) ) . ToArray ( ) ;
194- // Convert TensorArray flow variables inside the context back into
195- // their associated TensorArrays for calling the body.
196- var packed_vars_for_body = _convert_flows_to_tensorarrays ( flat_loop_vars , vars_for_body ) ;
197- var body_result = body ( packed_vars_for_body [ 0 ] ) ;
198- var post_summaries = ops . get_collection ( tf . GraphKeys . _SUMMARY_COLLECTION ) ;
199-
200- // Store body_result to keep track of TensorArrays returned by body
201- var original_body_result = new [ ] { body_result } ;
202- // Convert TensorArrays returned by body into their flow variables
203- var result = new [ ] { body_result } ;
204-
205- var next_vars = new List < Tensor > ( ) ;
206- foreach ( var ( m , v ) in zip ( merge_vars , result ) )
207- next_vars . Add ( _AddNextAndBackEdge ( m , v ) ) ;
208-
209- // Add the exit ops.
210- var exit_vars = switch_vars . Select ( x => exit ( x [ 0 ] ) ) . ToList ( ) ;
211- _loop_exits = exit_vars ;
212-
213- // Exit the loop.
214- // ExitResult(exit_vars);
215- return ( original_body_result , exit_vars . ToArray ( ) ) ;
234+ throw new NotImplementedException ( "" ) ;
216235 }
217236
218237 private void _FixControlInputsAndContext ( Tensor [ ] enters )
0 commit comments