Skip to content

Commit 6509ae0

Browse files
committed
Fix datatype for slice.
1 parent 615105c commit 6509ae0

File tree

8 files changed

+116
-104
lines changed

8 files changed

+116
-104
lines changed

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -293,20 +293,24 @@ public static Tensor[] _StridedSliceGrad(Operation op, Tensor[] grads)
293293
var strides = op.inputs[3];
294294

295295
var x = array_ops.shape(op.inputs[0], out_type: begin.dtype);
296+
var x_static = tensor_util.constant_value(x);
297+
var begin_static = tensor_util.constant_value(begin);
298+
var end_static = tensor_util.constant_value(end);
299+
var strides_static = tensor_util.constant_value(strides);
296300

297301
return new Tensor[]
298302
{
299-
gen_array_ops.strided_slice_grad(
300-
x,
301-
begin,
302-
end,
303-
strides,
303+
array_ops.strided_slice_grad(
304+
x_static,
305+
begin_static,
306+
end_static,
307+
strides_static,
304308
grad,
305-
begin_mask: int.Parse(op.get_attr("begin_mask").ToString()),
306-
end_mask: int.Parse(op.get_attr("end_mask").ToString()),
307-
ellipsis_mask: int.Parse(op.get_attr("ellipsis_mask").ToString()),
308-
new_axis_mask: int.Parse(op.get_attr("new_axis_mask").ToString()),
309-
shrink_axis_mask: int.Parse(op.get_attr("shrink_axis_mask").ToString())),
309+
begin_mask: op.get_attr<long>("begin_mask"),
310+
end_mask: op.get_attr<long>("end_mask"),
311+
ellipsis_mask: op.get_attr<long>("ellipsis_mask"),
312+
new_axis_mask: op.get_attr<long>("new_axis_mask"),
313+
shrink_axis_mask: op.get_attr<long>("shrink_axis_mask")),
310314
null,
311315
null,
312316
null
@@ -331,11 +335,11 @@ public static Tensor[] _StridedSliceGradGrad(Operation op, Tensor[] grads)
331335
begin,
332336
end,
333337
strides,
334-
begin_mask: (int)op.get_attr("begin_mask"),
335-
end_mask: (int)op.get_attr("end_mask"),
336-
ellipsis_mask: (int)op.get_attr("ellipsis_mask"),
337-
new_axis_mask: (int)op.get_attr("new_axis_mask"),
338-
shrink_axis_mask: (int)op.get_attr("shrink_axis_mask"))
338+
begin_mask: op.get_attr<long>("begin_mask"),
339+
end_mask: op.get_attr<long>("end_mask"),
340+
ellipsis_mask: op.get_attr<long>("ellipsis_mask"),
341+
new_axis_mask: op.get_attr<long>("new_axis_mask"),
342+
shrink_axis_mask: op.get_attr<long>("shrink_axis_mask"))
339343
};
340344
}
341345

src/TensorFlowNET.Core/Operations/OpDefLibrary.cs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,12 +276,16 @@ private void SetAttrs(string op_type_name,
276276
}
277277
else
278278
{
279-
attrs[input_arg.NumberAttr] = (values as Tensor[]).Length;
280-
inferred_from[input_arg.NumberAttr] = input_name;
281-
var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr);
282-
if (num_attr.HasMinimum && (values as Tensor[]).Length < num_attr.Minimum)
283-
throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " +
284-
$"than minimum length {num_attr.Minimum}");
279+
if(values is Tensor[] tensors)
280+
{
281+
var num_attr = op_def.Attr.First(x => x.Name == input_arg.NumberAttr);
282+
if (num_attr.HasMinimum && tensors.Length < num_attr.Minimum)
283+
throw new ValueError($"List argument '{input_name}' to '{op_type_name}' Op with length {(values as Tensor[]).Length} shorter " +
284+
$"than minimum length {num_attr.Minimum}");
285+
286+
attrs[input_arg.NumberAttr] = Convert.ToInt64(tensors.Length);
287+
inferred_from[input_arg.NumberAttr] = input_name;
288+
}
285289
}
286290

287291
// All tensors must have the same base type.
@@ -378,7 +382,10 @@ private AttrValue SetAttrValue(OpDef op_def, AttrDef attr_def, object value)
378382
attr_value.F = (float)value;
379383
break;
380384
case "int":
381-
attr_value.I = (int)value;
385+
if (value is long value_long)
386+
attr_value.I = value_long;
387+
else
388+
attr_value.I = Convert.ToInt64(value);
382389
if (attr_def.HasMinimum && attr_value.I < attr_def.Minimum)
383390
throw new ValueError($"Attr '{attr_def.Name}' of '{op_def.Name}' Op passed {attr_value.I} less than minimum {attr_def.Minimum}.");
384391
break;

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -242,16 +242,17 @@ public virtual object get_attr(string name)
242242
if (string.IsNullOrEmpty(oneof_value))
243243
return null;
244244

