Skip to content

Commit 90638a8

Browse files
committed
TextVectorization
1 parent ac53791 commit 90638a8

File tree

4 files changed

+39
-5
lines changed

4 files changed

+39
-5
lines changed

src/TensorFlowNET.Core/Data/MapDataset.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ public MapDataset(IDatasetV2 input_dataset,
1717
{
1818
var func = new ConcreteFunction($"{map_func.Method.Name}_{Guid.NewGuid()}");
1919
func.Enter();
20-
var input = tf.placeholder(input_dataset.element_spec[0].dtype);
21-
var output = map_func(input);
22-
func.ToGraph(input, output);
20+
var inputs = new Tensors();
21+
foreach (var input in input_dataset.element_spec)
22+
inputs.Add(tf.placeholder(input.dtype, shape: input.shape));
23+
var outputs = map_func(inputs);
24+
func.ToGraph(inputs, outputs);
2325
func.Exit();
2426

2527
structure = func.OutputStructure;

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ tf.net 0.4x.x aligns with TensorFlow v2.4.1 native library.</PackageReleaseNotes
8686
<ItemGroup>
8787
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
8888
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
89-
<PackageReference Include="NumSharp.Lite" Version="0.1.12" />
89+
<PackageReference Include="NumSharp" Version="0.30.0" />
9090
<PackageReference Include="Protobuf.Text" Version="0.5.0" />
9191
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />
9292
</ItemGroup>

src/TensorFlowNET.Keras/Layers/Preprocessing/TextVectorization.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,40 @@ public class TextVectorization : CombinerPreprocessingLayer
1313
public TextVectorization(TextVectorizationArgs args)
1414
: base(args)
1515
{
16+
this.args = args;
1617
args.DType = TF_DataType.TF_STRING;
1718
// string standardize = "lower_and_strip_punctuation",
1819
}
20+
21+
/// <summary>
22+
/// Fits the state of the preprocessing layer to the dataset.
23+
/// </summary>
24+
/// <param name="data"></param>
25+
/// <param name="reset_state"></param>
26+
public void adapt(IDatasetV2 data, bool reset_state = true)
27+
{
28+
var shape = data.output_shapes[0];
29+
if (shape.rank == 1)
30+
data = data.map(tensor => array_ops.expand_dims(tensor, -1));
31+
build(data.variant_tensor);
32+
var preprocessed_inputs = data.map(_preprocess);
33+
}
34+
35+
protected override void build(Tensors inputs)
36+
{
37+
base.build(inputs);
38+
}
39+
40+
Tensors _preprocess(Tensors inputs)
41+
{
42+
if (args.Standardize != null)
43+
inputs = args.Standardize(inputs);
44+
if (!string.IsNullOrEmpty(args.Split))
45+
{
46+
if (inputs.shape.ndim > 1)
47+
inputs = array_ops.squeeze(inputs, axis: new[] { -1 });
48+
}
49+
return inputs;
50+
}
1951
}
2052
}

src/TensorFlowNET.Keras/Preprocessings/Preprocessing.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public partial class Preprocessing
1111
public DatasetUtils dataset_utils => new DatasetUtils();
1212

1313
public TextVectorization TextVectorization(Func<Tensor, Tensor> standardize = null,
14-
string split = "standardize",
14+
string split = "whitespace",
1515
int max_tokens = -1,
1616
string output_mode = "int",
1717
int output_sequence_length = -1) => new TextVectorization(new TextVectorizationArgs

0 commit comments

Comments
 (0)