@@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine
1414{
1515 public partial class Model
1616 {
17+ protected Dictionary < string , float > evaluate ( CallbackList callbacks , DataHandler data_handler , bool is_val )
18+ {
19+ callbacks . on_test_begin ( ) ;
20+
21+ //Dictionary<string, float>? logs = null;
22+ var logs = new Dictionary < string , float > ( ) ;
23+ int x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
24+ foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
25+ {
26+ reset_metrics ( ) ;
27+ callbacks . on_epoch_begin ( epoch ) ;
28+ // data_handler.catch_stop_iteration();
29+
30+ foreach ( var step in data_handler . steps ( ) )
31+ {
32+ callbacks . on_test_batch_begin ( step ) ;
33+
34+ var data = iterator . next ( ) ;
35+
36+ logs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
37+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _test_counter . assign_add ( 1 ) ) ;
38+
39+ var end_step = step + data_handler . StepIncrement ;
40+
41+ if ( ! is_val )
42+ callbacks . on_test_batch_end ( end_step , logs ) ;
43+ }
44+ }
45+
46+ return logs ;
47+ }
48+
1749 /// <summary>
1850 /// Returns the loss value & metrics values for the model in test mode.
1951 /// </summary>
@@ -64,31 +96,8 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
6496 Verbose = verbose ,
6597 Steps = data_handler . Inferredsteps
6698 } ) ;
67- callbacks . on_test_begin ( ) ;
68-
69- //Dictionary<string, float>? logs = null;
70- var logs = new Dictionary < string , float > ( ) ;
71- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
72- {
73- reset_metrics ( ) ;
74- // data_handler.catch_stop_iteration();
7599
76- foreach ( var step in data_handler . steps ( ) )
77- {
78- callbacks . on_test_batch_begin ( step ) ;
79- logs = test_function ( data_handler , iterator ) ;
80- var end_step = step + data_handler . StepIncrement ;
81- if ( is_val == false )
82- callbacks . on_test_batch_end ( end_step , logs ) ;
83- }
84- }
85-
86- var results = new Dictionary < string , float > ( ) ;
87- foreach ( var log in logs )
88- {
89- results [ log . Key ] = log . Value ;
90- }
91- return results ;
100+ return evaluate ( callbacks , data_handler , is_val ) ;
92101 }
93102
94103 public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
@@ -107,31 +116,8 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
107116 Verbose = verbose ,
108117 Steps = data_handler . Inferredsteps
109118 } ) ;
110- callbacks . on_test_begin ( ) ;
111119
112- Dictionary < string , float > logs = null ;
113- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
114- {
115- reset_metrics ( ) ;
116- callbacks . on_epoch_begin ( epoch ) ;
117- // data_handler.catch_stop_iteration();
118-
119- foreach ( var step in data_handler . steps ( ) )
120- {
121- callbacks . on_test_batch_begin ( step ) ;
122- logs = test_function ( data_handler , iterator ) ;
123- var end_step = step + data_handler . StepIncrement ;
124- if ( is_val == false )
125- callbacks . on_test_batch_end ( end_step , logs ) ;
126- }
127- }
128-
129- var results = new Dictionary < string , float > ( ) ;
130- foreach ( var log in logs )
131- {
132- results [ log . Key ] = log . Value ;
133- }
134- return results ;
120+ return evaluate ( callbacks , data_handler , is_val ) ;
135121 }
136122
137123
@@ -150,51 +136,8 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150136 Verbose = verbose ,
151137 Steps = data_handler . Inferredsteps
152138 } ) ;
153- callbacks . on_test_begin ( ) ;
154-
155- Dictionary < string , float > logs = null ;
156- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
157- {
158- reset_metrics ( ) ;
159- callbacks . on_epoch_begin ( epoch ) ;
160- // data_handler.catch_stop_iteration();
161-
162- foreach ( var step in data_handler . steps ( ) )
163- {
164- callbacks . on_test_batch_begin ( step ) ;
165- logs = test_function ( data_handler , iterator ) ;
166- var end_step = step + data_handler . StepIncrement ;
167- if ( is_val == false )
168- callbacks . on_test_batch_end ( end_step , logs ) ;
169- }
170- }
171-
172- var results = new Dictionary < string , float > ( ) ;
173- foreach ( var log in logs )
174- {
175- results [ log . Key ] = log . Value ;
176- }
177- return results ;
178- }
179-
180- Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
181- {
182- var data = iterator . next ( ) ;
183- var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
184- var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
185- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
186- return outputs ;
187- }
188-
189- Dictionary < string , float > test_step ( DataHandler data_handler , Tensor x , Tensor y )
190- {
191- ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
192- var y_pred = Apply ( x , training : false ) ;
193- var loss = compiled_loss . Call ( y , y_pred ) ;
194-
195- compiled_metrics . update_state ( y , y_pred ) ;
196139
197- return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Item1 , x => ( float ) x . Item2 ) ;
140+ return evaluate ( callbacks , data_handler , is_val ) ;
198141 }
199142 }
200- }
143+ }
0 commit comments