Skip to content

Commit 10f6819

Browse files
committed
Merge branch 'master' into ndarrayload
# Conflicts: # src/TensorFlowNET.Core/Tensorflow.Binding.csproj # src/TensorFlowNET.Keras/Datasets/Imdb.cs
2 parents ea978bb + 70d681c commit 10f6819

File tree

276 files changed

+14477
-1544
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

276 files changed

+14477
-1544
lines changed

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using System;
1818
using System.Runtime.InteropServices;
19+
using static Tensorflow.CppShapeInferenceResult.Types;
1920

2021
namespace Tensorflow
2122
{
@@ -50,6 +51,35 @@ public static string StringPiece(IntPtr handle)
5051
return handle == IntPtr.Zero ? String.Empty : Marshal.PtrToStringAnsi(handle);
5152
}
5253

54+
public unsafe static byte[] ByteStringPiece(Buffer? handle)
55+
{
56+
if (handle is null)
57+
{
58+
return new byte[0];
59+
}
60+
var data = handle.ToArray();
61+
return data;
62+
}
63+
64+
public unsafe static byte[] ByteStringPieceFromNativeString(IntPtr handle)
65+
{
66+
if (handle == IntPtr.Zero)
67+
{
68+
return new byte[0];
69+
}
70+
71+
byte* str_data = (byte*)handle.ToPointer();
72+
List<byte> bytes = new List<byte>();
73+
byte current = 255;
74+
while (current != ((byte)'\0'))
75+
{
76+
current = *(str_data++);
77+
bytes.Add(current);
78+
}
79+
var data = bytes.ToArray();
80+
return data;
81+
}
82+
5383
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
5484
public delegate void Deallocator(IntPtr data, IntPtr size, ref DeallocatorArgs args);
5585

src/TensorFlowNET.Core/APIs/c_api.customize.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public partial class c_api
1010
[DllImport(TensorFlowLibName)]
1111
public static extern void TFC_SetAttr(SafeGraphHandle graph, IntPtr op, string attr_name, SafeBufferHandle attr_value_proto, SafeStatusHandle status);
1212
[DllImport(TensorFlowLibName)]
13-
public static extern IntPtr TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
13+
public static extern SafeBufferHandle TFC_GetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output);
1414
[DllImport(TensorFlowLibName)]
1515
public static extern void TFC_SetHandleShapeAndType(SafeGraphHandle c_graph, TF_Output output, byte[] data, long proto_len, SafeStatusHandle status);
1616
}

src/TensorFlowNET.Core/APIs/tf.array.cs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ public Tensor concat(IEnumerable<Tensor> values, int axis, string name = "concat
9191
return identity(values.First(), name: scope);
9292
});
9393
}
94-
95-
return gen_array_ops.concat_v2(values.ToArray(), ops.convert_to_tensor(axis), name: name);
94+
return array_ops.concat(values.ToArray(), axis, name: name);
9695
}
9796

