Skip to content

Commit 4ef675f

Browse files
committed
Consolidate MapDataset function.
1 parent 2860139 commit 4ef675f

29 files changed

+258
-38
lines changed

src/TensorFlowNET.Console/Tensorflow.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>netcoreapp3.1</TargetFramework>
5+
<TargetFramework>net5.0</TargetFramework>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<AssemblyName>Tensorflow</AssemblyName>
88
<Platforms>AnyCPU;x64</Platforms>

src/TensorFlowNET.Core/APIs/c_api.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ namespace Tensorflow
4343
/// </summary>
4444
public partial class c_api
4545
{
46-
public const string TensorFlowLibName = @"D:\Projects\tensorflow-haiping\bazel-bin\tensorflow\tensorflow";
46+
public const string TensorFlowLibName = "tensorflow";
4747

4848
public static string StringPiece(IntPtr handle)
4949
{

src/TensorFlowNET.Core/APIs/tf.strings.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,30 @@ public class StringsApi
2424
{
2525
string_ops ops = new string_ops();
2626

27+
/// <summary>
28+
/// Converts all uppercase characters into their respective lowercase replacements.
29+
/// </summary>
30+
/// <param name="input"></param>
31+
/// <param name="encoding"></param>
32+
/// <param name="name"></param>
33+
/// <returns></returns>
34+
public Tensor lower(Tensor input, string encoding = "", string name = null)
35+
=> ops.lower(input: input, encoding: encoding, name: name);
36+
37+
/// <summary>
38+
///
39+
/// </summary>
40+
/// <param name="input"></param>
41+
/// <param name="pattern"></param>
42+
/// <param name="rewrite"></param>
43+
/// <param name="replace_global"></param>
44+
/// <param name="name"></param>
45+
/// <returns></returns>
46+
public Tensor regex_replace(Tensor input, string pattern, string rewrite,
47+
bool replace_global = true, string name = null)
48+
=> ops.regex_replace(input, pattern, rewrite,
49+
replace_global: replace_global, name: name);
50+
2751
/// <summary>
2852
/// Return substrings from `Tensor` of strings.
2953
/// </summary>

src/TensorFlowNET.Core/Data/DatasetV2.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ namespace Tensorflow
1414
public class DatasetV2 : IDatasetV2
1515
{
1616
protected dataset_ops ops = new dataset_ops();
17+
public string[] class_names { get; set; }
1718
public Tensor variant_tensor { get; set; }
1819

1920
public TensorSpec[] structure { get; set; }
@@ -54,7 +55,7 @@ public IDatasetV2 skip(int count)
5455
public IDatasetV2 optimize(string[] optimizations, string[] optimization_configs)
5556
=> new OptimizeDataset(this, optimizations, optimization_configs: optimization_configs);
5657

57-
public IDatasetV2 map(Func<Tensor, Tensor> map_func,
58+
public IDatasetV2 map(Func<Tensors, Tensors> map_func,
5859
bool use_inter_op_parallelism = true,
5960
bool preserve_cardinality = true,
6061
bool use_legacy_function = false)
@@ -64,7 +65,7 @@ public IDatasetV2 map(Func<Tensor, Tensor> map_func,
6465
preserve_cardinality: preserve_cardinality,
6566
use_legacy_function: use_legacy_function);
6667

67-
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls = -1)
68+
public IDatasetV2 map(Func<Tensors, Tensors> map_func, int num_parallel_calls)
6869
=> new ParallelMapDataset(this, map_func, num_parallel_calls: num_parallel_calls);
6970

7071
public IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func)

src/TensorFlowNET.Core/Data/IDatasetV2.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ namespace Tensorflow
66
{
77
public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
88
{
9+
string[] class_names { get; set; }
10+
911
Tensor variant_tensor { get; set; }
1012

1113
TensorShape[] output_shapes { get; }
@@ -62,13 +64,13 @@ public interface IDatasetV2 : IEnumerable<(Tensor, Tensor)>
6264

6365
IDatasetV2 optimize(string[] optimizations, string[] optimization_configs);
6466

65-
IDatasetV2 map(Func<Tensor, Tensor> map_func,
67+
IDatasetV2 map(Func<Tensors, Tensors> map_func,
6668
bool use_inter_op_parallelism = true,
6769
bool preserve_cardinality = true,
6870
bool use_legacy_function = false);
6971

7072
IDatasetV2 map(Func<Tensors, Tensors> map_func,
71-
int num_parallel_calls = -1);
73+
int num_parallel_calls);
7274

7375
IDatasetV2 flat_map(Func<Tensor, IDatasetV2> map_func);
7476

src/TensorFlowNET.Core/Data/MapDataset.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace Tensorflow
1010
public class MapDataset : UnaryDataset
1111
{
1212
public MapDataset(IDatasetV2 input_dataset,
13-
Func<Tensor, Tensor> map_func,
13+
Func<Tensors, Tensors> map_func,
1414
bool use_inter_op_parallelism = true,
1515
bool preserve_cardinality = false,
1616
bool use_legacy_function = false) : base(input_dataset)

src/TensorFlowNET.Core/Framework/Models/TensorSpec.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public TensorSpec _unbatch()
1515
if (_shape.ndim == 0)
1616
throw new ValueError("Unbatching a tensor is only supported for rank >= 1");
1717

18-
return new TensorSpec(_shape.dims[1..], _dtype);
18+
return new TensorSpec(_shape.dims.Skip(1).ToArray(), _dtype);
1919
}
2020

2121
public TensorSpec _batch(int dim = -1)

src/TensorFlowNET.Core/Gradients/image_grad.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public static Tensor[] _ResizeNearestNeighborGrad(Operation op, Tensor[] grads)
3030
var shape = new TensorShape(image.shape.Skip(1).Take(2).ToArray());
3131
Tensor image_shape = null;
3232
if (shape.is_fully_defined())
33-
image_shape = constant_op.constant(image.shape[1..3]);
33+
image_shape = constant_op.constant(image.shape.Skip(1).Take(2).ToArray());
3434
else
3535
image_shape = array_ops.shape(image)["1:3"];
3636

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class PreprocessingLayerArgs : LayerArgs
8+
{
9+
}
10+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.ArgsDefinition
6+
{
7+
public class TextVectorizationArgs : PreprocessingLayerArgs
8+
{
9+
public Func<Tensor, Tensor> Standardize { get; set; }
10+
public string Split { get; set; } = "standardize";
11+
public int MaxTokens { get; set; } = -1;
12+
public string OutputMode { get; set; } = "int";
13+
public int OutputSequenceLength { get; set; } = -1;
14+
}
15+
}

0 commit comments

Comments
 (0)