Skip to content

Commit cdc0c2e

Browse files
committed
Add ShellProgressBar for model fitting.
1 parent b71bfe3 commit cdc0c2e

File tree

3 files changed

+19
-20
lines changed

3 files changed

+19
-20
lines changed

src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ public class DataHandler
1515
public IDataAdapter DataAdapter => _adapter;
1616
IDatasetV2 _dataset;
1717
int _inferred_steps;
18+
public int Inferredsteps => _inferred_steps;
1819
int _current_step;
1920
int _step_increment;
2021
bool _insufficient_data;

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using NumSharp;
2+
using ShellProgressBar;
23
using System;
34
using System.Collections.Generic;
45
using System.Linq;
@@ -51,22 +52,7 @@ public void fit(NDArray x, NDArray y,
5152
StepsPerExecution = _steps_per_execution
5253
});
5354

54-
stop_training = false;
55-
_train_counter.assign(0);
56-
Console.WriteLine($"Training...");
57-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
58-
{
59-
// reset_metrics();
60-
// callbacks.on_epoch_begin(epoch)
61-
// data_handler.catch_stop_iteration();
62-
IEnumerable<(string, Tensor)> results = null;
63-
foreach (var step in data_handler.steps())
64-
{
65-
// callbacks.on_train_batch_begin(step)
66-
results = step_function(iterator);
67-
}
68-
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
69-
}
55+
FitInternal(epochs);
7056
}
7157

7258
public void fit(IDatasetV2 dataset,
@@ -95,21 +81,32 @@ public void fit(IDatasetV2 dataset,
9581
StepsPerExecution = _steps_per_execution
9682
});
9783

84+
FitInternal(epochs);
85+
}
86+
87+
void FitInternal(int epochs)
88+
{
9889
stop_training = false;
9990
_train_counter.assign(0);
100-
Console.WriteLine($"Training...");
91+
var options = new ProgressBarOptions
92+
{
93+
ProgressCharacter = '.',
94+
ProgressBarOnBottom = true
95+
};
96+
10197
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
10298
{
99+
using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options);
103100
// reset_metrics();
104101
// callbacks.on_epoch_begin(epoch)
105102
// data_handler.catch_stop_iteration();
106-
IEnumerable<(string, Tensor)> results = null;
107103
foreach (var step in data_handler.steps())
108104
{
109105
// callbacks.on_train_batch_begin(step)
110-
results = step_function(iterator);
106+
var results = step_function(iterator);
107+
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
108+
pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]");
111109
}
112-
Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
113110
}
114111
}
115112
}

src/TensorFlowNET.Keras/Tensorflow.Keras.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac
4747
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
4848
<PackageReference Include="Newtonsoft.Json" Version="12.0.3" />
4949
<PackageReference Include="SharpZipLib" Version="1.3.1" />
50+
<PackageReference Include="ShellProgressBar" Version="5.0.0" />
5051
</ItemGroup>
5152

5253
<ItemGroup>

0 commit comments

Comments
 (0)