Skip to content

Commit fd64ad1

Browse files
committed
Fix Sequential model.
1 parent 400cde2 commit fd64ad1

File tree

21 files changed

+240
-76
lines changed

21 files changed

+240
-76
lines changed

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ In comparison to other projects, like for instance [TensorFlowSharp](https://www
2626

2727
### How to use
2828

29-
| TensorFlow | tf native1.14 | tf native 1.15 | tf native 2.3 |
30-
| -------------------------- | ------------- | -------------- | ------------- |
31-
| tf.net 0.3x, tf.keras 0.2 | | | x |
32-
| tf.net 0.2x | | x | x |
33-
| tf.net 0.15 | x | x | |
34-
| tf.net 0.14 | x | | |
29+
| TensorFlow | tf native1.14, cuda 10.0 | tf native 1.15, cuda 10.0 | tf native 2.3, cuda 10.1 | tf native 2.4, cuda 11 |
30+
| -------------------------- | ------------- | -------------- | ------------- | ------------- |
31+
| tf.net 0.3x, tf.keras 0.2 | | | x | not compatible |
32+
| tf.net 0.2x | | x | x | |
33+
| tf.net 0.15 | x | x | | |
34+
| tf.net 0.14 | x | | | |
3535

3636
Troubleshooting of running example or installation, please refer [here](tensorflowlib/README.md).
3737

src/SciSharp.TensorFlow.Redist/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,19 @@ https://www.nuget.org/packages/SciSharp.TensorFlow.Redist
2222

2323
Related merged [commits](https://github.com/SciSharp/TensorFlow.NET/commit/854a5ba61ad0e400623821236bd117cc24c6cb77).
2424

25+
26+
27+
#### Download pre-build package
28+
29+
[Mac OSX CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-darwin-x86_64-2.4.0.tar.gz), [Linux CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-linux-x86_64-2.4.0.tar.gz), [Linux GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-linux-x86_64-2.4.0.tar.gz), [Windows CPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.4.0.tar.gz), [Windows GPU](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-gpu-windows-x86_64-2.4.0.zip)
30+
31+
32+
2533
#### Pack and Deploy ####
2634

2735
On Windows, the tar command does not support extracting archives with symlinks. So when `dotnet pack` runs on Windows it will only package the Windows binaries.
2836

2937
1. Run `dotnet pack SciSharp.TensorFlow.Redist.nupkgproj` under `src/SciSharp.TensorFlow.Redist` directory in Linux.
30-
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.3.1.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
38+
2. Run `dotnet nuget push SciSharp.TensorFlow.Redist.2.4.0.nupkg -k APIKEY -s https://api.nuget.org/v3/index.json -t 600`
3139

3240

src/TensorFlowNET.Console/TensorFlowNET.Console.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
</PropertyGroup>
99

1010
<ItemGroup>
11-
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.0" />
11+
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
1212
</ItemGroup>
1313

1414
<ItemGroup>

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ public static string to_numpy_string(Tensor tensor)
574574
return string.Join(string.Empty, nd.ToArray<byte>()
575575
.Select(x => x < 32 || x > 127 ? "\\x" + x.ToString("x") : Convert.ToChar(x).ToString()));
576576
case TF_DataType.TF_BOOL:
577-
return (nd.GetByte(0) > 0).ToString();
577+
return nd.GetBoolean(0).ToString();
578578
case TF_DataType.TF_VARIANT:
579579
case TF_DataType.TF_RESOURCE:
580580
return "<unprintable>";

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,38 @@ public DataHandler(DataHandlerArgs args)
3737
_steps_per_execution_value = args.StepsPerExecution.numpy();
3838
}
3939

40-
_adapter = new TensorLikeDataAdapter(new TensorLikeDataAdapterArgs
40+
if(args.Dataset == null)
4141
{
42-
X = args.X,
43-
Y = args.Y,
44-
BatchSize = args.BatchSize,
45-
Steps = args.StepsPerEpoch,
46-
Epochs = args.Epochs - args.InitialEpoch,
47-
Shuffle = args.Shuffle,
48-
MaxQueueSize = args.MaxQueueSize,
49-
Worker = args.Workers,
50-
UseMultiprocessing = args.UseMultiprocessing,
51-
Model = args.Model
52-
});
42+
_adapter = new TensorLikeDataAdapter(new DataAdapterArgs
43+
{
44+
X = args.X,
45+
Y = args.Y,
46+
BatchSize = args.BatchSize,
47+
Steps = args.StepsPerEpoch,
48+
Epochs = args.Epochs - args.InitialEpoch,
49+
Shuffle = args.Shuffle,
50+
MaxQueueSize = args.MaxQueueSize,
51+
Worker = args.Workers,
52+
UseMultiprocessing = args.UseMultiprocessing,
53+
Model = args.Model
54+
});
55+
}
56+
else
57+
{
58+
_adapter = new DatasetAdapter(new DataAdapterArgs
59+
{
60+
Dataset = args.Dataset,
61+
BatchSize = args.BatchSize,
62+
Steps = args.StepsPerEpoch,
63+
Epochs = args.Epochs - args.InitialEpoch,
64+
Shuffle = args.Shuffle,
65+
MaxQueueSize = args.MaxQueueSize,
66+
Worker = args.Workers,
67+
UseMultiprocessing = args.UseMultiprocessing,
68+
Model = args.Model
69+
});
70+
}
71+
5372
_dataset = _adapter.GetDataset();
5473
_inferred_steps = _infer_steps(args.StepsPerEpoch, _dataset);
5574
_current_step = 0;
@@ -66,7 +85,8 @@ int _infer_steps(int steps_per_epoch, IDatasetV2 dataset)
6685
if (adapter_steps > -1)
6786
return adapter_steps;
6887

