Skip to content

Commit d2e50dd

Browse files
committed
Add keras model.predict.
1 parent 1ae9bbc commit d2e50dd

File tree

10 files changed

+74
-23
lines changed

10 files changed

+74
-23
lines changed

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ public class DataHandler
1818
public int Inferredsteps => _inferred_steps;
1919
int _current_step;
2020
int _step_increment;
21+
public int StepIncrement => _step_increment;
2122
bool _insufficient_data;
2223
int _steps_per_execution_value;
2324
int _initial_epoch => args.InitialEpoch;
@@ -73,7 +74,7 @@ public DataHandler(DataHandlerArgs args)
7374
_dataset = _adapter.GetDataset();
7475
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
7576
_current_step = 0;
76-
_step_increment = args.StepsPerExecution.numpy() - 1;
77+
_step_increment = _steps_per_execution_value - 1;
7778
_insufficient_data = false;
7879
}
7980

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ public class TensorLikeDataAdapter : DataAdapter, IDataAdapter
1414
int _batch_size;
1515
int num_samples;
1616
int num_full_batches;
17+
int _partial_batch_size;
1718

1819
public TensorLikeDataAdapter(DataAdapterArgs args)
1920
{
@@ -22,9 +23,9 @@ public TensorLikeDataAdapter(DataAdapterArgs args)
2223
num_samples = args.X.shape[0];
2324
var batch_size = args.BatchSize == -1 ? 32 : args.BatchSize;
2425
_batch_size = batch_size;
25-
_size = Convert.ToInt32(Math.Floor(num_samples / (batch_size + 0f)));
26+
_size = num_samples < batch_size ? num_samples % batch_size : num_samples / batch_size;
2627
num_full_batches = num_samples / batch_size;
27-
var _partial_batch_size = num_samples % batch_size;
28+
_partial_batch_size = num_samples % batch_size;
2829

2930
var indices_dataset = tf.data.Dataset.range(1);
3031
indices_dataset = indices_dataset.repeat(args.Epochs);
@@ -57,6 +58,15 @@ IDatasetV2 slice_batch_indices(Tensor indices)
5758
var first_k_indices = array_ops.slice(indices, new int[] { 0 }, new int[] { num_in_full_batch });
5859
first_k_indices = array_ops.reshape(first_k_indices, new int[] { num_full_batches, _batch_size });
5960
var flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices);
61+
if (_partial_batch_size > 0)
62+
{
63+
var array = array_ops.slice(indices,
64+
new[] { constant_op.constant(num_in_full_batch)},
65+
new[] { constant_op.constant(_partial_batch_size)});
66+
var index_remainder = tf.data.Dataset.from_tensor(array);
67+
flat_dataset = flat_dataset.concatenate(index_remainder);
68+
}
69+
6070
return flat_dataset;
6171
}
6272

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask =
340340
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
341341
var outputs = node.Layer.Apply(layer_inputs, is_training: training);
342342
foreach (var output in outputs.Where(x => x != null))
343-
tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
343+
tf.Logger.Information($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
344344
// Update tensor_dict for next input
345345
foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
346346
tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ void FitInternal(int epochs)
9595
foreach (var step in data_handler.steps())
9696
{
9797
// callbacks.on_train_batch_begin(step)
98-
var results = step_function(iterator);
98+
var results = train_step_function(iterator);
9999
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
100100
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
101101
}

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
using NumSharp;
22
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
35
using Tensorflow.Keras.ArgsDefinition;
46
using Tensorflow.Keras.Engine.DataAdapters;
7+
using static Tensorflow.Binding;
58

69
namespace Tensorflow.Keras.Engine
710
{
@@ -21,7 +24,7 @@ public partial class Model
2124
/// <param name="workers"></param>
2225
/// <param name="use_multiprocessing"></param>
2326
/// <returns></returns>
24-
public Tensor predict(Tensor x,
27+
public Tensors predict(Tensor x,
2528
int batch_size = -1,
2629
int verbose = 0,
2730
int steps = -1,
@@ -43,7 +46,35 @@ public Tensor predict(Tensor x,
4346
StepsPerExecution = _steps_per_execution
4447
});
4548

46-
throw new NotImplementedException("");
49+
Tensors outputs = null;
50+
_predict_counter.assign(0);
51+
// callbacks.on_predict_begin()
52+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
53+
{
54+
foreach(var step in data_handler.steps())
55+
{
56+
// callbacks.on_predict_batch_begin(step)
57+
var batch_outputs = run_predict_step(iterator);
58+
outputs = batch_outputs;
59+
var end_step = step + data_handler.StepIncrement;
60+
// callbacks.on_predict_batch_end(end_step, {'outputs': batch_outputs})
61+
}
62+
}
63+
// callbacks.on_predict_end()
64+
return outputs;
65+
}
66+
67+
Tensors run_predict_step(OwnedIterator iterator)
68+
{
69+
var data = iterator.next();
70+
var outputs = predict_step(data[0]);
71+
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
72+
return outputs;
73+
}
74+
75+
Tensors predict_step(Tensor data)
76+
{
77+
return Apply(data, is_training: false);
4778
}
4879
}
4980
}

