@@ -485,7 +485,7 @@ public static Tensor[] cond<T>(Tensor pred,
485485 } ) ;
486486 }
487487
488- public static Tensor [ ] _convert_flows_to_tensorarrays < T > ( T [ ] tensors_or_tensorarrays , Tensor [ ] tensors_or_flows )
488+ public static Tensor [ ] _convert_flows_to_tensorarrays < T > ( T tensors_or_tensorarrays , Tensor [ ] tensors_or_flows )
489489 {
490490 // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
491491 return tensors_or_flows ;
@@ -591,18 +591,18 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
591591 /// <param name="body"></param>
592592 /// <param name="loop_vars"></param>
593593 /// <param name="i"></param>
594- public static Tensor while_loop ( Func < Tensor , Tensor > cond , Func < Tensor , Tensor > body , Tensor [ ] loop_vars ,
594+ public static Tensor while_loop < TItem > ( Func < TItem , Tensor > cond , Func < TItem , TItem > body , TItem loop_vars ,
595595 TensorShape shape_invariants = null ,
596596 int parallel_iterations = 10 ,
597597 bool back_prop = true ,
598598 bool swap_memory = false ,
599599 string name = null ,
600- int ? maximum_iterations = null ,
600+ Tensor maximum_iterations = null ,
601601 bool return_same_structure = false )
602602 {
603603 tf_with ( ops . name_scope ( name , "while" , loop_vars ) , scope =>
604604 {
605- if ( loop_vars == null || loop_vars . Length == 0 )
605+ if ( loop_vars == null )
606606 throw new ValueError ( "No loop variables provided" ) ;
607607 if ( cond == null )
608608 throw new ValueError ( "cond must be callable." ) ;
@@ -611,6 +611,28 @@ public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor>
611611 if ( parallel_iterations < 1 )
612612 throw new ValueError ( "parallel_iterations must be a positive integer." ) ;
613613
614+ var try_to_pack = loop_vars is Tensor && ! return_same_structure ;
615+ var counter = constant_op . constant ( 0 , dtype : maximum_iterations . dtype , name : "iteration_counter" ) ;
616+ var orig_cond = cond ;
617+ var orig_body = body ;
618+
619+ LoopVar < TItem > loop_vars_1 = null ;
620+ Func < Tensor , TItem , LoopVar < TItem > > body_buildloop = null ;
621+ Func < Tensor , TItem , Tensor > cond_buildloop = null ;
622+
623+ if ( try_to_pack )
624+ {
625+
626+ }
627+ else
628+ {
629+ loop_vars_1 = new LoopVar < TItem > ( counter , loop_vars ) ;
630+ cond_buildloop = ( i , lv ) =>
631+ math_ops . logical_and ( i < maximum_iterations , orig_cond ( lv ) ) ;
632+ body_buildloop = ( i , lv ) => new LoopVar < TItem > ( i + 1 , orig_body ( lv ) ) ;
633+ }
634+ try_to_pack = false ;
635+
614636 var loop_context = new WhileContext (
615637 maximum_iterations : maximum_iterations ,
616638 parallel_iterations : parallel_iterations ,
@@ -620,7 +642,7 @@ public static Tensor while_loop(Func<Tensor, Tensor> cond, Func<Tensor, Tensor>
620642 if ( loop_context . outer_context == null )
621643 ops . add_to_collection ( tf . GraphKeys . WHILE_CONTEXT , loop_context ) ;
622644
623- var results = loop_context . BuildLoop ( cond , body , loop_vars , shape_invariants ,
645+ var results = loop_context . BuildLoop ( cond_buildloop , body_buildloop , loop_vars , shape_invariants ,
624646 return_same_structure ) ;
625647
626648 if ( maximum_iterations != null )
0 commit comments