@@ -580,19 +580,19 @@ def _get_result_file(model_path, model_name):
580580def load_model (model_path , model_name , net = None ):
581581 config_file = _get_config_file (model_path , model_name )
582582 model_file = _get_model_file (model_path , model_name )
583- assert os .path .isfile (
584- config_file
585- ), f'Could not find the config file " { config_file } ". Are you sure this is the correct path and you have your model config stored here?'
586- assert os .path .isfile (
587- model_file
588- ), f'Could not find the model file " { model_file } ". Are you sure this is the correct path and you have your model stored here?'
583+ assert os .path .isfile (config_file ), (
584+ f'Could not find the config file " { config_file } ". Are you sure this is the correct path and you have your model config stored here?'
585+ )
586+ assert os .path .isfile (model_file ), (
587+ f'Could not find the model file " { model_file } ". Are you sure this is the correct path and you have your model stored here?'
588+ )
589589 with open (config_file ) as f :
590590 config_dict = json .load (f )
591591 if net is None :
592592 act_fn_name = config_dict ["act_fn" ].pop ("name" ).lower ()
593- assert (
594- act_fn_name in act_fn_by_name
595- ), f'Unknown activation function " { act_fn_name } ". Please add it to the "act_fn_by_name" dict.'
593+ assert act_fn_name in act_fn_by_name , (
594+ f'Unknown activation function " { act_fn_name } ". Please add it to the " act_fn_by_name" dict.'
595+ )
596596 act_fn = act_fn_by_name [act_fn_name ]()
597597 net = BaseNetwork (act_fn = act_fn , ** config_dict )
598598 net .load_state_dict (torch .load (model_file ))
@@ -678,7 +678,7 @@ def train_model(net, model_name, optim_func, max_epochs=50, batch_size=256, over
678678 plt .show ()
679679 plt .close ()
680680
681- print ((f" Test accuracy: { results ['test_acc' ]* 100.0 :4.2f} % " ).center (50 , "=" ) + "\n " )
681+ print ((f" Test accuracy: { results ['test_acc' ] * 100.0 :4.2f} % " ).center (50 , "=" ) + "\n " )
682682 return results
683683
684684
@@ -700,7 +700,7 @@ def epoch_iteration(net, loss_module, optimizer, train_loader_local, val_loader,
700700 # Record statistics during training
701701 true_preds += (preds .argmax (dim = - 1 ) == labels ).sum ().item ()
702702 count += labels .shape [0 ]
703- t .set_description (f"Epoch { epoch + 1 } : loss={ loss .item ():4.2f} " )
703+ t .set_description (f"Epoch { epoch + 1 } : loss={ loss .item ():4.2f} " )
704704 epoch_losses .append (loss .item ())
705705 train_acc = true_preds / count
706706
@@ -709,7 +709,7 @@ def epoch_iteration(net, loss_module, optimizer, train_loader_local, val_loader,
709709 ##############
710710 val_acc = test_model (net , val_loader )
711711 print (
712- f"[Epoch { epoch + 1 :2i} ] Training accuracy: { train_acc * 100.0 :05.2f} %, Validation accuracy: { val_acc * 100.0 :05.2f} %"
712+ f"[Epoch { epoch + 1 :2i} ] Training accuracy: { train_acc * 100.0 :05.2f} %, Validation accuracy: { val_acc * 100.0 :05.2f} %"
713713 )
714714 return train_acc , val_acc , epoch_losses
715715
0 commit comments