Skip to content

Commit c700d02

Browse files
authored
Add system mode to predict_from_yaml.py (#652)
1 parent 5120a2a commit c700d02

File tree

3 files changed

+163
-8
lines changed

3 files changed

+163
-8
lines changed

configs/det/dbnet/db_r50_icdar15.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ predict:
164164
type: PredictDataset
165165
dataset_root: path/to/dataset_root
166166
data_dir: ic15/det/test/ch4_test_images
167-
# label_file: test.txt
168167
sample_ratio: 1.0
169168
transform_pipeline:
170169
- DecodeImage:

configs/rec/crnn/crnn_resnet34.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,6 @@ predict:
157157
type: PredictDataset
158158
dataset_root: path/to/dataset_root
159159
data_dir: predict_result/crop
160-
# label_files: # not required when using LMDBDataset
161160
sample_ratio: 1.0
162161
shuffle: False
163162
transform_pipeline:

tools/infer/text/predict_from_yaml.py

Lines changed: 163 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,30 @@
55
$ python tools/infer/text/predict_from_yaml.py --config configs/det/dbnet/db++_r50_icdar15.yaml
66
$ python tools/infer/text/predict_from_yaml.py --config configs/rec/crnn/crnn_resnet34.yaml
77
"""
8+
import argparse
89
import logging
910
import os
1011
import sys
1112

13+
import yaml
1214
from addict import Dict
1315
from PIL import Image
1416
from predict_det import save_det_res, validate_det_res
1517
from predict_rec import save_rec_res
18+
from predict_system import save_res
1619
from tqdm import tqdm
20+
from utils import crop_text_region
1721

18-
from mindspore import get_context, set_auto_parallel_context, set_context
22+
from mindspore import Tensor, get_context, set_auto_parallel_context, set_context
1923
from mindspore.communication import get_group_size, get_rank, init
2024

2125
from mindocr.data import build_dataset
26+
from mindocr.data.transforms import create_transforms, run_transforms
2227
from mindocr.models import build_model
2328
from mindocr.postprocess import build_postprocess
2429
from mindocr.utils.visualize import draw_boxes, show_imgs
25-
from tools.arg_parser import parse_args_and_config
30+
from tools.arg_parser import _merge_options, _parse_options
31+
from tools.modelarts_adapter.modelarts import modelarts_setup
2632

2733
__dir__ = os.path.dirname(os.path.abspath(__file__))
2834
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../")))
@@ -92,7 +98,8 @@ def save_predict_result(task: str, preds_list: list, output_save_dir: str):
9298
save_rec_res(rec_res, img_list, save_path=os.path.join(output_save_dir, "rec_results.txt"))
9399

94100

95-
def main(cfg):
101+
def predict_single_step(cfg):
102+
"""Run predict for det task or rec task"""
96103
# 1. Set the environment information.
97104
set_context(mode=cfg.system.mode)
98105
output_save_dir = cfg.predict.output_save_dir or "./output"
@@ -170,6 +177,7 @@ def main(cfg):
170177
meta_data_indices = cfg.predict.dataset.pop("meta_data_column_index", None)
171178

172179
# 6.Start prediction
180+
logger.info(f"Start {cfg.model.type}")
173181
preds_list = []
174182
for i, data in tqdm(enumerate(iterator), total=num_batches_predict):
175183
if input_indices is not None:
@@ -199,13 +207,162 @@ def main(cfg):
199207
# Add "img_ori" to preds if present, which means task is det
200208
if "image_ori" in output_columns:
201209
preds["img_ori"] = data[output_columns.index("image_ori")].numpy()
210+
if "polys" in preds:
211+
preds["crops"] = []
212+
polys_batch = preds["polys"].copy()
213+
for i, polys in enumerate(polys_batch):
214+
crops_per_img = []
215+
for poly in polys:
216+
cropped_img = crop_text_region(preds["img_ori"][i], poly, box_type=cfg.postprocess.box_type)
217+
crops_per_img.append(cropped_img)
218+
preds["crops"].append(crops_per_img)
202219
preds_list.append(preds)
203220

204221
# 7. Save the prediction results to the specified directory
205222
save_predict_result(cfg.model.type, preds_list, output_save_dir)
223+
return preds_list
224+
225+
226+
def predict_system(args, det_cfg, rec_cfg):
227+
"""Run predict for both det and rec task"""
228+
# merge image_dir option in model config
229+
det_cfg.predict.dataset_root = ""
230+
det_cfg.predict.data_dir = args.image_dir
231+
output_save_dir = det_cfg.predict.output_save_dir or "./output"
232+
233+
# get det result from predict
234+
preds_list = predict_single_step(det_cfg)
235+
236+
# set amp level
237+
amp_level = det_cfg.system.get("amp_level_infer", "O0")
238+
if get_context("device_target") == "GPU" and amp_level == "O3":
239+
logger.warning(
240+
"Model evaluation does not support amp_level O3 on GPU currently. "
241+
"The program has switched to amp_level O2 automatically."
242+
)
243+
amp_level = "O2"
244+
245+
# create preprocess and postprocess for rec task
246+
transforms = create_transforms(rec_cfg.predict.dataset.transform_pipeline)
247+
postprocessor = build_postprocess(rec_cfg.postprocess)
248+
249+
# build rec model from yaml
250+
rec_network = build_model(rec_cfg.model, ckpt_load_path=rec_cfg.predict.ckpt_load_path, amp_level=amp_level)
251+
252+
# start rec task
253+
logger.info("Start rec")
254+
img_list = [] # list of img_path
255+
boxes_all = [] # list of boxes of all image
256+
text_scores_all = [] # list of text and scores of all image
257+
for preds_batch in tqdm(preds_list):
258+
# preds_batch is a dictionary of det prediction output, which contains det information of a batch
259+
preds_batch["texts"] = []
260+
preds_batch["confs"] = []
261+
for i, crops in enumerate(preds_batch["crops"]):
262+
# A batch may contain multiple images
263+
img_path = preds_batch["img_path"][i]
264+
img_box = []
265+
img_text_scores = []
266+
for j, crop in enumerate(crops):
267+
# For each image, it may contain several crops
268+
data = {"image": crop}
269+
data["image_ori"] = crop.copy()
270+
data["image_shape"] = crop.shape
271+
data = run_transforms(data, transforms[1:])
272+
data = rec_network(Tensor(data["image"]).expand_dims(0))
273+
out = postprocessor(data)
274+
confs = out["confs"][0]
275+
if confs > 0.5:
276+
# Keep text with a confidence greater than 0.5
277+
box = preds_batch["polys"][i][j]
278+
text = out["texts"][0]
279+
img_box.append(box)
280+
img_text_scores.append((text, confs))
281+
# Each image saves its path, box and texts_scores
282+
img_list.append(img_path)
283+
boxes_all.append(img_box)
284+
text_scores_all.append(img_text_scores)
285+
save_res(boxes_all, text_scores_all, img_list, save_path=os.path.join(output_save_dir, "system_results.txt"))
286+
287+
288+
def create_parser():
289+
parser = argparse.ArgumentParser(description="Training Config", add_help=False)
290+
parser.add_argument("--image_dir", type=str, help="image path or image directory")
291+
parser.add_argument("--task_mode", type=str, default="system", choices=["det", "rec", "system"], help="Task mode")
292+
parser.add_argument(
293+
"--det_config",
294+
type=str,
295+
default="configs/det/dbnet/db_r50_icdar15.yaml",
296+
help='YAML config file specifying default arguments for det (default="configs/det/dbnet/db_r50_icdar15.yaml")',
297+
)
298+
parser.add_argument(
299+
"--rec_config",
300+
type=str,
301+
default="configs/rec/crnn/crnn_resnet34.yaml",
302+
help='YAML config file specifying default arguments for rec (default="configs/rec/crnn/crnn_resnet34.yaml")',
303+
)
304+
parser.add_argument(
305+
"-o",
306+
"--opt",
307+
nargs="+",
308+
help="Options to change yaml configuration values, "
309+
"e.g. `-o system.distribute=False eval.dataset.dataset_root=/my_path/to/ocr_data`",
310+
)
311+
# modelarts
312+
group = parser.add_argument_group("modelarts")
313+
group.add_argument("--enable_modelarts", type=bool, default=False, help="Run on modelarts platform (default=False)")
314+
group.add_argument(
315+
"--device_target",
316+
type=str,
317+
default="Ascend",
318+
help="Target device, only used on modelarts platform (default=Ascend)",
319+
)
320+
# The url are provided by modelart, usually they are S3 paths
321+
group.add_argument("--multi_data_url", type=str, default="", help="path to multi dataset")
322+
group.add_argument("--data_url", type=str, default="", help="path to dataset")
323+
group.add_argument("--ckpt_url", type=str, default="", help="pre_train_model path in obs")
324+
group.add_argument("--pretrain_url", type=str, default="", help="pre_train_model paths in obs")
325+
group.add_argument("--train_url", type=str, default="", help="model folder to save/load")
326+
327+
# args = parser.parse_args()
328+
329+
return parser
330+
331+
332+
def parse_args_and_config():
333+
"""
334+
Return:
335+
args: command line argments
336+
cfg: train/eval config dict
337+
"""
338+
parser = create_parser()
339+
args = parser.parse_args() # CLI args
340+
341+
modelarts_setup(args)
342+
if args.task_mode == "system" and args.image_dir is None:
343+
raise ValueError("When the task is 'ocr', the 'image_dir' is necessary.")
344+
345+
with open(args.det_config, "r") as f:
346+
det_cfg = yaml.safe_load(f)
347+
with open(args.rec_config, "r") as f:
348+
rec_cfg = yaml.safe_load(f)
349+
350+
if args.opt:
351+
options = _parse_options(args.opt)
352+
det_cfg = _merge_options(det_cfg, options)
353+
rec_cfg = _merge_options(rec_cfg, options)
354+
return args, det_cfg, rec_cfg
206355

207356

208357
if __name__ == "__main__":
209-
args, config = parse_args_and_config()
210-
config = Dict(config)
211-
main(config)
358+
args, det_cfg, rec_cfg = parse_args_and_config()
359+
if args.task_mode == "det":
360+
det_cfg = Dict(det_cfg)
361+
predict_single_step(det_cfg)
362+
elif args.task_mode == "rec":
363+
rec_cfg = Dict(rec_cfg)
364+
predict_single_step(rec_cfg)
365+
elif args.task_mode == "system":
366+
rec_cfg = Dict(rec_cfg)
367+
det_cfg = Dict(det_cfg)
368+
predict_system(args, det_cfg, rec_cfg)

0 commit comments

Comments
 (0)