@@ -253,6 +253,12 @@ def _forward(estimators, *x):
253253 self .tb_logger .add_scalar (
254254 "bagging/Validation_Acc" , acc , epoch
255255 )
256+ # No validation
257+ else :
258+ self .estimators_ = nn .ModuleList ()
259+ self .estimators_ .extend (estimators )
260+ if save_model :
261+ io .save (self , save_dir , self .logger )
256262
257263 # Update the scheduler
258264 with warnings .catch_warnings ():
@@ -271,11 +277,6 @@ def _forward(estimators, *x):
271277 else :
272278 scheduler_ .step ()
273279
274- self .estimators_ = nn .ModuleList ()
275- self .estimators_ .extend (estimators )
276- if save_model and not test_loader :
277- io .save (self , save_dir , self .logger )
278-
279280 @torchensemble_model_doc (item = "classifier_evaluate" )
280281 def evaluate (self , test_loader , return_loss = False ):
281282 return super ().evaluate (test_loader , return_loss )
@@ -449,6 +450,12 @@ def _forward(estimators, *x):
449450 self .tb_logger .add_scalar (
450451 "bagging/Validation_Loss" , val_loss , epoch
451452 )
453+ # No validation
454+ else :
455+ self .estimators_ = nn .ModuleList ()
456+ self .estimators_ .extend (estimators )
457+ if save_model :
458+ io .save (self , save_dir , self .logger )
452459
453460 # Update the scheduler
454461 with warnings .catch_warnings ():
@@ -464,11 +471,6 @@ def _forward(estimators, *x):
464471 else :
465472 scheduler_ .step ()
466473
467- self .estimators_ = nn .ModuleList ()
468- self .estimators_ .extend (estimators )
469- if save_model and not test_loader :
470- io .save (self , save_dir , self .logger )
471-
472474 @torchensemble_model_doc (item = "regressor_evaluate" )
473475 def evaluate (self , test_loader ):
474476 return super ().evaluate (test_loader )
0 commit comments