69-
throw new NotImplementedException("");
88+
var size = dataset.dataset_cardinality();
89+
return size.numpy();
7090
}
7191

7292
public IEnumerable<(int, OwnedIterator)> enumerate_epochs()
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.ArgsDefinition;
5+
6+
namespace Tensorflow.Keras.Engine.DataAdapters
7+
{
8+
public class DatasetAdapter : IDataAdapter
9+
{
10+
DataAdapterArgs args;
11+
IDatasetV2 _dataset => args.Dataset;
12+
public DatasetAdapter(DataAdapterArgs args)
13+
{
14+
this.args = args;
15+
}
16+
17+
public bool CanHandle(Tensor x, Tensor y = null)
18+
{
19+
throw new NotImplementedException();
20+
}
21+
22+
public IDatasetV2 GetDataset()
23+
=> _dataset;
24+
25+
public int GetSize()
26+
=> -1;
27+
28+
public (Tensor, Tensor) Expand1d(Tensor x, Tensor y)
29+
{
30+
if (y.TensorShape.ndim == 1)
31+
y = array_ops.expand_dims(y, axis: -1);
32+
return (x, y);
33+
}
34+
}
35+
}

src/TensorFlowNET.Keras/Engine/DataAdapters/TensorLikeDataAdapter.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters
99
/// </summary>
1010
public class TensorLikeDataAdapter : IDataAdapter
1111
{
12-
TensorLikeDataAdapterArgs args;
12+
DataAdapterArgs args;
1313
int _size;
1414
int _batch_size;
1515
int num_samples;
1616
int num_full_batches;
1717
IDatasetV2 _dataset;
1818

19-
public TensorLikeDataAdapter(TensorLikeDataAdapterArgs args)
19+
public TensorLikeDataAdapter(DataAdapterArgs args)
2020
{
2121
this.args = args;
2222
_process_tensorlike();

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,12 @@ public Functional(Tensors inputs, Tensors outputs, string name = null)
3939
_input_coordinates = new List<KerasHistory>();
4040
_output_coordinates = new List<KerasHistory>();
4141
tensor_usage_count = new Dictionary<int, int>();
42+
if (this is Sequential)
43+
return;
4244
_init_graph_network(inputs, outputs);
4345
}
4446

45-
void _init_graph_network(Tensors inputs, Tensors outputs)
47+
protected void _init_graph_network(Tensors inputs, Tensors outputs)
4648
{
4749
_is_graph_network = true;
4850
this.inputs = inputs;

src/TensorFlowNET.Keras/Engine/Model.Compile.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ public partial class Model
99
{
1010
LossesContainer compiled_loss;
1111
MetricsContainer compiled_metrics;
12-
public void compile(string optimizerName, ILossFunc lossName)
13-
{
14-
throw new NotImplementedException("");
15-
}
1612

1713
public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
1814
{
@@ -29,12 +25,12 @@ public void compile(ILossFunc loss, OptimizerV2 optimizer, string[] metrics)
2925
this.loss = loss;
3026
}
3127

32-
public void compile(string optimizerName, string lossName)
28+
public void compile(string optimizer, string loss, string[] metrics)
3329
{
34-
switch (optimizerName)
30+
switch (optimizer)
3531
{
3632
case "rmsprop":
37-
optimizer = new RMSprop(new RMSpropArgs
33+
this.optimizer = new RMSprop(new RMSpropArgs
3834
{
3935

4036
});

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,49 @@ public void fit(NDArray x, NDArray y,
6868
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
6969
}
7070
}
71+
72+
public void fit(IDatasetV2 dataset,
73+
IDatasetV2 validation_data = null,
74+
int batch_size = -1,
75+
int epochs = 1,
76+
int verbose = 1,
77+
float validation_split = 0f,
78+
bool shuffle = true,
79+
int initial_epoch = 0,
80+
int max_queue_size = 10,
81+
int workers = 1,
82+
bool use_multiprocessing = false)
83+
{
84+
data_handler = new DataHandler(new DataHandlerArgs
85+
{
86+
Dataset = dataset,
87+
BatchSize = batch_size,
88+
InitialEpoch = initial_epoch,
89+
Epochs = epochs,
90+
Shuffle = shuffle,
91+
MaxQueueSize = max_queue_size,
92+
Workers = workers,
93+
UseMultiprocessing = use_multiprocessing,
94+
Model = this,
95+
StepsPerExecution = _steps_per_execution
96+
});
97+
98+
stop_training = false;
99+
_train_counter.assign(0);
100+
Console.WriteLine($"Training...");
101+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
102+
{
103+
// reset_metrics();
104+
// callbacks.on_epoch_begin(epoch)
105+
// data_handler.catch_stop_iteration();
106+
IEnumerable<(string, Tensor)> results = null;
107+
foreach (var step in data_handler.steps())
108+
{
109+
// callbacks.on_train_batch_begin(step)
110+
results = step_function(iterator);
111+
}
112+
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
113+
}
114+
}
71115
}
72116
}

0 commit comments

Comments
 (0)