Skip to content

Commit ec20d02

Browse files
panshaowupanshaowu
andauthored
improve training performance of db_mv3 (#646)
Co-authored-by: panshaowu <panshaowu@huawei.com>
1 parent 40dd378 commit ec20d02

File tree

5 files changed

+317
-2
lines changed

5 files changed

+317
-2
lines changed

configs/det/dbnet/db_mobilenetv3_icdar15.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ train:
7878
data_dir: ic15/det/train/ch4_training_images
7979
label_file: ic15/det/train/det_gt.txt
8080
sample_ratio: 1.0
81+
use_minddata: True
8182
transform_pipeline:
8283
- DecodeImage:
8384
img_mode: RGB
@@ -135,6 +136,7 @@ eval:
135136
data_dir: ic15/det/test/ch4_test_images
136137
label_file: ic15/det/test/det_gt.txt
137138
sample_ratio: 1.0
139+
use_minddata: True
138140
transform_pipeline:
139141
- DecodeImage:
140142
img_mode: RGB
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
system:
2+
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
3+
distribute: True
4+
amp_level: 'O0'
5+
seed: 42
6+
log_interval: 10
7+
val_while_train: True
8+
val_start_epoch: 500
9+
drop_overflow_update: False
10+
11+
model:
12+
type: det
13+
transform: null
14+
backbone:
15+
name: det_mobilenet_v3
16+
architecture: large
17+
alpha: 0.5
18+
out_stages: [5, 8, 14, 20]
19+
bottleneck_params:
20+
se_version: SqueezeExciteV2
21+
always_expand: True
22+
pretrained: https://download.mindspore.cn/toolkits/mindcv/mobilenet/mobilenetv3/mobilenet_v3_large_050_no_scale_se_v2_expand-3c4047ac.ckpt
23+
neck:
24+
name: DBFPN
25+
out_channels: 256
26+
bias: False
27+
head:
28+
name: DBHead
29+
k: 50
30+
bias: False
31+
adaptive: True
32+
33+
postprocess:
34+
name: DBPostprocess
35+
box_type: quad # whether to output a polygon or a box
36+
binary_thresh: 0.3 # binarization threshold
37+
box_thresh: 0.6 # box score threshold
38+
max_candidates: 1000
39+
expand_ratio: 1.5 # coefficient for expanding predictions
40+
41+
metric:
42+
name: DetMetric
43+
main_indicator: f-score
44+
45+
loss:
46+
name: DBLoss
47+
eps: 1.0e-6
48+
l1_scale: 10
49+
bce_scale: 5
50+
bce_replace: bceloss
51+
52+
scheduler:
53+
scheduler: polynomial_decay
54+
lr: 0.02
55+
num_epochs: 2000
56+
decay_rate: 0.9
57+
warmup_epochs: 3
58+
59+
optimizer:
60+
opt: momentum
61+
filter_bias_and_bn: false
62+
momentum: 0.9
63+
weight_decay: 1.0e-4
64+
65+
# only used for mixed precision training
66+
loss_scaler:
67+
type: dynamic
68+
loss_scale: 512
69+
scale_factor: 2
70+
scale_window: 1000
71+
72+
train:
73+
ckpt_save_dir: './tmp_det'
74+
dataset_sink_mode: True
75+
dataset:
76+
type: DetDataset
77+
dataset_root: /data/ocr_datasets
78+
data_dir: ic15/det/train/ch4_training_images
79+
label_file: ic15/det/train/det_gt.txt
80+
sample_ratio: 1.0
81+
use_minddata: True
82+
transform_pipeline:
83+
- DecodeImage:
84+
img_mode: RGB
85+
to_float32: False
86+
- DetLabelEncode:
87+
- RandomColorAdjust:
88+
brightness: 0.1255 # 32.0 / 255
89+
saturation: 0.5
90+
- RandomHorizontalFlip:
91+
p: 0.5
92+
- RandomRotate:
93+
degrees: [ -10, 10 ]
94+
expand_canvas: False
95+
p: 1.0
96+
- RandomScale:
97+
scale_range: [ 0.5, 3.0 ]
98+
p: 1.0
99+
- RandomCropWithBBox:
100+
max_tries: 10
101+
min_crop_ratio: 0.1
102+
crop_size: [ 640, 640 ]
103+
p: 1.0
104+
- ValidatePolygons:
105+
- ShrinkBinaryMap:
106+
min_text_size: 8
107+
shrink_ratio: 0.4
108+
- BorderMap:
109+
shrink_ratio: 0.4
110+
thresh_min: 0.3
111+
thresh_max: 0.7
112+
- NormalizeImage:
113+
bgr_to_rgb: False
114+
is_hwc: True
115+
mean: imagenet
116+
std: imagenet
117+
- ToCHWImage:
118+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
119+
output_columns: [ 'image', 'binary_map', 'mask', 'thresh_map', 'thresh_mask']
120+
# output_columns: ['image'] # for debug op performance
121+
net_input_column_index: [0] # input indices for network forward func in output_columns
122+
label_column_index: [1, 2, 3, 4] # input indices marked as label
123+
124+
loader:
125+
shuffle: True
126+
batch_size: 8
127+
drop_remainder: True
128+
num_workers: 10
129+
130+
eval:
131+
ckpt_load_path: tmp_det/best.ckpt
132+
dataset_sink_mode: False
133+
dataset:
134+
type: DetDataset
135+
dataset_root: /data/ocr_datasets
136+
data_dir: ic15/det/test/ch4_test_images
137+
label_file: ic15/det/test/det_gt.txt
138+
sample_ratio: 1.0
139+
use_minddata: True
140+
transform_pipeline:
141+
- DecodeImage:
142+
img_mode: RGB
143+
to_float32: False
144+
- DetLabelEncode:
145+
- DetResize: # GridResize 32
146+
target_size: [ 736, 1280 ]
147+
keep_ratio: False
148+
limit_type: none
149+
divisor: 32
150+
- NormalizeImage:
151+
bgr_to_rgb: False
152+
is_hwc: True
153+
mean: imagenet
154+
std: imagenet
155+
- ToCHWImage:
156+
# the order of the dataloader list, matching the network input and the labels for evaluation
157+
output_columns: [ 'image', 'polys', 'ignore_tags', 'shape_list' ]
158+
net_input_column_index: [0] # input indices for network forward func in output_columns
159+
label_column_index: [1, 2] # input indices marked as label
160+
161+
loader:
162+
shuffle: False
163+
batch_size: 1 # TODO: due to dynamic shape of polygons (num of boxes varies), BS has to be 1
164+
drop_remainder: False
165+
num_workers: 3

docs/cn/tutorials/frequently_asked_questions.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- [关于`RunTimeError:The device address tpe is wrong`](#q6-runtimeerror-the-device-address-type-is-wrong-type-name-in-addresscpu-type-name-in-contextascend)
88
- [模型转换相关问题](#q7-模型转换相关问题)
99
- [推理相关问题](#q8-推理时相关问题)
10+
- [DBNet训练速率不及预期](#q9-DBNet训练速率不及预期)
1011

1112
### Q1 未定义符号
1213

@@ -618,3 +619,58 @@ ERROR: Could not build wheels for lanms-neo, which is required to install pyproj
618619
619620
- 使用恰当的模型。例如在 `--rec_model_path` 错误传入了检测模型,可触发此错误;
620621
- 使用推理模型(非训练模型),用`converter_lite`转换工具转为端侧`mindir`进行推理。
622+
623+
624+
### Q9 DBNet训练速率不及预期
625+
626+
执行以下命令,训练DBNet系列网络(包括DBNet MobileNetV3、DBNet ResNet-18、DBNet ResNet-50、DBNet++ ResNet-50等)时,训练帧率不及预期。例如,DBNet MobileNetV3在Ascend 910A上,训练速率仅80fps,不及预期的100fps。
627+
628+
``` bash
629+
python tools/train.py -c configs/det/dbnet/db_mobilenetv3_icdar15.yaml
630+
```
631+
632+
由于DBNet数据预处理过程相对复杂,如训练服务器CPU单核运算能力较弱,则数据预处理可能成为性能瓶颈。
633+
634+
**解决方法**
635+
636+
1. 尝试将配置文件中`train.dataset.use_minddata``eval.dataset.use_minddata`的选项设置为`True`。MindOCR将采用MindSpore[MindData](https://www.mindspore.cn/docs/zh-CN/master/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.map.html?highlight=map#mindspore.dataset.Dataset.map)执行部分数据预处理步骤:
637+
638+
``` yaml
639+
...
640+
train:
641+
ckpt_save_dir: './tmp_det'
642+
dataset_sink_mode: True
643+
dataset:
644+
type: DetDataset
645+
dataset_root: /data/ocr_datasets
646+
data_dir: ic15/det/train/ch4_training_images
647+
label_file: ic15/det/train/det_gt.txt
648+
sample_ratio: 1.0
649+
use_minddata: True <-- 设置该选项
650+
...
651+
eval:
652+
ckpt_load_path: tmp_det/best.ckpt
653+
dataset_sink_mode: False
654+
dataset:
655+
type: DetDataset
656+
dataset_root: /data/ocr_datasets
657+
data_dir: ic15/det/test/ch4_test_images
658+
label_file: ic15/det/test/det_gt.txt
659+
sample_ratio: 1.0
660+
use_minddata: True <-- 设置该选项
661+
...
662+
```
663+
664+
2. 如训练服务器CPU核数较多,尝试调高配置文件中的`train.loader.num_workers`选项,提升数据预取的线程数:
665+
666+
``` yaml
667+
...
668+
train:
669+
...
670+
loader:
671+
shuffle: True
672+
batch_size: 10
673+
drop_remainder: True
674+
num_workers: 12 <-- 设置该选项
675+
...
676+
```

docs/en/tutorials/frequently_asked_questions.md

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- [`RunTimeError:The device address tpe is wrong`](#q6-runtimeerror-the-device-address-type-is-wrong-type-name-in-addresscpu-type-name-in-contextascend)
88
- [Problems related to model converting](#q7-problems-related-to-model-converting)
99
- [Problems related to inference](#q8-problems-related-to-inference)
10+
- [Training speed of DBNet not as fast as expexted](#q9-training-speed-of-dbnet-not-as-fast-as-expexted)
1011

1112
### Q1 Undefined symbol
1213

@@ -607,3 +608,58 @@ Reason:
607608
608609
- Use suitable model. For example, it may fail and pass detection model to `--rec_model_path` parameter.
609610
- Use inference model(not training model) to do converting.
611+
612+
613+
### Q9 Training speed of DBNet not as fast as expexted
614+
615+
When traning DBNet series networks (including DBNet MobileNetV3, DBNet ResNet-18, DBNet ResNet-50, and DBNet++ ResNet-50) using following command, the training speed is not as fast as expexted. For instance, the training speed of DBNet MobileNetV3 can reach only 80fps which is slower than the expecting 100fps.
616+
617+
``` bash
618+
python tools/train.py -c configs/det/dbnet/db_mobilenetv3_icdar15.yaml
619+
```
620+
621+
This problem is due to the complex data pre-processing procedures of DBNet. The data pre-processing procedures will become the performance bottleneck if the computation ability of a CPU core of the training server is relatively weak.
622+
623+
**Solutions**
624+
625+
1. Try to set the `train.dataset.use_minddata` and `eval.dataset.use_minddata` in the configuration file to `True`. MindOCR will execute parts of data pre-processing procedures using MindSpore[MindData](https://www.mindspore.cn/docs/zh-CN/master/api_python/dataset/dataset_method/operation/mindspore.dataset.Dataset.map.html?highlight=map#mindspore.dataset.Dataset.map):
626+
627+
``` yaml
628+
...
629+
train:
630+
ckpt_save_dir: './tmp_det'
631+
dataset_sink_mode: True
632+
dataset:
633+
type: DetDataset
634+
dataset_root: /data/ocr_datasets
635+
data_dir: ic15/det/train/ch4_training_images
636+
label_file: ic15/det/train/det_gt.txt
637+
sample_ratio: 1.0
638+
use_minddata: True <-- Set this configuration
639+
...
640+
eval:
641+
ckpt_load_path: tmp_det/best.ckpt
642+
dataset_sink_mode: False
643+
dataset:
644+
type: DetDataset
645+
dataset_root: /data/ocr_datasets
646+
data_dir: ic15/det/test/ch4_test_images
647+
label_file: ic15/det/test/det_gt.txt
648+
sample_ratio: 1.0
649+
use_minddata: True <-- Set this configuration
650+
...
651+
```
652+
653+
2. Try to set the `train.loader.num_workers` in the configuration file to a larger value to enhance the number of threads fetching dataset if the training server has enough CPU cores:
654+
655+
``` yaml
656+
...
657+
train:
658+
...
659+
loader:
660+
shuffle: True
661+
batch_size: 10
662+
drop_remainder: True
663+
num_workers: 12 <-- Set this configuration
664+
...
665+
```

mindocr/data/builder.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import mindspore as ms
66

7+
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
78
from .det_dataset import DetDataset, SynthTextDataset
89
from .kie_dataset import KieDataset
910
from .layout_dataset import PublayNetDataset
@@ -140,6 +141,9 @@ def build_dataset(
140141
assert dataset_class_name in supported_dataset_types, "Invalid dataset name"
141142
dataset_class = eval(dataset_class_name)
142143
dataset_args = dict(is_train=is_train, **dataset_config)
144+
if "use_minddata" in dataset_args and dataset_args["use_minddata"]:
145+
minddata_op_list = _parse_minddata_op(dataset_args)
146+
143147
dataset = dataset_class(**dataset_args)
144148

145149
dataset_column_names = dataset.get_output_columns()
@@ -158,8 +162,13 @@ def build_dataset(
158162
)
159163

160164
# 2. data mapping using minddata C lib (optional)
161-
# ds = ds.map(operations=transform_list, input_columns=['image', 'label'], num_parallel_workers=8,
162-
# python_multiprocessing=True)
165+
if "use_minddata" in dataset_args and dataset_args["use_minddata"]:
166+
ds = ds.map(
167+
operations=minddata_op_list,
168+
input_columns=["image"],
169+
num_parallel_workers=num_workers,
170+
python_multiprocessing=True,
171+
)
163172

164173
# 3. create loader
165174
# get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
@@ -242,3 +251,30 @@ def _check_batch_size(num_samples, ori_batch_size=32, refine=True):
242251
f"dropped/padded in graph mode."
243252
)
244253
return bs
254+
255+
256+
def _parse_minddata_op(dataset_args):
257+
minddata_op_idx = []
258+
minddata_op_list = []
259+
for i, transform_dict in enumerate(dataset_args["transform_pipeline"]):
260+
if "RandomColorAdjust" in transform_dict.keys():
261+
minddata_op_idx.append(i)
262+
color_adjust_op = ms.dataset.vision.RandomColorAdjust(
263+
brightness=transform_dict["RandomColorAdjust"]["brightness"],
264+
saturation=transform_dict["RandomColorAdjust"]["saturation"],
265+
)
266+
minddata_op_list.append(color_adjust_op)
267+
continue
268+
if "NormalizeImage" in transform_dict.keys():
269+
minddata_op_idx.append(i)
270+
normalize_op = ms.dataset.vision.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
271+
minddata_op_list.append(normalize_op)
272+
continue
273+
if "ToCHWImage" in transform_dict.keys():
274+
minddata_op_idx.append(i)
275+
change_swap_op = ms.dataset.vision.HWC2CHW()
276+
minddata_op_list.append(change_swap_op)
277+
continue
278+
for _ in range(len(minddata_op_idx)):
279+
dataset_args["transform_pipeline"].pop(minddata_op_idx.pop())
280+
return minddata_op_list

0 commit comments

Comments
 (0)