1- using Tensorflow . NumPy ;
21using System ;
32using System . Collections . Generic ;
43using System . Linq ;
4+ using Tensorflow ;
55using Tensorflow . Keras . ArgsDefinition ;
6+ using Tensorflow . Keras . Callbacks ;
67using Tensorflow . Keras . Engine . DataAdapters ;
7- using static Tensorflow . Binding ;
88using Tensorflow . Keras . Layers ;
99using Tensorflow . Keras . Utils ;
10- using Tensorflow ;
11- using Tensorflow . Keras . Callbacks ;
10+ using Tensorflow . NumPy ;
11+ using static Tensorflow . Binding ;
1212
1313namespace Tensorflow . Keras . Engine
1414{
1515 public partial class Model
1616 {
17- protected Dictionary < string , float > evaluate ( CallbackList callbacks , DataHandler data_handler , bool is_val )
18- {
19- callbacks . on_test_begin ( ) ;
20-
21- //Dictionary<string, float>? logs = null;
22- var logs = new Dictionary < string , float > ( ) ;
23- int x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
24- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
25- {
26- reset_metrics ( ) ;
27- callbacks . on_epoch_begin ( epoch ) ;
28- // data_handler.catch_stop_iteration();
29-
30- foreach ( var step in data_handler . steps ( ) )
31- {
32- callbacks . on_test_batch_begin ( step ) ;
33-
34- var data = iterator . next ( ) ;
35-
36- logs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
37- tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _test_counter . assign_add ( 1 ) ) ;
38-
39- var end_step = step + data_handler . StepIncrement ;
40-
41- if ( ! is_val )
42- callbacks . on_test_batch_end ( end_step , logs ) ;
43- }
44- }
45-
46- return logs ;
47- }
48-
4917 /// <summary>
5018 /// Returns the loss value & metrics values for the model in test mode.
5119 /// </summary>
@@ -97,7 +65,7 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
9765 Steps = data_handler . Inferredsteps
9866 } ) ;
9967
100- return evaluate ( callbacks , data_handler , is_val ) ;
68+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
10169 }
10270
10371 public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
@@ -117,10 +85,9 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
11785 Steps = data_handler . Inferredsteps
11886 } ) ;
11987
120- return evaluate ( callbacks , data_handler , is_val ) ;
88+ return evaluate ( data_handler , callbacks , is_val , test_step_multi_inputs_function ) ;
12189 }
12290
123-
12491 public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 , bool is_val = false )
12592 {
12693 var data_handler = new DataHandler ( new DataHandlerArgs
@@ -137,7 +104,74 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
137104 Steps = data_handler . Inferredsteps
138105 } ) ;
139106
140- return evaluate ( callbacks , data_handler , is_val ) ;
107+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
108+ }
109+
110+ /// <summary>
111+ /// Internal bare implementation of evaluate function.
112+ /// </summary>
113+ /// <param name="data_handler">Interations handling objects</param>
114+ /// <param name="callbacks"></param>
115+ /// <param name="test_func">The function to be called on each batch of data.</param>
116+ /// <param name="is_val">Whether it is validation or test.</param>
117+ /// <returns></returns>
118+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
119+ {
120+ callbacks . on_test_begin ( ) ;
121+
122+ var results = new Dictionary < string , float > ( ) ;
123+ var logs = results ;
124+ foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
125+ {
126+ reset_metrics ( ) ;
127+ callbacks . on_epoch_begin ( epoch ) ;
128+ // data_handler.catch_stop_iteration();
129+
130+ foreach ( var step in data_handler . steps ( ) )
131+ {
132+ callbacks . on_test_batch_begin ( step ) ;
133+
134+ var data = iterator . next ( ) ;
135+
136+ logs = test_func ( data_handler , iterator . next ( ) ) ;
137+
138+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
139+
140+ var end_step = step + data_handler . StepIncrement ;
141+ if ( ! is_val )
142+ callbacks . on_test_batch_end ( end_step , logs ) ;
143+ }
144+
145+ if ( ! is_val )
146+ callbacks . on_epoch_end ( epoch , logs ) ;
147+ }
148+
149+ foreach ( var log in logs )
150+ {
151+ results [ log . Key ] = log . Value ;
152+ }
153+
154+ return results ;
155+ }
156+
157+ Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
158+ {
159+ var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
160+
161+ var y_pred = Apply ( x , training : false ) ;
162+ var loss = compiled_loss . Call ( y , y_pred ) ;
163+
164+ compiled_metrics . update_state ( y , y_pred ) ;
165+
166+ var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
167+ return outputs ;
168+ }
169+
170+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
171+ {
172+ var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
173+ var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
174+ return outputs ;
141175 }
142176 }
143- }
177+ }
0 commit comments