11using System ;
22using System . Collections . Generic ;
3+ using System . Linq ;
34using System . Text ;
45using Tensorflow . Operations ;
56
@@ -13,16 +14,17 @@ public class nn_grad
1314 /// <param name="op"></param>
1415 /// <param name="grad"></param>
1516 /// <returns></returns>
16- public static Tensor [ ] _BiasAddGrad ( Operation op , Tensor grad )
17+ public static Tensor [ ] _BiasAddGrad ( Operation op , Tensor [ ] grads )
1718 {
19+ var grad = grads [ 0 ] ;
1820 string data_format = op . get_attr ( "data_format" ) ? . ToString ( ) ;
1921 var bias_add_grad = gen_nn_ops . bias_add_grad ( out_backprop : grad , data_format : data_format ) ;
2022 return new Tensor [ ] { grad , bias_add_grad } ;
2123 }
2224
23- public static Tensor [ ] _ReluGrad ( Operation op , Tensor grad )
25+ public static Tensor [ ] _ReluGrad ( Operation op , Tensor [ ] grads )
2426 {
25- return new Tensor [ ] { gen_nn_ops . relu_grad ( grad , op . outputs [ 0 ] ) } ;
27+ return new Tensor [ ] { gen_nn_ops . relu_grad ( grads [ 0 ] , op . outputs [ 0 ] ) } ;
2628 }
2729
2830 /// <summary>
@@ -37,8 +39,57 @@ public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[]
3739 var grad_loss = grads [ 0 ] ;
3840 var grad_grad = grads [ 1 ] ;
3941 var softmax_grad = op . outputs [ 1 ] ;
42+ var grad = _BroadcastMul ( grad_loss , softmax_grad ) ;
4043
41- throw new NotImplementedException ( "_SoftmaxCrossEntropyWithLogitsGrad" ) ;
44+ var logits = op . inputs [ 0 ] ;
45+ if ( grad_grad != null && ! IsZero ( grad_grad ) )
46+ {
47+ throw new NotImplementedException ( "_SoftmaxCrossEntropyWithLogitsGrad" ) ;
48+ }
49+
50+ return new Tensor [ ]
51+ {
52+ grad ,
53+ _BroadcastMul ( grad_loss , - nn_ops . log_softmax ( logits ) )
54+ } ;
55+ }
56+
57+ private static bool IsZero ( Tensor g )
58+ {
59+ if ( new string [ ] { "ZerosLike" , "Zeros" } . Contains ( g . op . type ) )
60+ return true ;
61+
62+ throw new NotImplementedException ( "IsZero" ) ;
63+ }
64+
65+ private static Tensor _BroadcastMul ( Tensor vec , Tensor mat )
66+ {
67+ vec = array_ops . expand_dims ( vec , - 1 ) ;
68+ return vec * mat ;
69+ }
70+
71+ /// <summary>
72+ /// Return the gradients for TopK.
73+ /// </summary>
74+ /// <param name="op"></param>
75+ /// <param name="grads"></param>
76+ /// <returns></returns>
77+ public static Tensor [ ] _TopKGrad ( Operation op , Tensor [ ] grads )
78+ {
79+ var grad = grads [ 0 ] ;
80+ var _ = grads [ 1 ] ;
81+
82+ var in_shape = array_ops . shape ( op . inputs [ 0 ] ) ;
83+ var ind_shape = array_ops . shape ( op . outputs [ 1 ] ) ;
84+
85+ // int32 is not supported on GPU hence up-casting
86+ var ind_lastdim = array_ops . gather ( math_ops . cast (
87+ ind_shape , TF_DataType . TF_INT64 ) , array_ops . size ( ind_shape ) - 1 ) ;
88+
89+ // Flatten indices to 2D.
90+ var ind_2d = array_ops . reshape ( op . outputs [ 1 ] , array_ops . stack ( new object [ ] { - 1 , ind_lastdim } ) ) ;
91+
92+ throw new NotImplementedException ( "nn_grad._TopKGrad" ) ;
4293 }
4394 }
4495}
0 commit comments