@@ -98,7 +98,7 @@ def save_predict_result(task: str, preds_list: list, output_save_dir: str):
9898 save_rec_res (rec_res , img_list , save_path = os .path .join (output_save_dir , "rec_results.txt" ))
9999
100100
101- def predict_single_step (cfg ):
101+ def predict_single_step (cfg , save_res = True ):
102102 """Run predict for det task or rec task"""
103103 # 1. Set the environment information.
104104 set_context (mode = cfg .system .mode )
@@ -162,6 +162,12 @@ def predict_single_step(cfg):
162162 "The program has switched to amp_level O2 automatically."
163163 )
164164 amp_level = "O2"
165+ cfg .model .backbone .pretrained = False
166+ if cfg .predict .ckpt_load_path is None :
167+ logger .warning (
168+ f"No ckpt is available for { cfg .model .task } , "
169+ "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
170+ )
165171 network = build_model (cfg .model , ckpt_load_path = cfg .predict .ckpt_load_path , amp_level = amp_level )
166172 network .set_train (False )
167173
@@ -219,7 +225,8 @@ def predict_single_step(cfg):
219225 preds_list .append (preds )
220226
221227 # 7. Save the prediction results to the specified directory
222- save_predict_result (cfg .model .type , preds_list , output_save_dir )
228+ if save_res is True :
229+ save_predict_result (cfg .model .type , preds_list , output_save_dir )
223230 return preds_list
224231
225232
@@ -231,7 +238,7 @@ def predict_system(args, det_cfg, rec_cfg):
231238 output_save_dir = det_cfg .predict .output_save_dir or "./output"
232239
233240 # get det result from predict
234- preds_list = predict_single_step (det_cfg )
241+ preds_list = predict_single_step (det_cfg , save_res = False )
235242
236243 # set amp level
237244 amp_level = det_cfg .system .get ("amp_level_infer" , "O0" )
@@ -247,6 +254,12 @@ def predict_system(args, det_cfg, rec_cfg):
247254 postprocessor = build_postprocess (rec_cfg .postprocess )
248255
249256 # build rec model from yaml
257+ rec_cfg .model .backbone .pretrained = False
258+ if rec_cfg .predict .ckpt_load_path is None :
259+ logger .warning (
260+ f"No ckpt is available for { rec_cfg .model .type } , "
261+ "please check your configuration of 'predict.ckpt_load_path' in the yaml file."
262+ )
250263 rec_network = build_model (rec_cfg .model , ckpt_load_path = rec_cfg .predict .ckpt_load_path , amp_level = amp_level )
251264
252265 # start rec task
0 commit comments