9897
/// <summary>
@@ -163,14 +162,17 @@ public Tensor transpose<T1>(T1 a, Axis perm = null, string name = "transpose", b
163162
/// Reverses specific dimensions of a tensor.
164163
/// </summary>
165164
/// <param name="tensor"></param>
166-
/// <param name="axis"></param>
165+
/// <param name="axis">The indices of the dimensions to reverse. Must be in the range [-rank(tensor), rank(tensor)).</param>
167166
/// <param name="name"></param>
168167
/// <returns></returns>
169-
public Tensor reverse(Tensor tensor, int[] axis, string name = null)
170-
=> gen_array_ops.reverse(tensor, ops.convert_to_tensor(axis), name: name);
171-
172-
public Tensor reverse(Tensor tensor, Tensor axis, string name = null)
173-
=> gen_array_ops.reverse(tensor, axis, name: name);
168+
public Tensor reverse(Tensor tensor, Axis axis, string name = null)
169+
{
170+
if (axis.IsScalar)
171+
{
172+
axis = new Axis(axis.axis);
173+
}
174+
return array_ops.reverse(tensor, axis, name: name);
175+
}
174176

175177
/// <summary>
176178
/// Returns the rank of a tensor.

src/TensorFlowNET.Core/APIs/tf.control_flow.cs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
4646
Tensor loop_vars,
4747
int parallel_iterations = 10)
4848
{
49-
Func<Tensor[], Tensor> cond1 = x
49+
Func<Tensors, Tensor> cond1 = x
5050
=> cond(x[0]);
5151

52-
Func<Tensor[], Tensor[]> body1 = x
52+
Func<Tensors, Tensors> body1 = x
5353
=> new[] { body(x[0]) };
5454

5555
var results = control_flow_ops.while_loop(cond1,
@@ -58,9 +58,9 @@ public Tensor while_loop(Func<Tensor, Tensor> cond,
5858
return results[0];
5959
}
6060

61-
public Tensor[] while_loop(Func<Tensor[], Tensor> cond,
62-
Func<Tensor[], Tensor[]> body,
63-
Tensor[] loop_vars,
61+
public Tensor[] while_loop(Func<Tensors, Tensor> cond,
62+
Func<Tensors, Tensors> body,
63+
Tensors loop_vars,
6464
int parallel_iterations = 10,
6565
string name = null)
6666
=> control_flow_ops.while_loop(cond, body, loop_vars,

src/TensorFlowNET.Core/APIs/tf.image.cs

Lines changed: 119 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using OneOf.Types;
18+
using System;
19+
using System.Buffers.Text;
20+
using Tensorflow.Contexts;
1721
using static Tensorflow.Binding;
1822

1923
namespace Tensorflow
@@ -162,17 +166,108 @@ public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] p
162166
public Tensor sobel_edges(Tensor image)
163167
=> image_ops_impl.sobel_edges(image);
164168

165-
public Tensor decode_jpeg(Tensor contents,
166-
int channels = 0,
167-
int ratio = 1,
168-
bool fancy_upscaling = true,
169-
bool try_recover_truncated = false,
170-
int acceptable_fraction = 1,
171-
string dct_method = "",
172-
string name = null)
173-
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
174-
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
175-
acceptable_fraction: acceptable_fraction, dct_method: dct_method);
169+
/// <summary>
170+
/// Adjust contrast of RGB or grayscale images.
171+
/// </summary>
172+
/// <param name="images">Images to adjust. At least 3-D.</param>
173+
/// <param name="contrast_factor"></param>
174+
/// <param name="name">A float multiplier for adjusting contrast.</param>
175+
/// <returns>The contrast-adjusted image or images.</returns>
176+
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
177+
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name);
178+
179+
/// <summary>
180+
/// Adjust hue of RGB images.
181+
/// </summary>
182+
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param>
183+
/// <param name="delta">float. How much to add to the hue channel.</param>
184+
/// <param name="name">A name for this operation (optional).</param>
185+
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
186+
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception>
187+
public Tensor adjust_hue(Tensor images, float delta, string name = null)
188+
{
189+
if (tf.Context.executing_eagerly())
190+
{
191+
if (delta < -1f || delta > 1f)
192+
throw new ValueError("delta must be in the interval [-1, 1]");
193+
}
194+
return gen_image_ops.adjust_hue(images, delta, name: name);
195+
}
196+
197+
/// <summary>
198+
/// Adjust saturation of RGB images.
199+
/// </summary>
200+
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param>
201+
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param>
202+
/// <param name="name">A name for this operation (optional).</param>
203+
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
204+
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
205+
=> gen_image_ops.adjust_saturation(image, saturation_factor, name);
206+
207+
/// <summary>
208+
/// Greedily selects a subset of bounding boxes in descending order of score.
209+
/// </summary>
210+
/// <param name="boxes">
211+
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
212+
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal
213+
/// to number of classes, class-specific boxes are used.
214+
/// </param>
215+
/// <param name="scores">
216+
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
217+
/// representing a single score corresponding to each box(each row of boxes).
218+
/// </param>
219+
/// <param name="max_output_size_per_class">
220+
/// A scalar integer `Tensor` representing the
221+
/// maximum number of boxes to be selected by non-max suppression per class
222+
/// </param>
223+
/// <param name="max_total_size">
224+
/// A int32 scalar representing maximum number of boxes retained
225+
/// over all classes.Note that setting this value to a large number may
226+
/// result in OOM error depending on the system workload.
227+
/// </param>
228+
/// <param name="iou_threshold">
229+
/// A float representing the threshold for deciding whether boxes
230+
/// overlap too much with respect to IOU.
231+
/// </param>
232+
/// <param name="score_threshold">
233+
/// A float representing the threshold for deciding when to
234+
/// remove boxes based on score.
235+
/// </param>
236+
/// <param name="pad_per_class">
237+
/// If false, the output nmsed boxes, scores and classes are
238+
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
239+
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
240+
/// </param>
241+
/// <param name="clip_boxes">
242+
/// If true, the coordinates of output nmsed boxes will be clipped
243+
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
244+
/// </param>
245+
/// <returns>
246+
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
247+
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
248+
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
249+
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of
250+
/// valid detections per batch item. Only the top valid_detections[i] entries
251+
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
252+
/// entries are zero paddings.
253+
/// </returns>
254+
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
255+
Tensor boxes,
256+
Tensor scores,
257+
int max_output_size_per_class,
258+
int max_total_size,
259+
float iou_threshold,
260+
float score_threshold,
261+
bool pad_per_class = false,
262+
bool clip_boxes = true)
263+
{
264+
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
265+
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
266+
var max_total_size_t = ops.convert_to_tensor(max_total_size);
267+
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
268+
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
269+
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
270+
}
176271

