|
1 | 1 | using NumSharp; |
| 2 | +using ShellProgressBar; |
2 | 3 | using System; |
3 | 4 | using System.Collections.Generic; |
4 | 5 | using System.Linq; |
@@ -51,22 +52,7 @@ public void fit(NDArray x, NDArray y, |
51 | 52 | StepsPerExecution = _steps_per_execution |
52 | 53 | }); |
53 | 54 |
|
54 | | - stop_training = false; |
55 | | - _train_counter.assign(0); |
56 | | - Console.WriteLine($"Training..."); |
57 | | - foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
58 | | - { |
59 | | - // reset_metrics(); |
60 | | - // callbacks.on_epoch_begin(epoch) |
61 | | - // data_handler.catch_stop_iteration(); |
62 | | - IEnumerable<(string, Tensor)> results = null; |
63 | | - foreach (var step in data_handler.steps()) |
64 | | - { |
65 | | - // callbacks.on_train_batch_begin(step) |
66 | | - results = step_function(iterator); |
67 | | - } |
68 | | - Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); |
69 | | - } |
| 55 | + FitInternal(epochs); |
70 | 56 | } |
71 | 57 |
|
72 | 58 | public void fit(IDatasetV2 dataset, |
@@ -95,21 +81,32 @@ public void fit(IDatasetV2 dataset, |
95 | 81 | StepsPerExecution = _steps_per_execution |
96 | 82 | }); |
97 | 83 |
|
| 84 | + FitInternal(epochs); |
| 85 | + } |
| 86 | + |
| 87 | + void FitInternal(int epochs) |
| 88 | + { |
98 | 89 | stop_training = false; |
99 | 90 | _train_counter.assign(0); |
100 | | - Console.WriteLine($"Training..."); |
| 91 | + var options = new ProgressBarOptions |
| 92 | + { |
| 93 | + ProgressCharacter = '.', |
| 94 | + ProgressBarOnBottom = true |
| 95 | + }; |
| 96 | + |
101 | 97 | foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) |
102 | 98 | { |
| 99 | + using var pbar = new ProgressBar(data_handler.Inferredsteps, "Training...", options); |
103 | 100 | // reset_metrics(); |
104 | 101 | // callbacks.on_epoch_begin(epoch) |
105 | 102 | // data_handler.catch_stop_iteration(); |
106 | | - IEnumerable<(string, Tensor)> results = null; |
107 | 103 | foreach (var step in data_handler.steps()) |
108 | 104 | { |
109 | 105 | // callbacks.on_train_batch_begin(step) |
110 | | - results = step_function(iterator); |
| 106 | + var results = step_function(iterator); |
| 107 | + var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); |
| 108 | + pbar.Tick($"[Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}]"); |
111 | 109 | } |
112 | | - Console.WriteLine($"epoch: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); |
113 | 110 | } |
114 | 111 | } |
115 | 112 | } |
|
0 commit comments