Skip to content

Commit 2ee6513

Browse files
committed
fix shape issue for array_ops.size_internal
1 parent bad610d commit 2ee6513

File tree

6 files changed

+66
-39
lines changed

6 files changed

+66
-39
lines changed

src/TensorFlowNET.Core/Gradients/nn_grad.py.cs

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,24 @@ public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
8383
var ind_shape = array_ops.shape(op.outputs[1]);
8484

8585
// int32 is not supported on GPU hence up-casting
86-
var ind_lastdim = array_ops.gather(math_ops.cast(
87-
ind_shape, TF_DataType.TF_INT64), array_ops.size(ind_shape) - 1);
86+
var cast = math_ops.cast(ind_shape, TF_DataType.TF_INT64);
87+
var size = array_ops.size(ind_shape) - 1;
88+
var ind_lastdim = array_ops.gather(cast, size);
8889

8990
// Flatten indices to 2D.
90-
var ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack(new object[] { -1, ind_lastdim }));
91+
var stack = array_ops.stack(new object[] { -1L, ind_lastdim });
92+
var ind_2d = array_ops.reshape(op.outputs[1], stack);
93+
94+
var in_lastdim = array_ops.gather(math_ops.cast(in_shape, TF_DataType.TF_INT64),
95+
array_ops.size(in_shape) - 1);
96+
var outerdim = array_ops.shape(ind_2d);
97+
98+
// Compute linear indices(flattened to 1D).
99+
var cast1 = math_ops.cast(outerdim, TF_DataType.TF_INT64);
100+
var range2 = math_ops.range(0L, cast1 * in_lastdim, in_lastdim);
101+
var dim2 = array_ops.expand_dims(range2, -1);
102+
var cast2 = math_ops.cast(dim2, TF_DataType.TF_INT32);
103+
var ind = array_ops.reshape(ind_2d + cast2, new int[] { -1 });
91104

92105
throw new NotImplementedException("nn_grad._TopKGrad");
93106
}

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 36 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,35 @@ private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dt
4646
}
4747
}
4848