177272
/// <summary>
178273
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
@@ -187,7 +282,19 @@ public Tensor decode_jpeg(Tensor contents,
187282
/// <param name="name">A name for the operation (optional).</param>
188283
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns>
189284
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
190-
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
285+
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
286+
287+
public Tensor decode_jpeg(Tensor contents,
288+
int channels = 0,
289+
int ratio = 1,
290+
bool fancy_upscaling = true,
291+
bool try_recover_truncated = false,
292+
int acceptable_fraction = 1,
293+
string dct_method = "",
294+
string name = null)
295+
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
296+
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
297+
acceptable_fraction: acceptable_fraction, dct_method: dct_method);
191298

192299
public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
193300
bool uniform_noise = true, string name = null)

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.NumPy;
1718
using Tensorflow.Operations;
1819

1920
namespace Tensorflow
@@ -42,10 +43,20 @@ public Tensor erf(Tensor x, string name = null)
4243

4344
public Tensor multiply(Tensor x, Tensor y, string name = null)
4445
=> math_ops.multiply(x, y, name: name);
45-
4646
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
4747
=> math_ops.div_no_nan(a, b);
4848

49+
/// <summary>
50+
/// Computes the Euclidean norm of elements across dimensions of a tensor.
51+
/// </summary>
52+
/// <param name="input_tensor">The tensor to reduce. Should have numeric type.</param>
53+
/// <param name="axis">The dimensions to reduce. If `None` (the default), reduces all dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`</param>
54+
/// <param name="keepdims">If true, retains reduced dimensions with length 1.</param>
55+
/// <param name="name">A name for the operation (optional).</param>
56+
/// <returns>The reduced tensor, of the same dtype as the input_tensor.</returns>
57+
public Tensor reduce_euclidean_norm(Tensor input_tensor, Axis? axis = null, bool keepdims = false, string name = null)
58+
=> math_ops.reduce_euclidean_norm(input_tensor, axis: axis, keepdims: keepdims, name);
59+
4960
public Tensor square(Tensor x, string name = null)
5061
=> math_ops.square(x, name: name);
5162

