@@ -27,7 +27,7 @@ public partial class Model
2727 /// <param name="use_multiprocessing"></param>
2828 /// <param name="return_dict"></param>
2929 /// <param name="is_val"></param>
30- public Dictionary < string , float > evaluate ( Tensor x , Tensor y ,
30+ public Dictionary < string , float > evaluate ( NDArray x , NDArray y ,
3131 int batch_size = - 1 ,
3232 int verbose = 1 ,
3333 int steps = - 1 ,
@@ -115,62 +115,53 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
115115 /// <param name="test_func">The function to be called on each batch of data.</param>
116116 /// <param name="is_val">Whether it is validation or test.</param>
117117 /// <returns></returns>
118- Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
118+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , OwnedIterator , Dictionary < string , float > > test_func )
119119 {
120120 callbacks . on_test_begin ( ) ;
121121
122- var results = new Dictionary < string , float > ( ) ;
123- var logs = results ;
122+ var logs = new Dictionary < string , float > ( ) ;
124123 foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
125124 {
126125 reset_metrics ( ) ;
127- callbacks . on_epoch_begin ( epoch ) ;
128- // data_handler.catch_stop_iteration();
129-
130126 foreach ( var step in data_handler . steps ( ) )
131127 {
132128 callbacks . on_test_batch_begin ( step ) ;
133-
134- logs = test_func ( data_handler , iterator . next ( ) ) ;
135-
136- tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
137-
129+ logs = test_func ( data_handler , iterator ) ;
138130 var end_step = step + data_handler . StepIncrement ;
139131 if ( ! is_val )
140132 callbacks . on_test_batch_end ( end_step , logs ) ;
141133 }
142-
143- if ( ! is_val )
144- callbacks . on_epoch_end ( epoch , logs ) ;
145134 }
146-
147- foreach ( var log in logs )
148- {
149- results [ log . Key ] = log . Value ;
150- }
151-
135+ callbacks . on_test_end ( logs ) ;
136+ var results = new Dictionary < string , float > ( logs ) ;
152137 return results ;
153138 }
154139
155- Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
140+ Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
156141 {
157- var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
158-
159- var y_pred = Apply ( x , training : false ) ;
160- var loss = compiled_loss . Call ( y , y_pred ) ;
161-
162- compiled_metrics . update_state ( y , y_pred ) ;
163-
164- var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
142+ var data = iterator . next ( ) ;
143+ var outputs = test_step ( data_handler , data [ 0 ] , data [ 1 ] ) ;
144+ tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
165145 return outputs ;
166146 }
167147
168- Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
148+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , OwnedIterator iterator )
169149 {
150+ var data = iterator . next ( ) ;
170151 var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
171- var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) . ToArray ( ) ) , new Tensors ( data . Skip ( x_size ) . ToArray ( ) ) ) ;
172- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _train_counter . assign_add ( 1 ) ) ;
152+ var outputs = test_step ( data_handler , data . Take ( x_size ) . ToArray ( ) , data . Skip ( x_size ) . ToArray ( ) ) ;
153+ tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
173154 return outputs ;
174155 }
156+
157+
158+ Dictionary < string , float > test_step ( DataHandler data_handler , Tensors x , Tensors y )
159+ {
160+ ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
161+ var y_pred = Apply ( x , training : false ) ;
162+ var loss = compiled_loss . Call ( y , y_pred ) ;
163+ compiled_metrics . update_state ( y , y_pred ) ;
164+ return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Item1 , x => ( float ) x . Item2 ) ;
165+ }
175166 }
176167}
0 commit comments