Skip to content

Commit e80a681

Browse files
committed
_ReluGrad, _BiasAddGrad
1 parent adae3aa commit e80a681

File tree

6 files changed

+106
-1
lines changed

6 files changed

+106
-1
lines changed
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Gradients
6+
{
7+
public class array_grad
8+
{
9+
public static (Tensor, Tensor) _ReshapeGrad(Operation op, Tensor grad)
10+
{
11+
return (array_ops.reshape(grad, array_ops.shape(op.inputs[0])), null);
12+
}
13+
}
14+
}

src/TensorFlowNET.Core/Gradients/math_grad.py.cs

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Linq;
44
using System.Text;
55

6-
namespace Tensorflow
6+
namespace Tensorflow.Gradients
77
{
88
/// <summary>
99
/// Gradients for operators defined in math_ops.py.
@@ -57,6 +57,38 @@ public static (Tensor, Tensor) _MulGrad(Operation op, Tensor grad)
5757
return (reshape1, reshape2);
5858
}
5959

60+
public static (Tensor, Tensor) _MatMulGrad(Operation op, Tensor grad)
61+
{
62+
Tensor grad_a = null, grad_b = null;
63+
64+
var t_a = (bool)op.get_attr("transpose_a");
65+
var t_b = (bool)op.get_attr("transpose_b");
66+
var a = math_ops.conj(op.inputs[0]);
67+
var b = math_ops.conj(op.inputs[1]);
68+
if(!t_a && !t_b)
69+
{
70+
grad_a = gen_math_ops.mat_mul(grad, b, transpose_b: true);
71+
grad_b = gen_math_ops.mat_mul(a, grad, transpose_a: true);
72+
}
73+
else if (!t_a && t_b)
74+
{
75+
grad_a = gen_math_ops.mat_mul(grad, b);
76+
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true);
77+
}
78+
else if (t_a && !t_b)
79+
{
80+
grad_a = gen_math_ops.mat_mul(grad, b);
81+
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true);
82+
}
83+
else if (t_a && t_b)
84+
{
85+
grad_a = gen_math_ops.mat_mul(b, grad, transpose_a: true, transpose_b: true);
86+
grad_b = gen_math_ops.mat_mul(grad, a, transpose_a: true, transpose_b: true);
87+
}
88+
89+
return (grad_a, grad_b);
90+
}
91+
6092
public static (Tensor, Tensor) _MeanGrad(Operation op, Tensor grad)
6193
{
6294
var sum_grad = _SumGrad(op, grad).Item1;
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Operations;
5+
6+
namespace Tensorflow.Gradients
7+
{
8+
public class nn_grad
9+
{
10+
/// <summary>
11+
/// Return the gradients for the 2 inputs of bias_op.
12+
/// </summary>
13+
/// <param name="op"></param>
14+
/// <param name="grad"></param>
15+
/// <returns></returns>
16+
public static (Tensor, Tensor) _BiasAddGrad(Operation op, Tensor grad)
17+
{
18+
string data_format = op.get_attr("data_format")?.ToString();
19+
var bias_add_grad = gen_nn_ops.bias_add_grad(out_backprop: grad, data_format: data_format);
20+
return (grad, bias_add_grad);
21+
}
22+
23+
public static (Tensor, Tensor) _ReluGrad(Operation op, Tensor grad)
24+
{
25+
return (gen_nn_ops.relu_grad(grad, op.outputs[0]), null);
26+
}
27+
}
28+
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,22 @@ public static Tensor bias_add(Tensor value,
5353
return _op.outputs[0];
5454
}
5555

56+
public static Tensor bias_add_grad(Tensor out_backprop,
57+
string data_format = "NHWC",
58+
string name = null)
59+
{
60+
if (data_format == null)
61+
data_format = "NHWC";
62+
63+
var _op = _op_def_lib._apply_op_helper("BiasAddGrad", name: name, args: new
64+
{
65+
out_backprop,
66+
data_format
67+
});
68+
69+
return _op.outputs[0];
70+
}
71+
5672
public static Tensor[] _fused_batch_norm(Tensor x,
5773
Tensor scale,
5874
Tensor offset,
@@ -109,6 +125,17 @@ public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string n
109125
return _op.outputs;
110126
}
111127

128+
public static Tensor relu_grad(Tensor gradients, Tensor features, string name = null)
129+
{
130+
var _op = _op_def_lib._apply_op_helper("ReluGrad", name: name, args: new
131+
{
132+
gradients,
133+
features
134+
});
135+
136+
return _op.outputs[0];
137+
}
138+
112139
/// <summary>
113140
/// Computes softmax cross entropy cost and gradients to backpropagate.
114141
/// </summary>

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Linq;
1010
using NumSharp.Core;
1111
using System.ComponentModel;
12+
using Tensorflow.Gradients;
1213

1314
namespace Tensorflow
1415
{
@@ -380,6 +381,9 @@ public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation
380381
case "RealDiv":
381382
var realdiv = math_grad._RealDivGrad(oper, out_grads);
382383
return new Tensor[] { realdiv.Item1, realdiv.Item2 };
384+
case "Reshape":
385+
var reshape = array_grad._ReshapeGrad(oper, out_grads);
386+
return new Tensor[] { reshape.Item1, reshape.Item2 };
383387
default:
384388
throw new NotImplementedException($"get_gradient_function {oper.type}");
385389
}

0 commit comments

Comments
 (0)