@@ -354,7 +365,7 @@ public Tensor divide(Tensor a, Tensor b)
354365
=> a / b;
355366

356367
public Tensor sqrt(Tensor a, string name = null)
357-
=> gen_math_ops.sqrt(a, name);
368+
=> math_ops.sqrt(a, name);
358369

359370
public Tensor sign(Tensor a, string name = null)
360371
=> gen_math_ops.sign(a, name);
@@ -452,7 +463,18 @@ public Tensor multiply(Tensor x, Tensor y, string name = null)
452463
/// <returns></returns>
453464
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
454465
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
455-
466+
/// <summary>
467+
/// return scalar product
468+
/// </summary>
469+
/// <typeparam name="Tx"></typeparam>
470+
/// <typeparam name="Ty"></typeparam>
471+
/// <param name="x"></param>
472+
/// <param name="y"></param>
473+
/// <param name="axes"></param>
474+
/// <param name="name"></param>
475+
/// <returns></returns>
476+
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null)
477+
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
456478
public Tensor negative(Tensor x, string name = null)
457479
=> gen_math_ops.neg(x, name);
458480

@@ -600,5 +622,7 @@ public Tensor squared_difference(Tensor x, Tensor y, string name = null)
600622
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
601623
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
602624
string name = null) => gen_ops.complex(real, imag, dtype, name);
625+
public Tensor exp(Tensor x,
626+
string name = null) => gen_math_ops.exp(x, name);
603627
}
604628
}

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System.Xml.Linq;
1718
using Tensorflow.Operations;
1819
using Tensorflow.Operations.Activation;
1920
using static Tensorflow.Binding;
@@ -126,6 +127,26 @@ public Tensor[] fused_batch_norm(Tensor x,
126127
name: name,
127128
exponential_avg_factor: exponential_avg_factor);
128129

130+
/// <summary>
131+
/// Normalizes a tensor by `mean` and `variance`, and applies (optionally) a`scale` \\(\gamma\\) to it, as well as an `offset` \\(\beta\\).
132+
/// </summary>
133+
/// <param name="x">A floating point tensor.</param>
134+
/// <param name="mean">A mean `Tensor`.</param>
135+
/// <param name="variance">A variance `Tensor`.</param>
136+
/// <param name="offset"> An offset `Tensor`, often denoted \\(\beta\\) in equations, or NULL. If present, will be added to the normalized tensor.</param>
137+
/// <param name="scale"> A scale `Tensor`, often denoted \\(\gamma\\) in equations, or NULL. If present, the scale is applied to the normalized tensor.</param>
138+
/// <param name="variance_epsilon"> A small float number to avoid dividing by 0.</param>
139+
/// <param name="name">A name for this operation.</param>
140+
/// <returns>the normalized, scaled, offset tensor.</returns>
141+
public Tensor batch_normalization(Tensor x,
142+
Tensor mean,
143+
Tensor variance,
144+
Tensor offset,
145+
Tensor scale,
146+
float variance_epsilon,
147+
string name = null) => nn_impl.batch_normalization(x, mean, variance, offset, scale, variance_epsilon, name);
148+
149+
129150
public Tensor max_pool(Tensor value, int[] ksize, int[] strides, string padding, string data_format = "NHWC", string name = null)
130151
=> nn_ops.max_pool(value, ksize, strides, padding, data_format: data_format, name: name);
131152

src/TensorFlowNET.Core/APIs/tf.reshape.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ public Tensor reshape(Tensor tensor,
3131
public Tensor reshape(Tensor tensor,
3232
object[] shape,
3333
string name = null)
34-
=> gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name);
34+
=> array_ops.reshape(tensor, shape, name);
3535
}
3636
}

0 commit comments

Comments
 (0)