@@ -10,12 +10,13 @@ namespace Tensorflow.Gradients
1010 /// </summary>
1111 public class math_grad
1212 {
13- public static ( Tensor , Tensor ) _AddGrad ( Operation op , Tensor grad )
13+ public static Tensor [ ] _AddGrad ( Operation op , Tensor [ ] grads )
1414 {
1515 var x = op . inputs [ 0 ] ;
1616 var y = op . inputs [ 1 ] ;
17+ var grad = grads [ 0 ] ;
1718 if ( grad is Tensor && _ShapesFullySpecifiedAndEqual ( x , y , grad ) )
18- return ( grad , grad ) ;
19+ return new Tensor [ ] { grad , grad } ;
1920
2021 var sx = array_ops . shape ( x ) ;
2122 var sy = array_ops . shape ( y ) ;
@@ -24,21 +25,22 @@ public static (Tensor, Tensor) _AddGrad(Operation op, Tensor grad)
2425 var r1 = gen_array_ops . reshape ( math_ops . reduce_sum ( grad , rx ) , sx ) ;
2526 var r2 = gen_array_ops . reshape ( math_ops . reduce_sum ( grad , ry ) , sy ) ;
2627
27- return ( r1 , r2 ) ;
28+ return new Tensor [ ] { r1 , r2 } ;
2829 }
2930
30- public static Tensor _IdGrad ( Operation op , Tensor grad )
31+ public static Tensor [ ] _IdGrad ( Operation op , Tensor [ ] grads )
3132 {
32- return grad ;
33+ return new Tensor [ ] { grads [ 0 ] } ;
3334 }
3435
35- public static ( Tensor , Tensor ) _MulGrad ( Operation op , Tensor grad )
36+ public static Tensor [ ] _MulGrad ( Operation op , Tensor [ ] grads )
3637 {
3738 var x = op . inputs [ 0 ] ;
3839 var y = op . inputs [ 1 ] ;
40+ var grad = grads [ 0 ] ;
3941 if ( grad is Tensor && _ShapesFullySpecifiedAndEqual ( x , y , grad ) &&
4042 new TF_DataType [ ] { tf . int32 , tf . float32 } . Contains ( grad . dtype ) )
41- return ( gen_math_ops . mul ( grad , y ) , gen_math_ops . mul ( grad , x ) ) ;
43+ return new Tensor [ ] { gen_math_ops . mul ( grad , y ) , gen_math_ops . mul ( grad , x ) } ;
4244
4345 var sx = array_ops . shape ( x ) ;
4446 var sy = array_ops . shape ( y ) ;
@@ -54,11 +56,12 @@ public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad)
5456 var reshape1 = gen_array_ops . reshape ( reduce_sum1 , sx ) ;
5557 var reshape2 = gen_array_ops . reshape ( reduce_sum2 , sy ) ;
5658
57- return ( reshape1 , reshape2 ) ;
59+ return new Tensor [ ] { reshape1 , reshape2 } ;
5860 }
5961
60- public static ( Tensor , Tensor ) _MatMulGrad ( Operation op , Tensor grad )
62+ public static Tensor [ ] _MatMulGrad ( Operation op , Tensor [ ] grads )
6163 {
64+ var grad = grads [ 0 ] ;
6265 Tensor grad_a = null , grad_b = null ;
6366
6467 var t_a = ( bool ) op . get_attr ( "transpose_a" ) ;
@@ -86,33 +89,35 @@ public static (Tensor, Tensor) _MatMulGrad(Operation op, Tensor grad)
8689 grad_b = gen_math_ops . mat_mul ( grad , a , transpose_a : true , transpose_b : true ) ;
8790 }
8891
89- return ( grad_a , grad_b ) ;
92+ return new Tensor [ ] { grad_a , grad_b } ;
9093 }
9194
92- public static ( Tensor , Tensor ) _MeanGrad ( Operation op , Tensor grad )
95+ public static Tensor [ ] _MeanGrad ( Operation op , Tensor [ ] grads )
9396 {
94- var sum_grad = _SumGrad ( op , grad ) . Item1 ;
97+ var grad = grads [ 0 ] ;
98+ var sum_grad = _SumGrad ( op , grads ) [ 0 ] ;
9599 var input_shape = op . inputs [ 0 ] . _shape_tuple ( ) ;
96100 var output_shape = op . outputs [ 0 ] . _shape_tuple ( ) ;
97101
98102 var input_shape_tensor = array_ops . shape ( op . inputs [ 0 ] ) ;
99103 var output_shape_tensor = array_ops . shape ( op . outputs [ 0 ] ) ;
100104 var factor = _safe_shape_div ( math_ops . reduce_prod ( input_shape_tensor ) , math_ops . reduce_prod ( output_shape_tensor ) ) ;
101105
102- return ( math_ops . truediv ( sum_grad , math_ops . cast ( factor , sum_grad . dtype ) ) , null ) ;
106+ return new Tensor [ ] { math_ops . truediv ( sum_grad , math_ops . cast ( factor , sum_grad . dtype ) ) , null } ;
103107 }
104108
105109 private static Tensor _safe_shape_div ( Tensor x , Tensor y )
106110 {
107111 return math_ops . floordiv ( x , gen_math_ops . maximum ( y , 1 ) ) ;
108112 }
109113
110- public static ( Tensor , Tensor ) _SubGrad ( Operation op , Tensor grad )
114+ public static Tensor [ ] _SubGrad ( Operation op , Tensor [ ] grads )
111115 {
116+ var grad = grads [ 0 ] ;
112117 var x = op . inputs [ 0 ] ;
113118 var y = op . inputs [ 1 ] ;
114119 if ( grad is Tensor && _ShapesFullySpecifiedAndEqual ( x , y , grad ) )
115- return ( grad , - grad ) ;
120+ return new Tensor [ ] { grad , - grad } ;
116121
117122 var sx = array_ops . shape ( x ) ;
118123 var sy = array_ops . shape ( y ) ;
@@ -121,16 +126,17 @@ public static (Tensor, Tensor) _SubGrad(Operation op, Tensor grad)
121126 var r1 = gen_array_ops . reshape ( math_ops . reduce_sum ( grad , rx ) , sx ) ;
122127 var r2 = gen_array_ops . reshape ( - math_ops . reduce_sum ( grad , ry ) , sy ) ;
123128
124- return ( r1 , r2 ) ;
129+ return new Tensor [ ] { r1 , r2 } ;
125130 }
126131
127132 public static bool _ShapesFullySpecifiedAndEqual ( Tensor x , Tensor y , Tensor grad )
128133 {
129134 return x . NDims == y . NDims && y . NDims == grad . NDims && x . NDims > - 1 ;
130135 }
131136
132- public static ( Tensor , Tensor ) _SumGrad ( Operation op , Tensor grad )
137+ public static Tensor [ ] _SumGrad ( Operation op , Tensor [ ] grads )
133138 {
139+ var grad = grads [ 0 ] ;
134140 var input_0_shape = op . inputs [ 0 ] . _shape_tuple ( ) ;
135141 Tensor input_shape = null ;
136142
@@ -145,7 +151,7 @@ public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
145151 input_shape = constant_op . constant ( input_0_shape ) ;
146152 else
147153 input_shape = array_ops . shape ( op . inputs [ 0 ] ) ;
148- return ( gen_array_ops . tile ( grad , input_shape ) , null ) ;
154+ return new Tensor [ ] { gen_array_ops . tile ( grad , input_shape ) , null } ;
149155 }
150156 }
151157
@@ -155,11 +161,12 @@ public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
155161 var tile_scaling = _safe_shape_div ( input_shape , output_shape_kept_dims ) ;
156162 grad = gen_array_ops . reshape ( grad , output_shape_kept_dims ) ;
157163
158- return ( gen_array_ops . tile ( grad , tile_scaling ) , null ) ;
164+ return new Tensor [ ] { gen_array_ops . tile ( grad , tile_scaling ) , null } ;
159165 }
160166
161- public static ( Tensor , Tensor ) _RealDivGrad ( Operation op , Tensor grad )
167+ public static Tensor [ ] _RealDivGrad ( Operation op , Tensor [ ] grads )
162168 {
169+ var grad = grads [ 0 ] ;
163170 var x = op . inputs [ 0 ] ;
164171 var y = op . inputs [ 1 ] ;
165172
@@ -177,11 +184,12 @@ public static (Tensor, Tensor) _RealDivGrad(Operation op, Tensor grad)
177184 var reduce_sum2 = math_ops . reduce_sum ( realdiv3 , rx ) ;
178185 var reshape2 = gen_array_ops . reshape ( reduce_sum2 , sx ) ;
179186
180- return ( reshape2 , reshape1 ) ;
187+ return new Tensor [ ] { reshape2 , reshape1 } ;
181188 }
182189
183- public static ( Tensor , Tensor ) _PowGrad ( Operation op , Tensor grad )
190+ public static Tensor [ ] _PowGrad ( Operation op , Tensor [ ] grads )
184191 {
192+ var grad = grads [ 0 ] ;
185193 var x = op . inputs [ 0 ] ;
186194 var y = op . inputs [ 1 ] ;
187195 var z = op . outputs [ 0 ] ;
@@ -212,7 +220,7 @@ public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
212220 var reduce_sum1 = math_ops . reduce_sum ( mul1 , ry ) ;
213221 var gy = gen_array_ops . reshape ( reduce_sum1 , sy ) ;
214222
215- return ( gx , gy ) ;
223+ return new Tensor [ ] { gx , gy } ;
216224 }
217225 }
218226}
0 commit comments