49+
public static Tensor _autopacking_conversion_function(object[] v, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
50+
{
51+
var inferred_dtype = _get_dtype_from_nested_lists(v);
52+
if (dtype == TF_DataType.DtInvalid)
53+
dtype = inferred_dtype;
54+
55+
return _autopacking_helper(v, dtype, name == null ? "packed" : name);
56+
}
57+
58+
private static TF_DataType _get_dtype_from_nested_lists(object[] list_or_tuple)
59+
{
60+
TF_DataType dtype = TF_DataType.DtInvalid;
61+
62+
foreach(var obj in list_or_tuple)
63+
{
64+
switch (obj)
65+
{
66+
case Tensor t:
67+
dtype = t.dtype.as_base_dtype();
68+
break;
69+
}
70+
71+
if (dtype != TF_DataType.DtInvalid)
72+
break;
73+
}
74+
75+
return dtype;
76+
}
77+
4978
public static Tensor _autopacking_helper(object[] list_or_tuple, TF_DataType dtype, string name)
5079
{
5180
var must_pack = false;
@@ -242,32 +271,21 @@ private static Tensor shape_internal(Tensor input, string name = null, bool opti
242271

243272
private static Tensor size_internal(Tensor input, string name = null, bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
244273
{
245-
return with(ops.name_scope(name, "Size", new Tensor[] { input }), scope =>
274+
return with(ops.name_scope(name, "Size", new { input }), scope =>
246275
{
247276
name = scope;
248277

249-
if (!tf.context.executing_eagerly())
278+
var input_tensor = ops.convert_to_tensor(input);
279+
var input_shape = tensor_util.to_shape(input_tensor.shape);
280+
if (optimize)
250281
{
251-
var input_tensor = ops.convert_to_tensor(input);
252-
var input_shape = tensor_util.to_shape(input_tensor.shape);
253-
if (optimize)
282+
if (input_shape.is_fully_defined())
254283
{
255-
if (input_shape.is_fully_defined())
256-
{
257-
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
258-
return constant_op.constant(nd, name: name);
259-
}
284+
return constant_op.constant(input_shape.Size, dtype: out_type, name: name);
260285
}
261-
262-
return gen_array_ops.size(input, name: name, out_type: out_type);
263-
}
264-
else
265-
{
266-
// result = gen_array_ops.shape();
267-
throw new NotImplementedException("array_ops.size_internal");
268286
}
269287

270-
return null;
288+
return gen_array_ops.size(input, name: name, out_type: out_type);
271289
});
272290
}
273291

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s
1818

1919
return with(ops.name_scope(name, "Cast", new { x }), scope =>
2020
{
21+
name = scope;
2122
x = ops.convert_to_tensor(x, name: "x");
2223
if (x.dtype.as_base_dtype() != base_type)
2324
x = gen_math_ops.cast(x, base_type, name: name);
@@ -263,7 +264,7 @@ public static Tensor range(object start, object limit = null, object delta = nul
263264
if (delta == null)
264265
delta = 1;
265266

266-
return with(ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
267+
return with(ops.name_scope(name, "Range", new { start, limit, delta }), scope =>
267268
{
268269
name = scope;
269270
var start1 = ops.convert_to_tensor(start, name: "start");

src/TensorFlowNET.Core/Tensors/dtypes.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ public static TF_DataType as_dtype(Type type)
3434
case "Int32":
3535
dtype = TF_DataType.TF_INT32;
3636
break;
37+
case "Int64":
38+
dtype = TF_DataType.TF_INT64;
39+
break;
3740
case "Single":
3841
dtype = TF_DataType.TF_FLOAT;
3942
break;
@@ -47,7 +50,7 @@ public static TF_DataType as_dtype(Type type)
4750
dtype = TF_DataType.TF_STRING;
4851
break;
4952
default:
50-
throw new Exception("Not Implemented");
53+
throw new Exception($"{type.Name} Not Implemented in as_dtype");
5154
}
5255

5356
return dtype;

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
111111
case int intVal:
112112
nparray = intVal;
113113
break;
114+
case long intVal:
115+
nparray = intVal;
116+
break;
114117
case int[] intVals:
115118
nparray = np.array(intVals);
116119
break;
@@ -231,6 +234,9 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
231234
case "Int32":
232235
tensor_proto.IntVal.AddRange(proto_values.Data<int>());
233236
break;
237+
case "Int64":
238+
tensor_proto.Int64Val.AddRange(proto_values.Data<long>());
239+
break;
234240
case "Single":
235241
tensor_proto.FloatVal.AddRange(proto_values.Data<float>());
236242
break;

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -410,26 +410,12 @@ public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype
410410
return tensor;
411411
case Tensor[] tensors:
412412
return array_ops._autopacking_helper(tensors, dtype, name);
413-
case string str:
414-
return constant_op.constant(str, dtype: dtype, name: name);
415-
case string[] strArray:
416-
return constant_op.constant(strArray, dtype: dtype, name: name);
417-
case int intVal:
418-
return constant_op.constant(intVal, dtype: dtype, name: name);
419-
case int[] intArray:
420-
return constant_op.constant(intArray, dtype: dtype, name: name);
421-
case float floatVal:
422-
return constant_op.constant(floatVal, dtype: dtype, name: name);
423-
case float[] floatArray:
424-
return constant_op.constant(floatArray, dtype: dtype, name: name);
425-
case double doubleVal:
426-
return constant_op.constant(doubleVal, dtype: dtype, name: name);
427413
case RefVariable varVal:
428414
return varVal._TensorConversionFunction(as_ref: as_ref);
429415
case object[] objects:
430-
return array_ops._autopacking_helper(objects, dtype: dtype, name: name);
416+
return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name);
431417
default:
432-
throw new NotImplementedException($"internal_convert_to_tensor: Can't convert {value.GetType().Name} to Tensor");
418+
return constant_op.constant(value, dtype: dtype, name: name);
433419
}
434420
}
435421

0 commit comments

Comments
 (0)