Skip to content

Commit d03d7b5

Browse files
committed
Merge branch 'master' into tf.keras-0.3.image-classification
2 parents 8c0feae + 0ff31a3 commit d03d7b5

File tree

7 files changed

+116
-15
lines changed

7 files changed

+116
-15
lines changed

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Tensor where<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null)
152152
/// <param name="name"></param>
153153
/// <param name="conjugate"></param>
154154
/// <returns></returns>
155-
public Tensor transpose<T1>(T1 a, int[] perm = null, string name = "transpose", bool conjugate = false)
155+
public Tensor transpose<T1>(T1 a, TensorShape perm = null, string name = "transpose", bool conjugate = false)
156156
=> array_ops.transpose(a, perm, name, conjugate);
157157

158158
/// <summary>

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,22 @@ public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null,
779779
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
780780
}
781781

782-
public static Tensor transpose<T1, T2>(T1 a, T2 perm, string name = "transpose", bool conjugate = false)
782+
public static Tensor transpose<T1>(T1 a, TensorShape perm, string name = "transpose", bool conjugate = false)
783+
{
784+
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
785+
{
786+
var a_tensor = ops.convert_to_tensor(a);
787+
if(perm == null)
788+
{
789+
var rank = a_tensor.rank;
790+
perm = range(0, rank).OrderByDescending(x => x).ToArray();
791+
}
792+
793+
return gen_array_ops.transpose(a_tensor, perm, name: scope);
794+
});
795+
}
796+
797+
public static Tensor transpose(Tensor a, Tensor perm, string name = "transpose", bool conjugate = false)
783798
{
784799
return tf_with(ops.name_scope(name, "transpose", new { a }), scope =>
785800
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ public static Tensor tile<T>(Tensor input, T multiples, string name = null)
531531
input, multiples).FirstOrDefault(),
532532
input);
533533

534-
public static Tensor transpose<T1, T2>(T1 x, T2 perm, string name = null)
534+
public static Tensor transpose<T1>(Tensor x, T1 perm, string name = null)
535535
{
536536
if (tf.Context.executing_eagerly())
537537
{

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*****************************************************************************
1+
/*****************************************************************************
22
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -619,6 +619,16 @@ public static Tensor logical_xor(Tensor x, Tensor y, string name = "LogicalXor")
619619

620620
public static Tensor squared_difference(Tensor x, Tensor y, string name = null)
621621
{
622+
if (tf.Context.executing_eagerly())
623+
{
624+
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
625+
"SquaredDifference", name,
626+
null,
627+
x,y);
628+
629+
return results[0];
630+
}
631+
622632
var _op = tf.OpDefLib._apply_op_helper("SquaredDifference", name, args: new { x, y, name });
623633

624634
return _op.outputs[0];
@@ -1210,4 +1220,4 @@ public static Tensor zero_fraction(Tensor value, string name = null)
12101220
return _op.outputs[0];
12111221
}
12121222
}
1213-
}
1223+
}

src/TensorFlowNET.Keras/Utils/losses_utils.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*****************************************************************************
1+
/*****************************************************************************
22
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,7 +25,7 @@ public class losses_utils
2525
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null)
2626
{
2727
if (sample_weight == null)
28-
sample_weight = tf.constant(1.0f);
28+
sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f);
2929
var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight);
3030
// Apply reduction function to the individual weighted losses.
3131
var loss = reduce_weighted_loss(weighted_losses, reduction);

test/TensorFlowNET.UnitTest/Basics/VariableTest.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,26 @@ public void SliceAssign()
8383
Assert.AreEqual(nd[2], x[2].numpy());
8484
}
8585

86+
[TestMethod, Ignore]
87+
public void TypeMismatchedSliceAssign()
88+
{
89+
NDArray intNd = new int[]
90+
{
91+
1, -2, 3
92+
};
93+
NDArray doubleNd = new double[]
94+
{
95+
-5, 6, -7
96+
};
97+
var x = tf.Variable(doubleNd);
98+
99+
var slice = x[":"];
100+
Assert.ThrowsException<System.Exception>(
101+
// this statement exit without throwing any exception but the "test execution summary" seems not able to detect that.
102+
() => slice.assign(intNd)
103+
);
104+
}
105+
86106
[TestMethod]
87107
public void Accumulation()
88108
{

test/TensorFlowNET.UnitTest/ManagedAPI/TensorOperate.cs

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,70 @@ public class TensorOperate
1111
[TestMethod]
1212
public void TransposeTest()
1313
{
14-
var a = tf.constant(np.array(new[, , ,] { { { { 1, 11, 2, 22 } }, { { 3, 33, 4, 44 } } },
15-
{ { { 5, 55, 6, 66 } }, { { 7, 77, 8, 88 } } } }));
16-
var b = tf.transpose(a, new[] { 3, 1, 2, 0 });
17-
var transpose_a = tf.constant(np.array(new[, , ,] { { { { 1, 5 } }, { { 3, 7 } } },
18-
{ { { 11, 55 } }, { { 33, 77 } } }, { { { 2, 6 } }, { { 4, 8 } } },
19-
{ { { 22, 66 } }, { { 44, 88 } } } }));
20-
Assert.IsTrue(Enumerable.SequenceEqual(new[] { 4, 2, 1, 2 }, b.shape));
21-
Assert.IsTrue(Enumerable.SequenceEqual(transpose_a.numpy().ToArray<int>(), b.numpy().ToArray<int>()));
14+
// https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2
15+
var x = tf.constant(new int[,]
16+
{
17+
{ 1, 2, 3 },
18+
{ 4, 5, 6 }
19+
});
20+
var transpose_x = tf.transpose(x);
21+
Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy());
22+
Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy());
23+
Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy());
24+
25+
#region constant a
26+
var a = tf.constant(np.array(new[, , ,]
27+
{
28+
{
29+
{
30+
{ 1, 11, 2, 22 }
31+
},
32+
{
33+
{ 3, 33, 4, 44 }
34+
}
35+
},
36+
{
37+
{
38+
{ 5, 55, 6, 66 }
39+
},
40+
{
41+
{ 7, 77, 8, 88 }
42+
}
43+
}
44+
}));
45+
46+
#endregion
47+
var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 });
48+
49+
#region constant transpose_a
50+
var expected_transposed_a = tf.constant(np.array(new[, , ,]
51+
{
52+
{
53+
{ { 1, 5 } }, { { 3, 7 } }
54+
},
55+
{
56+
{ { 11, 55 } }, { { 33, 77 } }
57+
},
58+
{
59+
{
60+
{ 2, 6 }
61+
},
62+
{
63+
{ 4, 8 }
64+
}
65+
},
66+
{
67+
{
68+
{ 22, 66 }
69+
},
70+
{
71+
{ 44, 88 }
72+
}
73+
}
74+
}));
75+
#endregion
76+
Assert.AreEqual((4, 2, 1, 2 ), actual_transposed_a.TensorShape);
77+
Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy());
2278
}
2379

2480
[TestMethod]

0 commit comments

Comments
 (0)