@@ -19,8 +19,10 @@ public class EarlyStopping: ICallback
1919 string _monitor ;
2020 string _mode ;
2121 bool _restore_best_weights ;
22- List < IVariableV1 > ? _best_weights ;
22+ List < NDArray > ? _best_weights ;
2323 CallbackParams _parameters ;
24+ Func < NDArray , NDArray , NDArray > _monitor_op ;
25+
2426 public Dictionary < string , List < float > > ? history { get ; set ; }
2527 // user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
2628 public EarlyStopping ( CallbackParams parameters , string monitor = "val_loss" , float min_delta = 0f , int patience = 0 ,
@@ -38,17 +40,49 @@ public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", floa
3840 _min_delta = Math . Abs ( min_delta ) ;
3941 _restore_best_weights = restore_best_weights ;
4042 _mode = mode ;
41- if ( mode != "auto" && mode != "min" && mode != "max" )
43+
44+ if ( _mode != "auto" && _mode != "min" && _mode != "max" )
45+ {
46+ Console . WriteLine ( $ "EarlyStopping mode { _mode } is unknown, fallback to auto mode.") ;
47+ _mode = "auto" ;
48+ }
49+
50+ if ( _mode == "min" )
51+ {
52+ _monitor_op = np . less ;
53+ }
54+ else if ( _mode == "max" )
55+ {
56+ _monitor_op = np . greater ;
57+ }
58+ else
59+ {
60+ if ( _monitor . EndsWith ( "acc" ) || _monitor . EndsWith ( "accuracy" ) || _monitor . EndsWith ( "auc" ) )
61+ {
62+ _monitor_op = np . greater ;
63+ }
64+ else
65+ {
66+ _monitor_op = np . less ;
67+ }
68+ }
69+
70+ if ( _monitor_op == np . greater )
4271 {
43- Console . WriteLine ( "EarlyStopping mode %s is unknown, fallback to auto mode." , mode ) ;
72+ _min_delta *= 1 ;
73+ }
74+ else
75+ {
76+ _min_delta *= - 1 ;
4477 }
4578 }
4679 public void on_train_begin ( )
4780 {
4881 _wait = 0 ;
4982 _stopped_epoch = 0 ;
83+ _best = _monitor_op == np . less ? ( float ) np . Inf : ( float ) - np . Inf ;
84+ _best_weights = null ;
5085 _best_epoch = 0 ;
51- _best = ( float ) np . Inf ;
5286 }
5387
5488 public void on_epoch_begin ( int epoch )
@@ -74,7 +108,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
74108 // Restore the weights after first epoch if no progress is ever made.
75109 if ( _restore_best_weights && _best_weights == null )
76110 {
77- _best_weights = _parameters . Model . Weights ;
111+ _best_weights = _parameters . Model . get_weights ( ) ;
78112 }
79113 _wait += 1 ;
80114
@@ -83,7 +117,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
83117 _best = current ;
84118 _best_epoch = epoch ;
85119 if ( _restore_best_weights )
86- _best_weights = _parameters . Model . TrainableWeights ;
120+ _best_weights = _parameters . Model . get_weights ( ) ;
87121 // Only restart wait if we beat both the baseline and our previous best.
88122 if ( _baseline == 0f || _is_improvement ( current , _baseline ) )
89123 _wait = 0 ;
@@ -99,7 +133,7 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
99133 {
100134 Console . WriteLine ( $ "Restoring model weights from the end of the best epoch: { _best_epoch + 1 } ") ;
101135 }
102- _parameters . Model . Weights = _best_weights ;
136+ _parameters . Model . set_weights ( _best_weights ) ;
103137 }
104138 }
105139 }
@@ -131,21 +165,7 @@ float get_monitor_value(Dictionary<string, float> logs)
131165 }
132166 public bool _is_improvement ( float monitor_value , float reference_value )
133167 {
134- bool less_op = ( monitor_value - _min_delta ) < reference_value ;
135- bool greater_op = ( monitor_value - _min_delta ) >= reference_value ;
136- if ( _mode == "min" )
137- return less_op ;
138- else if ( _mode == "max" )
139- return greater_op ;
140- else
141- {
142- if ( _monitor . EndsWith ( "acc" ) || _monitor . EndsWith ( "accuracy" ) || _monitor . EndsWith ( "auc" ) )
143- {
144- return greater_op ;
145- }
146- else
147- return less_op ;
148- }
168+ return _monitor_op ( monitor_value - _min_delta , reference_value ) ;
149169 }
150170
151171 public void on_test_end ( Dictionary < string , float > logs )
0 commit comments