Skip to content

Commit e42dae0

Browse files
authored
1. Added support for the relation extraction task in the LayoutXLM and Vi-LayoutXLM models. (#643)
2. Fixed some bugs.
1 parent ec20d02 commit e42dae0

File tree

13 files changed

+627
-178
lines changed

13 files changed

+627
-178
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
system:
2+
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
3+
distribute: False
4+
amp_level: 'O0'
5+
seed: 42
6+
log_interval: 10
7+
val_while_train: True
8+
val_start_epoch: 50
9+
drop_overflow_update: False
10+
11+
model:
12+
type: kie
13+
transform: null
14+
backbone:
15+
name: layoutxlm
16+
pretrained: True
17+
num_classes: &num_classes 7
18+
use_visual_backbone: True
19+
use_float16: True
20+
head:
21+
name: RelationExtractionHead
22+
use_visual_backbone: True
23+
use_float16: True
24+
pretrained:
25+
26+
postprocess:
27+
name: VQAReTokenLayoutLMPostProcess
28+
class_path: &class_path path/to/class_list_xfun.txt
29+
30+
metric:
31+
name: VQAReTokenMetric
32+
main_indicator: hmean
33+
34+
loss:
35+
name: VQAReTokenLayoutLMLoss
36+
37+
scheduler:
38+
scheduler: polynomial_decay
39+
lr: 5.0e-5
40+
min_lr: 2.0e-7
41+
num_epochs: 200
42+
warmup_epochs: 10
43+
44+
optimizer:
45+
opt: adam
46+
beta1: 0.9
47+
beta2: 0.999
48+
clip_norm: 10
49+
filter_bias_and_bn: False
50+
weight_decay: 0.0005
51+
52+
train:
53+
ckpt_save_dir: './tmp_kie_re'
54+
dataset_sink_mode: False
55+
dataset:
56+
type: KieDataset
57+
dataset_root: path/to/train_data/
58+
data_dir: XFUND/zh_train/image
59+
label_file: XFUND/zh_train/train.json
60+
sample_ratio: 1.0
61+
transform_pipeline:
62+
- DecodeImage:
63+
img_mode: RGB
64+
to_float32: False
65+
- VQATokenLabelEncode:
66+
contains_re: True
67+
algorithm: &algorithm LayoutXLM
68+
class_path: *class_path
69+
order_method: tb-yx
70+
- VQATokenPad:
71+
max_seq_len: &max_seq_len 512
72+
return_attention_mask: True
73+
- VQAReTokenRelation:
74+
- VQAReTokenChunk:
75+
max_seq_len: *max_seq_len
76+
- TensorizeEntitiesRelations:
77+
max_relation_len: 5000
78+
- LayoutResize:
79+
size: [224, 224]
80+
- NormalizeImage:
81+
bgr_to_rgb: False
82+
is_hwc: True
83+
mean: imagenet
84+
std: imagenet
85+
- ToCHWImage:
86+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
87+
output_columns:
88+
[
89+
"input_ids",
90+
"bbox",
91+
"attention_mask",
92+
"token_type_ids",
93+
"image",
94+
"question",
95+
"question_label",
96+
"answer",
97+
"answer_label",
98+
"relation_label",
99+
]
100+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
101+
label_column_index: [9] # input indices marked as label
102+
103+
loader:
104+
shuffle: True
105+
batch_size: 8
106+
drop_remainder: True
107+
num_workers: 16
108+
109+
eval:
110+
ckpt_load_path: 'tmp_kie_re/best.ckpt'
111+
dataset_sink_mode: False
112+
dataset:
113+
type: KieDataset
114+
dataset_root: path/to/train_data/
115+
data_dir: XFUND/zh_val/image
116+
label_file: XFUND/zh_val/val.json
117+
sample_ratio: 1.0
118+
shuffle: False
119+
transform_pipeline:
120+
- DecodeImage:
121+
img_mode: RGB
122+
to_float32: False
123+
- VQATokenLabelEncode:
124+
contains_re: True
125+
algorithm: *algorithm
126+
class_path: *class_path
127+
order_method: tb-yx
128+
- VQATokenPad:
129+
max_seq_len: *max_seq_len
130+
return_attention_mask: True
131+
- VQAReTokenRelation:
132+
- VQAReTokenChunk:
133+
max_seq_len: *max_seq_len
134+
- TensorizeEntitiesRelations:
135+
max_relation_len: 5000
136+
- LayoutResize:
137+
size: [224, 224]
138+
- NormalizeImage:
139+
bgr_to_rgb: False
140+
is_hwc: True
141+
mean: imagenet
142+
std: imagenet
143+
- ToCHWImage:
144+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
145+
output_columns:
146+
[
147+
"input_ids",
148+
"bbox",
149+
"attention_mask",
150+
"token_type_ids",
151+
"image",
152+
"question",
153+
"question_label",
154+
"answer",
155+
"answer_label",
156+
"relation_label",
157+
]
158+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
159+
label_column_index: [9] # input indices marked as label
160+
161+
loader:
162+
shuffle: False
163+
batch_size: 1
164+
drop_remainder: False
165+
num_workers: 1
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
system:
2+
mode: 0 # 0 for graph mode, 1 for pynative mode in MindSpore
3+
distribute: False
4+
amp_level: "O0"
5+
seed: 42
6+
log_interval: 10
7+
val_while_train: True
8+
val_start_epoch: 50
9+
drop_overflow_update: False
10+
11+
model:
12+
type: kie
13+
transform: null
14+
backbone:
15+
name: layoutxlm
16+
pretrained: True
17+
num_classes: &num_classes 7
18+
use_visual_backbone: False
19+
use_float16: True
20+
head:
21+
name: RelationExtractionHead
22+
use_visual_backbone: False
23+
use_float16: True
24+
pretrained:
25+
26+
postprocess:
27+
name: VQAReTokenLayoutLMPostProcess
28+
class_path: &class_path path/to/class_list_xfun.txt
29+
30+
metric:
31+
name: VQAReTokenMetric
32+
main_indicator: hmean
33+
34+
loss:
35+
name: VQAReTokenLayoutLMLoss
36+
37+
scheduler:
38+
scheduler: polynomial_decay
39+
lr: 5.0e-5
40+
min_lr: 2.0e-7
41+
num_epochs: 200
42+
warmup_epochs: 10
43+
44+
optimizer:
45+
opt: adam
46+
beta1: 0.9
47+
beta2: 0.999
48+
clip_norm: 10
49+
filter_bias_and_bn: False
50+
weight_decay: 0.0005
51+
52+
train:
53+
ckpt_save_dir: "./vi_layoutxlm_re"
54+
dataset_sink_mode: False
55+
dataset:
56+
type: KieDataset
57+
dataset_root: path/to/train_data/
58+
data_dir: XFUND/zh_train/image
59+
label_file: XFUND/zh_train/train.json
60+
sample_ratio: 1.0
61+
transform_pipeline:
62+
- DecodeImage:
63+
img_mode: RGB
64+
to_float32: False
65+
- VQATokenLabelEncode:
66+
contains_re: True
67+
algorithm: &algorithm LayoutXLM
68+
class_path: *class_path
69+
order_method: tb-yx
70+
- VQATokenPad:
71+
max_seq_len: &max_seq_len 512
72+
return_attention_mask: True
73+
- VQAReTokenRelation:
74+
- VQAReTokenChunk:
75+
max_seq_len: *max_seq_len
76+
- TensorizeEntitiesRelations:
77+
max_relation_len: 5000
78+
- LayoutResize:
79+
size: [224, 224]
80+
- NormalizeImage:
81+
bgr_to_rgb: False
82+
is_hwc: True
83+
mean: imagenet
84+
std: imagenet
85+
- ToCHWImage:
86+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
87+
output_columns:
88+
[
89+
"input_ids",
90+
"bbox",
91+
"attention_mask",
92+
"token_type_ids",
93+
"image",
94+
"question",
95+
"question_label",
96+
"answer",
97+
"answer_label",
98+
"relation_label",
99+
]
100+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
101+
label_column_index: [9] # input indices marked as label
102+
103+
loader:
104+
shuffle: True
105+
batch_size: 8
106+
drop_remainder: True
107+
num_workers: 16
108+
109+
eval:
110+
ckpt_load_path: "vi_layoutxlm_re/best.ckpt"
111+
dataset_sink_mode: False
112+
dataset:
113+
type: KieDataset
114+
dataset_root: path/to/train_data/
115+
data_dir: XFUND/zh_val/image
116+
label_file: XFUND/zh_val/val.json
117+
sample_ratio: 1.0
118+
shuffle: False
119+
transform_pipeline:
120+
- DecodeImage:
121+
img_mode: RGB
122+
to_float32: False
123+
- VQATokenLabelEncode:
124+
contains_re: True
125+
algorithm: *algorithm
126+
class_path: *class_path
127+
order_method: tb-yx
128+
- VQATokenPad:
129+
max_seq_len: *max_seq_len
130+
return_attention_mask: True
131+
- VQAReTokenRelation:
132+
- VQAReTokenChunk:
133+
max_seq_len: *max_seq_len
134+
- TensorizeEntitiesRelations:
135+
max_relation_len: 5000
136+
- LayoutResize:
137+
size: [224, 224]
138+
- NormalizeImage:
139+
bgr_to_rgb: False
140+
is_hwc: True
141+
mean: imagenet
142+
std: imagenet
143+
- ToCHWImage:
144+
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
145+
output_columns:
146+
[
147+
"input_ids",
148+
"bbox",
149+
"attention_mask",
150+
"token_type_ids",
151+
"image",
152+
"question",
153+
"question_label",
154+
"answer",
155+
"answer_label",
156+
"relation_label",
157+
]
158+
net_input_column_index: [0, 1, 2, 3, 4, 5, 6, 7, 8] # input indices for network forward func in output_columns
159+
label_column_index: [9] # input indices marked as label
160+
161+
loader:
162+
shuffle: False
163+
batch_size: 1
164+
drop_remainder: False
165+
num_workers: 1

0 commit comments

Comments
 (0)