src/TensorFlowNET.Keras/Engine/Model.Train.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow.Keras.Engine
88
{
99
public partial class Model
1010
{
11-
IEnumerable<(string, Tensor)> step_function(OwnedIterator iterator)
11+
IEnumerable<(string, Tensor)> train_step_function(OwnedIterator iterator)
1212
{
1313
var data = iterator.next();
1414
var outputs = train_step(data[0], data[1]);

src/TensorFlowNET.Keras/Layers/LayersApi.Reshaping.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,11 @@ public Reshape Reshape(TensorShape target_shape)
4545
{
4646
TargetShape = target_shape
4747
});
48+
49+
public Reshape Reshape(object[] target_shape)
50+
=> new Reshape(new ReshapeArgs
51+
{
52+
TargetShapeObjects = target_shape
53+
});
4854
}
4955
}

src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,6 @@ private Tensor _fused_batch_norm(Tensor inputs, Tensor training)
142142
if (use_fused_avg_updates)
143143
exponential_avg_factor = 1.0f - momentum;
144144

145-
var beta = this.beta;
146-
var gamma = this.gamma;
147-
148145
Func<Tensor[]> _fused_batch_norm_training = () =>
149146
{
150147
return tf.nn.fused_batch_norm(

src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,32 @@ public Reshape(ReshapeArgs args)
2121

2222
protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
2323
{
24-
var shape_tensor = array_ops.shape(inputs);
25-
var shape = new List<int> { inputs.shape[0] };
26-
shape.AddRange(args.TargetShape.dims);
24+
var shapes = new List<object>();
25+
shapes.Add(array_ops.shape(inputs)[0]);
26+
if (args.TargetShapeObjects != null)
27+
shapes.AddRange(args.TargetShapeObjects);
28+
if (args.TargetShape != null)
29+
args.TargetShape.dims.ToList().ForEach(x => shapes.Add(x));
30+
var shape = ops.convert_to_tensor(shapes);
2731

28-
var result = array_ops.reshape(inputs, shape.ToArray());
32+
var result = array_ops.reshape(inputs, shape);
2933
if (!tf.Context.executing_eagerly())
3034
result.set_shape(ComputeOutputShape(inputs.shape));
3135
return result;
3236
}
3337

3438
public override TensorShape ComputeOutputShape(TensorShape input_shape)
3539
{
36-
if (input_shape.dims[0] == -1)
40+
if (input_shape.dims[1..].Contains(-1))
41+
{
42+
throw new NotImplementedException("");
43+
}
44+
else
3745
{
3846
input_shape = input_shape.dims[0];
3947
var output_shape = input_shape.concatenate(args.TargetShape.dims);
4048
return output_shape;
4149
}
42-
else
43-
throw new NotImplementedException("");
4450
}
4551
}
4652
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<LangVersion>8.0</LangVersion>
77
<RootNamespace>Tensorflow.Keras</RootNamespace>
88
<Platforms>AnyCPU;x64</Platforms>
9-
<Version>0.3.0</Version>
9+
<Version>0.4.0</Version>
1010
<Authors>Haiping Chen</Authors>
1111
<Product>Keras for .NET</Product>
1212
<Copyright>Apache 2.0, Haiping Chen 2020</Copyright>
@@ -20,7 +20,8 @@
2020
* Support Conv2D functional API.
2121
* Support BatchNormalization layer.
2222
* Building keras model in subclass, functional and sequential api
23-
* Implemented backward_function.</PackageReleaseNotes>
23+
* Implemented backward_function.
24+
* Support model.load_weights.</PackageReleaseNotes>
2425
<Description>Keras for .NET
2526

2627
Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent &amp; simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear &amp; actionable error messages.</Description>
@@ -31,8 +32,8 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
3132
<RepositoryType>Git</RepositoryType>
3233
<SignAssembly>true</SignAssembly>
3334
<AssemblyOriginatorKeyFile>Open.snk</AssemblyOriginatorKeyFile>
34-
<AssemblyVersion>0.3.0.0</AssemblyVersion>
35-
<FileVersion>0.3.0.0</FileVersion>
35+
<AssemblyVersion>0.4.0.0</AssemblyVersion>
36+
<FileVersion>0.4.0.0</FileVersion>
3637
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3738
</PropertyGroup>
3839

@@ -48,7 +49,6 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
4849
<ItemGroup>
4950
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
5051
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
51-
<PackageReference Include="NumSharp.Lite" Version="0.1.10" />
5252
<PackageReference Include="SciSharp.Keras.HDF5" Version="1.1.10.500" />
5353
<PackageReference Include="SharpZipLib" Version="1.3.1" />
5454
</ItemGroup>

0 commit comments

Comments
 (0)