|
5 | 5 | $ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/db++_r50_icdar15.yaml |
6 | 6 | $ python tools/infer/text/predict_from_yaml.py --config configs/rec/crnn/crnn_resnet34.yaml |
7 | 7 | """ |
| 8 | +import argparse |
8 | 9 | import logging |
9 | 10 | import os |
10 | 11 | import sys |
11 | 12 |
|
| 13 | +import yaml |
12 | 14 | from addict import Dict |
13 | 15 | from PIL import Image |
14 | 16 | from predict_det import save_det_res, validate_det_res |
15 | 17 | from predict_rec import save_rec_res |
| 18 | +from predict_system import save_res |
16 | 19 | from tqdm import tqdm |
| 20 | +from utils import crop_text_region |
17 | 21 |
|
18 | | -from mindspore import get_context, set_auto_parallel_context, set_context |
| 22 | +from mindspore import Tensor, get_context, set_auto_parallel_context, set_context |
19 | 23 | from mindspore.communication import get_group_size, get_rank, init |
20 | 24 |
|
21 | 25 | from mindocr.data import build_dataset |
| 26 | +from mindocr.data.transforms import create_transforms, run_transforms |
22 | 27 | from mindocr.models import build_model |
23 | 28 | from mindocr.postprocess import build_postprocess |
24 | 29 | from mindocr.utils.visualize import draw_boxes, show_imgs |
25 | | -from tools.arg_parser import parse_args_and_config |
| 30 | +from tools.arg_parser import _merge_options, _parse_options |
| 31 | +from tools.modelarts_adapter.modelarts import modelarts_setup |
26 | 32 |
|
27 | 33 | __dir__ = os.path.dirname(os.path.abspath(__file__)) |
28 | 34 | sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) |
@@ -92,7 +98,8 @@ def save_predict_result(task: str, preds_list: list, output_save_dir: str): |
92 | 98 | save_rec_res(rec_res, img_list, save_path=os.path.join(output_save_dir, "rec_results.txt")) |
93 | 99 |
|
94 | 100 |
|
95 | | -def main(cfg): |
| 101 | +def predict_single_step(cfg): |
| 102 | + """Run predict for det task or rec task""" |
96 | 103 | # 1. Set the environment information. |
97 | 104 | set_context(mode=cfg.system.mode) |
98 | 105 | output_save_dir = cfg.predict.output_save_dir or "./output" |
@@ -170,6 +177,7 @@ def main(cfg): |
170 | 177 | meta_data_indices = cfg.predict.dataset.pop("meta_data_column_index", None) |
171 | 178 |
|
172 | 179 | # 6.Start prediction |
| 180 | + logger.info(f"Start {cfg.model.type}") |
173 | 181 | preds_list = [] |
174 | 182 | for i, data in tqdm(enumerate(iterator), total=num_batches_predict): |
175 | 183 | if input_indices is not None: |
@@ -199,13 +207,162 @@ def main(cfg): |
199 | 207 | # Add "img_ori" to preds if present, which means task is det |
200 | 208 | if "image_ori" in output_columns: |
201 | 209 | preds["img_ori"] = data[output_columns.index("image_ori")].numpy() |
| 210 | + if "polys" in preds: |
| 211 | + preds["crops"] = [] |
| 212 | + polys_batch = preds["polys"].copy() |
| 213 | + for i, polys in enumerate(polys_batch): |
| 214 | + crops_per_img = [] |
| 215 | + for poly in polys: |
| 216 | + cropped_img = crop_text_region(preds["img_ori"][i], poly, box_type=cfg.postprocess.box_type) |
| 217 | + crops_per_img.append(cropped_img) |
| 218 | + preds["crops"].append(crops_per_img) |
202 | 219 | preds_list.append(preds) |
203 | 220 |
|
204 | 221 | # 7. Save the prediction results to the specified directory |
205 | 222 | save_predict_result(cfg.model.type, preds_list, output_save_dir) |
| 223 | + return preds_list |
| 224 | + |
| 225 | + |
| 226 | +def predict_system(args, det_cfg, rec_cfg): |
| 227 | + """Run predict for both det and rec task""" |
| 228 | + # merge image_dir option in model config |
| 229 | + det_cfg.predict.dataset_root = "" |
| 230 | + det_cfg.predict.data_dir = args.image_dir |
| 231 | + output_save_dir = det_cfg.predict.output_save_dir or "./output" |
| 232 | + |
| 233 | + # get det result from predict |
| 234 | + preds_list = predict_single_step(det_cfg) |
| 235 | + |
| 236 | + # set amp level |
| 237 | + amp_level = det_cfg.system.get("amp_level_infer", "O0") |
| 238 | + if get_context("device_target") == "GPU" and amp_level == "O3": |
| 239 | + logger.warning( |
| 240 | + "Model evaluation does not support amp_level O3 on GPU currently. " |
| 241 | + "The program has switched to amp_level O2 automatically." |
| 242 | + ) |
| 243 | + amp_level = "O2" |
| 244 | + |
| 245 | + # create preprocess and postprocess for rec task |
| 246 | + transforms = create_transforms(rec_cfg.predict.dataset.transform_pipeline) |
| 247 | + postprocessor = build_postprocess(rec_cfg.postprocess) |
| 248 | + |
| 249 | + # build rec model from yaml |
| 250 | + rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level) |
| 251 | + |
| 252 | + # start rec task |
| 253 | + logger.info("Start rec") |
| 254 | + img_list = [] # list of img_path |
| 255 | + boxes_all = [] # list of boxes of all image |
| 256 | + text_scores_all = [] # list of text and scores of all image |
| 257 | + for preds_batch in tqdm(preds_list): |
| 258 | + # preds_batch is a dictionary of det prediction output, which contains det information of a batch |
| 259 | + preds_batch["texts"] = [] |
| 260 | + preds_batch["confs"] = [] |
| 261 | + for i, crops in enumerate(preds_batch["crops"]): |
| 262 | + # A batch may contain multiple images |
| 263 | + img_path = preds_batch["img_path"][i] |
| 264 | + img_box = [] |
| 265 | + img_text_scores = [] |
| 266 | + for j, crop in enumerate(crops): |
| 267 | + # For each image, it may contain several crops |
| 268 | + data = {"image": crop} |
| 269 | + data["image_ori"] = crop.copy() |
| 270 | + data["image_shape"] = crop.shape |
| 271 | + data = run_transforms(data, transforms[1:]) |
| 272 | + data = rec_network(Tensor(data["image"]).expand_dims(0)) |
| 273 | + out = postprocessor(data) |
| 274 | + confs = out["confs"][0] |
| 275 | + if confs > 0.5: |
| 276 | + # Keep text with a confidence greater than 0.5 |
| 277 | + box = preds_batch["polys"][i][j] |
| 278 | + text = out["texts"][0] |
| 279 | + img_box.append(box) |
| 280 | + img_text_scores.append((text, confs)) |
| 281 | + # Each image saves its path, box and texts_scores |
| 282 | + img_list.append(img_path) |
| 283 | + boxes_all.append(img_box) |
| 284 | + text_scores_all.append(img_text_scores) |
| 285 | + save_res(boxes_all, text_scores_all, img_list, save_path=os.path.join(output_save_dir, "system_results.txt")) |
| 286 | + |
| 287 | + |
| 288 | +def create_parser(): |
| 289 | + parser = argparse.ArgumentParser(description="Training Config", add_help=False) |
| 290 | + parser.add_argument("--image_dir", type=str, help="image path or image directory") |
| 291 | + parser.add_argument("--task_mode", type=str, default="system", choices=["det", "rec", "system"], help="Task mode") |
| 292 | + parser.add_argument( |
| 293 | + "--det_config", |
| 294 | + type=str, |
| 295 | + default="configs/det/dbnet/db_r50_icdar15.yaml", |
| 296 | + help='YAML config file specifying default arguments for det (default="configs/det/dbnet/db_r50_icdar15.yaml")', |
| 297 | + ) |
| 298 | + parser.add_argument( |
| 299 | + "--rec_config", |
| 300 | + type=str, |
| 301 | + default="configs/rec/crnn/crnn_resnet34.yaml", |
| 302 | + help='YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")', |
| 303 | + ) |
| 304 | + parser.add_argument( |
| 305 | + "-o", |
| 306 | + "--opt", |
| 307 | + nargs="+", |
| 308 | + help="Options to change yaml configuration values, " |
| 309 | + "e.g. `-o system.distribute=False eval.dataset.dataset_root=/my_path/to/ocr_data`", |
| 310 | + ) |
| 311 | + # modelarts |
| 312 | + group = parser.add_argument_group("modelarts") |
| 313 | + group.add_argument("--enable_modelarts", type=bool, default=False, help="Run on modelarts platform (default=False)") |
| 314 | + group.add_argument( |
| 315 | + "--device_target", |
| 316 | + type=str, |
| 317 | + default="Ascend", |
| 318 | + help="Target device, only used on modelarts platform (default=Ascend)", |
| 319 | + ) |
| 320 | + # The url are provided by modelart, usually they are S3 paths |
| 321 | + group.add_argument("--multi_data_url", type=str, default="", help="path to multi dataset") |
| 322 | + group.add_argument("--data_url", type=str, default="", help="path to dataset") |
| 323 | + group.add_argument("--ckpt_url", type=str, default="", help="pre_train_model path in obs") |
| 324 | + group.add_argument("--pretrain_url", type=str, default="", help="pre_train_model paths in obs") |
| 325 | + group.add_argument("--train_url", type=str, default="", help="model folder to save/load") |
| 326 | + |
| 327 | + # args = parser.parse_args() |
| 328 | + |
| 329 | + return parser |
| 330 | + |
| 331 | + |
| 332 | +def parse_args_and_config(): |
| 333 | + """ |
| 334 | + Return: |
| 335 | + args: command line argments |
| 336 | + cfg: train/eval config dict |
| 337 | + """ |
| 338 | + parser = create_parser() |
| 339 | + args = parser.parse_args() # CLI args |
| 340 | + |
| 341 | + modelarts_setup(args) |
| 342 | + if args.task_mode == "system" and args.image_dir is None: |
| 343 | + raise ValueError("When the task is 'ocr', the 'image_dir' is necessary.") |
| 344 | + |
| 345 | + with open(args.det_config, "r") as f: |
| 346 | + det_cfg = yaml.safe_load(f) |
| 347 | + with open(args.rec_config, "r") as f: |
| 348 | + rec_cfg = yaml.safe_load(f) |
| 349 | + |
| 350 | + if args.opt: |
| 351 | + options = _parse_options(args.opt) |
| 352 | + det_cfg = _merge_options(det_cfg, options) |
| 353 | + rec_cfg = _merge_options(rec_cfg, options) |
| 354 | + return args, det_cfg, rec_cfg |
206 | 355 |
|
207 | 356 |
|
208 | 357 | if __name__ == "__main__": |
209 | | - args, config = parse_args_and_config() |
210 | | - config = Dict(config) |
211 | | - main(config) |
| 358 | + args, det_cfg, rec_cfg = parse_args_and_config() |
| 359 | + if args.task_mode == "det": |
| 360 | + det_cfg = Dict(det_cfg) |
| 361 | + predict_single_step(det_cfg) |
| 362 | + elif args.task_mode == "rec": |
| 363 | + rec_cfg = Dict(rec_cfg) |
| 364 | + predict_single_step(rec_cfg) |
| 365 | + elif args.task_mode == "system": |
| 366 | + rec_cfg = Dict(rec_cfg) |
| 367 | + det_cfg = Dict(det_cfg) |
| 368 | + predict_system(args, det_cfg, rec_cfg) |
0 commit comments