|
1 | | -using System; |
2 | | -using System.Collections.Generic; |
3 | | -using System.Text; |
4 | | -using static Tensorflow.Binding; |
5 | | -using static Tensorflow.KerasApi; |
| 1 | +namespace Tensorflow.Keras.Losses; |
6 | 2 |
|
7 | | -namespace Tensorflow.Keras.Losses |
| 3 | +public class Huber : LossFunctionWrapper |
8 | 4 | { |
9 | | - public class Huber : LossFunctionWrapper, ILossFunc |
| 5 | + protected Tensor delta = tf.Variable(1.0); |
| 6 | + |
| 7 | + public Huber( |
| 8 | + string reduction = null, |
| 9 | + Tensor delta = null, |
| 10 | + string name = null) : |
| 11 | + base(reduction: reduction, name: name == null ? "huber" : name) |
10 | 12 | { |
11 | | - protected Tensor delta = tf.Variable(1.0) ; |
12 | | - public Huber ( |
13 | | - string reduction = null, |
14 | | - Tensor delta = null, |
15 | | - string name = null) : |
16 | | - base(reduction: reduction, name: name == null ? "huber" : name) |
17 | | - { |
18 | | - this.delta = delta==null? this.delta: delta; |
19 | | - |
20 | | - } |
| 13 | + this.delta = delta == null ? this.delta : delta; |
| 14 | + } |
21 | 15 |
|
22 | | - public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) |
23 | | - { |
24 | | - Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); |
25 | | - Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); |
26 | | - Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); |
27 | | - Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); |
28 | | - Tensor abs_error = math_ops.abs(error); |
29 | | - Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); |
30 | | - return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, |
31 | | - half * math_ops.pow(error, 2), |
32 | | - half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), |
33 | | - ops.convert_to_tensor(-1)); |
34 | | - } |
| 16 | + public override Tensor Apply(Tensor y_true = null, Tensor y_pred = null, bool from_logits = false, int axis = -1) |
| 17 | + { |
| 18 | + Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); |
| 19 | + Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); |
| 20 | + Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); |
| 21 | + Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); |
| 22 | + Tensor abs_error = math_ops.abs(error); |
| 23 | + Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); |
| 24 | + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, |
| 25 | + half * math_ops.pow(error, 2), |
| 26 | + half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), |
| 27 | + ops.convert_to_tensor(-1)); |
35 | 28 | } |
36 | 29 | } |
0 commit comments