@@ -510,7 +510,7 @@ Tensor swap_batch_timestep(Tensor input_t)
510510 }
511511
512512 }
513-
513+
514514 // tf.where needs its condition tensor to be the same shape as its two
515515 // result tensors, but in our case the condition (mask) tensor is
516516 // (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
@@ -535,7 +535,7 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
535535 {
536536 mask_t = tf . expand_dims ( mask_t , - 1 ) ;
537537 }
538- var multiples = Enumerable . Repeat ( 1 , fixed_dim ) . ToArray ( ) . concat ( input_t . shape . as_int_list ( ) . ToList ( ) . GetRange ( fixed_dim , input_t . rank ) ) ;
538+ var multiples = Enumerable . Repeat ( 1 , fixed_dim ) . ToArray ( ) . concat ( input_t . shape . as_int_list ( ) . Skip ( fixed_dim ) . ToArray ( ) ) ;
539539 return tf . tile ( mask_t , multiples ) ;
540540 }
541541
@@ -570,9 +570,6 @@ Tensors _expand_mask(Tensors mask_t, Tensors input_t, int fixed_dim = 1)
570570 // individually. The result of this will be a tuple of lists, each of
571571 // the item in tuple is list of the tensor with shape (batch, feature)
572572
573-
574-
575-
576573 Tensors _process_single_input_t ( Tensor input_t )
577574 {
578575 var unstaked_input_t = array_ops . unstack ( input_t ) ; // unstack for time_step dim
@@ -609,7 +606,7 @@ object _get_input_tensor(int time)
609606 var mask_list = tf . unstack ( mask ) ;
610607 if ( go_backwards )
611608 {
612- mask_list . Reverse ( ) ;
609+ mask_list . Reverse ( ) . ToArray ( ) ;
613610 }
614611
615612 for ( int i = 0 ; i < time_steps ; i ++ )
@@ -629,9 +626,10 @@ object _get_input_tensor(int time)
629626 }
630627 else
631628 {
632- prev_output = successive_outputs [ successive_outputs . Length - 1 ] ;
629+ prev_output = successive_outputs . Last ( ) ;
633630 }
634631
632+ // output could be a tensor
635633 output = tf . where ( tiled_mask_t , output , prev_output ) ;
636634
637635 var flat_states = Nest . Flatten ( states ) . ToList ( ) ;
@@ -661,13 +659,13 @@ object _get_input_tensor(int time)
661659 }
662660
663661 }
664- last_output = successive_outputs [ successive_outputs . Length - 1 ] ;
665- new_states = successive_states [ successive_states . Length - 1 ] ;
662+ last_output = successive_outputs . Last ( ) ;
663+ new_states = successive_states . Last ( ) ;
666664 outputs = tf . stack ( successive_outputs ) ;
667665
668666 if ( zero_output_for_mask )
669667 {
670- last_output = tf . where ( _expand_mask ( mask_list [ mask_list . Length - 1 ] , last_output ) , last_output , tf . zeros_like ( last_output ) ) ;
668+ last_output = tf . where ( _expand_mask ( mask_list . Last ( ) , last_output ) , last_output , tf . zeros_like ( last_output ) ) ;
671669 outputs = tf . where ( _expand_mask ( mask , outputs , fixed_dim : 2 ) , outputs , tf . zeros_like ( outputs ) ) ;
672670 }
673671 else // mask is null
@@ -689,8 +687,8 @@ object _get_input_tensor(int time)
689687 successive_states = new Tensors { newStates } ;
690688 }
691689 }
692- last_output = successive_outputs [ successive_outputs . Length - 1 ] ;
693- new_states = successive_states [ successive_states . Length - 1 ] ;
690+ last_output = successive_outputs . Last ( ) ;
691+ new_states = successive_states . Last ( ) ;
694692 outputs = tf . stack ( successive_outputs ) ;
695693 }
696694 }
@@ -701,6 +699,8 @@ object _get_input_tensor(int time)
701699 // Create input tensor array, if the inputs is nested tensors, then it
702700 // will be flattened first, and tensor array will be created one per
703701 // flattened tensor.
702+
703+
704704 var input_ta = new List < TensorArray > ( ) ;
705705 for ( int i = 0 ; i < flatted_inptus . Count ; i ++ )
706706 {
@@ -719,6 +719,7 @@ object _get_input_tensor(int time)
719719 }
720720 }
721721
722+
722723 // Get the time(0) input and compute the output for that, the output will
723724 // be used to determine the dtype of output tensor array. Don't read from
724725 // input_ta due to TensorArray clear_after_read default to True.
@@ -773,7 +774,7 @@ object _get_input_tensor(int time)
773774 return res ;
774775 } ;
775776 }
776- // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor)?
777+ // TODO(Wanglongzhi2001), what the input_length's type should be(an integer or a single tensor), it could be an integer or tensor
777778 else if ( input_length is Tensor )
778779 {
779780 if ( go_backwards )
0 commit comments