245-
if (oneof_value == "list")
246-
throw new NotImplementedException($"Unsupported field type in {x.ToString()}");
247-
248-
if (string.Equals("type", oneof_value, StringComparison.OrdinalIgnoreCase))
249-
return x.Type;
250-
251-
object result = x.GetType().GetProperty(oneof_value).GetValue(x);
252-
if (result is Google.Protobuf.ByteString byteString)
253-
return byteString.ToStringUtf8();
254-
return result;
245+
switch (oneof_value.ToLower())
246+
{
247+
case "list":
248+
throw new NotImplementedException($"Unsupported field type in {oneof_value}");
249+
case "type":
250+
return x.Type;
251+
case "s":
252+
return x.S.ToStringUtf8();
253+
default:
254+
return x.GetType().GetProperty(oneof_value).GetValue(x);
255+
}
255256
}
256257

257258
public TF_AttrMetadata GetAttributeMetadata(string attr_name, Status s)

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public static Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF
122122
case TF_DataType.TF_FLOAT:
123123
return _constant_if_small(0.0F, shape, dtype, name);
124124
case TF_DataType.TF_INT64:
125-
return _constant_if_small(0l, shape, dtype, name);
125+
return _constant_if_small(0L, shape, dtype, name);
126126
case TF_DataType.TF_INT32:
127127
return _constant_if_small(0, shape, dtype, name);
128128
case TF_DataType.TF_INT8:
@@ -671,6 +671,68 @@ public static Tensor strided_slice(Tensor input_, Tensor begin, Tensor end,
671671
return op;
672672
}
673673

674+
/// <summary>
675+
/// Returns the gradient of `StridedSlice`.
676+
///
677+
/// Since `StridedSlice` cuts out pieces of its `input` which is size
678+
/// `shape`, its gradient will have the same shape (which is passed here
679+
/// as `shape`). The gradient will be zero in any element that the slice
680+
/// does not select.
681+
/// </summary>
682+
/// <param name="shape">Must be one of the following types: `int32`, `int64`.</param>
683+
/// <param name="begin">Must have the same type as `shape`.</param>
684+
/// <param name="end">Must have the same type as `shape`.</param>
685+
/// <param name="strides">Must have the same type as `shape`.</param>
686+
/// <param name="dy">A `Tensor`.</param>
687+
/// <param name="begin_mask">An optional `int`. Defaults to `0`.</param>
688+
/// <param name="end_mask">An optional `int`. Defaults to `0`.</param>
689+
/// <param name="ellipsis_mask">An optional `int`. Defaults to `0`.</param>
690+
/// <param name="new_axis_mask">An optional `int`. Defaults to `0`.</param>
691+
/// <param name="shrink_axis_mask">An optional `int`. Defaults to `0`.</param>
692+
/// <param name="name">A name for the operation (optional).</param>
693+
/// <returns>A `Tensor`. Has the same type as `dy`.</returns>
694+
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
695+
long begin_mask = 0, long end_mask = 0, long ellipsis_mask = 0, long new_axis_mask = 0,
696+
long shrink_axis_mask = 0, string name = null)
697+
=> tf.Context.RunInAutoMode2(
698+
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
699+
{
700+
shape,
701+
begin,
702+
end,
703+
strides,
704+
dy,
705+
begin_mask,
706+
end_mask,
707+
ellipsis_mask,
708+
new_axis_mask,
709+
shrink_axis_mask
710+
}).output,
711+
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
712+
"StridedSliceGrad", name,
713+
null,
714+
shape, begin, end, strides, dy,
715+
"begin_mask", begin_mask,
716+
"end_mask", end_mask,
717+
"ellipsis_mask", ellipsis_mask,
718+
"new_axis_mask", new_axis_mask,
719+
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
720+
(op) =>
721+
{
722+
var attrs = new object[]
723+
{
724+
"T", op.get_attr<TF_DataType>("T"),
725+
"Index", op.get_attr<TF_DataType>("Index"),
726+
"begin_mask", op.get_attr<long>("begin_mask"),
727+
"end_mask", op.get_attr<long>("end_mask"),
728+
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"),
729+
"new_axis_mask", op.get_attr<long>("new_axis_mask"),
730+
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask")
731+
};
732+
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs);
733+
},
734+
new Tensors(shape, begin, end, strides, dy));
735+
674736
/// <summary>
675737
/// Removes dimensions of size 1 from the shape of a tensor.
676738
/// Given a tensor `input`, this operation returns a tensor of the same type with

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public static Tensor _NextIteration(Tensor data, string name = null)
6565
return gen_control_flow_ops.next_iteration(data, name: name);
6666
}
6767

