@@ -21,6 +21,7 @@ limitations under the License.
2121using Tensorflow . Operations . ControlFlows ;
2222using util = Tensorflow . control_flow_util ;
2323using static Tensorflow . Binding ;
24+ using Tensorflow . Util ;
2425
2526namespace Tensorflow
2627{
@@ -251,12 +252,16 @@ public static Tensor _Identity(Tensor data, string name = null)
251252 return gen_array_ops . identity ( data , name : name ) ;
252253 }
253254
254- public static void _SetShapeInvariants ( Tensor [ ] input_vars , Tensor [ ] enter_vars , TensorShape shapes = null )
255+ public static void _SetShapeInvariants ( Tensor [ ] input_vars , Tensor [ ] enter_vars , TensorShape [ ] shapes = null )
255256 {
256257 if ( shapes == null )
257258 return ;
258259
259- throw new NotImplementedException ( "_SetShapeInvariants" ) ;
260+ var flat_shapes = nest . flatten2 ( shapes ) ;
261+ foreach ( var ( inp , var , shape ) in zip ( input_vars , enter_vars , flat_shapes ) )
262+ {
263+ var . set_shape ( shape ) ;
264+ }
260265 }
261266
262267 /// <summary>
@@ -428,12 +433,12 @@ raise ValueError(
428433 . Select ( pair => merge ( new Tensor [ ] { pair . Item1 , pair . Item2 } ) )
429434 . ToArray ( ) ;
430435
431- merges = _convert_flows_to_tensorarrays ( new Tensor [ ] { ( Tensor ) orig_res_t } , merges ) ;
436+ var merges2 = _convert_flows_to_tensorarrays ( new ITensorOrTensorArray [ ] { ( Tensor ) orig_res_t } , merges ) ;
432437
433438 ops . add_to_collection ( tf . GraphKeys . COND_CONTEXT , context_t ) ;
434439 ops . add_to_collection ( tf . GraphKeys . COND_CONTEXT , context_f ) ;
435440
436- return merges [ 0 ] ;
441+ return new Tensor ( IntPtr . Zero ) ;
437442 } ) ;
438443 }
439444
@@ -473,22 +478,28 @@ public static Tensor[] cond<T>(Tensor pred,
473478 var res_f_flat = res_f ;
474479
475480 var merges = zip ( res_f_flat , res_t_flat )
476- . Select ( pair => merge ( new Tensor [ ] { pair . Item1 , pair . Item2 } ) )
481+ . Select ( pair => merge ( new [ ] { pair . Item1 , pair . Item2 } ) )
477482 . ToArray ( ) ;
478483
479- merges = _convert_flows_to_tensorarrays ( orig_res_t , merges ) ;
484+ var merges2 = _convert_flows_to_tensorarrays ( orig_res_t . Select ( x => ( ITensorOrTensorArray ) x ) . ToArray ( ) , merges ) ;
480485
481486 ops . add_to_collection ( tf . GraphKeys . COND_CONTEXT , context_t ) ;
482487 ops . add_to_collection ( tf . GraphKeys . COND_CONTEXT , context_f ) ;
483488
484- return merges ;
489+ return new [ ] { new Tensor ( IntPtr . Zero ) } ;
485490 } ) ;
486491 }
487492
488- public static Tensor [ ] _convert_flows_to_tensorarrays < T > ( T tensors_or_tensorarrays , Tensor [ ] tensors_or_flows )
493+ public static ITensorOrTensorArray [ ] _convert_flows_to_tensorarrays ( ITensorOrTensorArray [ ] tensors_or_tensorarrays , Tensor [ ] tensors_or_flows )
489494 {
490- // zip(tensors_or_tensorarrays, tensors_or_flows).Select((ta, t_or_flow) => ta).ToArray();
491- return tensors_or_flows ;
495+ return zip ( tensors_or_tensorarrays , tensors_or_flows ) . Select ( x =>
496+ {
497+ var ( ta , t_or_flow ) = ( x . Item1 , x . Item2 ) ;
498+ if ( ta is TensorArray ta_1 )
499+ return tensor_array_ops . build_ta_with_new_flow ( ta_1 , t_or_flow ) as ITensorOrTensorArray ;
500+ else
501+ return t_or_flow as ITensorOrTensorArray ;
502+ } ) . ToArray ( ) ;
492503 }
493504
494505 /// <summary>
@@ -592,7 +603,7 @@ public static Tensor ZerosLikeOutsideLoop(Operation op, int index)
592603 /// <param name="loop_vars"></param>
593604 /// <param name="i"></param>
594605 public static Tensor while_loop < TItem > ( Func < TItem , Tensor > cond , Func < TItem , TItem > body , TItem loop_vars ,
595- TensorShape shape_invariants = null ,
606+ TensorShape [ ] shape_invariants = null ,
596607 int parallel_iterations = 10 ,
597608 bool back_prop = true ,
598609 bool swap_memory = false ,
@@ -617,8 +628,8 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
617628 var orig_body = body ;
618629
619630 LoopVar < TItem > loop_vars_1 = null ;
620- Func < Tensor , TItem , LoopVar < TItem > > body_buildloop = null ;
621- Func < Tensor , TItem , Tensor > cond_buildloop = null ;
631+ Func < LoopVar < TItem > , LoopVar < TItem > > body_buildloop = null ;
632+ Func < LoopVar < TItem > , Tensor > cond_buildloop = null ;
622633
623634 if ( try_to_pack )
624635 {
@@ -627,9 +638,18 @@ public static Tensor while_loop<TItem>(Func<TItem, Tensor> cond, Func<TItem, TIt
627638 else
628639 {
629640 loop_vars_1 = new LoopVar < TItem > ( counter , loop_vars ) ;
630- cond_buildloop = ( i , lv ) =>
631- math_ops . logical_and ( i < maximum_iterations , orig_cond ( lv ) ) ;
632- body_buildloop = ( i , lv ) => new LoopVar < TItem > ( i + 1 , orig_body ( lv ) ) ;
641+ cond_buildloop = ( item ) =>
642+ {
643+ var ( i , lv ) = ( item . Counter , item . Item ) ;
644+ var oc = orig_cond ( lv ) ;
645+ return math_ops . logical_and ( i < maximum_iterations , oc ) ;
646+ } ;
647+
648+ body_buildloop = ( item ) =>
649+ {
650+ var ( i , lv ) = ( item . Counter , item . Item ) ;
651+ return new LoopVar < TItem > ( i + 1 , orig_body ( lv ) ) ;
652+ } ;
633653 }
634654 try_to_pack = false ;
635655
0 commit comments