68-
public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null)
68+
public static Operation Assert(Tensor condition, object[] data, long? summarize = null, string name = null)
6969
{
7070
if (tf.executing_eagerly())
7171
{

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 5 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,11 @@ public static Tensor stop_gradient(Tensor x, string name = null)
578578
}
579579

580580
public static Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides,
581-
int begin_mask = 0,
582-
int end_mask = 0,
583-
int ellipsis_mask = 0,
584-
int new_axis_mask = 0,
585-
int shrink_axis_mask = 0,
581+
long begin_mask = 0,
582+
long end_mask = 0,
583+
long ellipsis_mask = 0,
584+
long new_axis_mask = 0,
585+
long shrink_axis_mask = 0,
586586
string name = null)
587587
=> tf.Context.RunInAutoMode(()
588588
=> tf.OpDefLib._apply_op_helper("StridedSlice", name, new
@@ -656,68 +656,6 @@ public static Tensor strided_slice<T>(Tensor input, T[] begin, T[] end, T[] stri
656656
return _op.outputs[0];
657657
}
658658

659-
/// <summary>
660-
/// Returns the gradient of `StridedSlice`.
661-
///
662-
/// Since `StridedSlice` cuts out pieces of its `input` which is size
663-
/// `shape`, its gradient will have the same shape (which is passed here
664-
/// as `shape`). The gradient will be zero in any element that the slice
665-
/// does not select.
666-
/// </summary>
667-
/// <param name="shape">Must be one of the following types: `int32`, `int64`.</param>
668-
/// <param name="begin">Must have the same type as `shape`.</param>
669-
/// <param name="end">Must have the same type as `shape`.</param>
670-
/// <param name="strides">Must have the same type as `shape`.</param>
671-
/// <param name="dy">A `Tensor`.</param>
672-
/// <param name="begin_mask">An optional `int`. Defaults to `0`.</param>
673-
/// <param name="end_mask">An optional `int`. Defaults to `0`.</param>
674-
/// <param name="ellipsis_mask">An optional `int`. Defaults to `0`.</param>
675-
/// <param name="new_axis_mask">An optional `int`. Defaults to `0`.</param>
676-
/// <param name="shrink_axis_mask">An optional `int`. Defaults to `0`.</param>
677-
/// <param name="name">A name for the operation (optional).</param>
678-
/// <returns>A `Tensor`. Has the same type as `dy`.</returns>
679-
public static Tensor strided_slice_grad(Tensor shape, Tensor begin, Tensor end, Tensor strides, Tensor dy,
680-
int begin_mask = 0, int end_mask = 0, int ellipsis_mask = 0, int new_axis_mask = 0,
681-
int shrink_axis_mask = 0, string name = null)
682-
=> tf.Context.RunInAutoMode2(
683-
() => tf.OpDefLib._apply_op_helper("StridedSliceGrad", name, new
684-
{
685-
shape,
686-
begin,
687-
end,
688-
strides,
689-
dy,
690-
begin_mask,
691-
end_mask,
692-
ellipsis_mask,
693-
new_axis_mask,
694-
shrink_axis_mask
695-
}).output,
696-
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
697-
"StridedSliceGrad", name,
698-
null,
699-
shape, begin, end, strides, dy,
700-
"begin_mask", begin_mask,
701-
"end_mask", end_mask,
702-
"ellipsis_mask", ellipsis_mask,
703-
"new_axis_mask", new_axis_mask,
704-
"shrink_axis_mask", shrink_axis_mask).FirstOrDefault(),
705-
(op) =>
706-
{
707-
var attrs = new object[]
708-
{
709-
"T", op.get_attr<TF_DataType>("T"),
710-
"Index", op.get_attr<TF_DataType>("Index"),
711-
"begin_mask", op.get_attr<long>("begin_mask"),
712-
"end_mask", op.get_attr<long>("end_mask"),
713-
"ellipsis_mask", op.get_attr<long>("ellipsis_mask"),
714-
"new_axis_mask", op.get_attr<long>("new_axis_mask"),
715-
"shrink_axis_mask", op.get_attr<long>("shrink_axis_mask")
716-
};
717-
tf.Runner.RecordGradient("StridedSliceGrad", op.inputs, attrs, op.outputs);
718-
},
719-
new Tensors(shape, begin, end, strides, dy));
720-
721659
/// <summary>
722660
/// Removes dimensions of size 1 from the shape of a tensor.
723661
/// Given a tensor `input`, this operation returns a tensor of the same type with

src/TensorFlowNET.Core/Operations/gen_image_ops.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ public static Tensor convert_image_dtype(Tensor image, TF_DataType dtype, bool s
6363
}
6464

6565
public static Tensor decode_jpeg(Tensor contents,
66-
int channels = 0,
67-
int ratio = 1,
66+
long channels = 0,
67+
long ratio = 1,
6868
bool fancy_upscaling = true,
6969
bool try_recover_truncated = false,
7070
float acceptable_fraction = 1,

src/TensorFlowNET.Core/Operations/gen_logging_ops.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace Tensorflow
2121
{
2222
public class gen_logging_ops
2323
{
24-
public static Operation _assert(Tensor condition, object[] data, int? summarize = 3, string name = null)
24+
public static Operation _assert(Tensor condition, object[] data, long? summarize = 3, string name = null)
2525
{
2626
if (!summarize.HasValue)
2727
summarize = 3;

0 commit comments

Comments
 (0)