Skip to content

Commit e9b0331

Browse files
author
Johannes Hötter
committed
adds label injection
1 parent 502e06d commit e9b0331

File tree

1 file changed

+57
-1
lines changed

1 file changed

+57
-1
lines changed

kern/adapter/rasa.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import os
33
from collections import OrderedDict
44

5+
CONSTANT_OUTSIDE = "OUTSIDE"
6+
CONSTANT_LABEL_BEGIN = "B-"
7+
CONSTANT_LABEL_INTERMEDIATE = "I-"
8+
59

610
class literal(str):
711
pass
@@ -25,15 +29,67 @@ def build_literal_from_iterable(iterable):
2529
return "\n".join([f"- {value}" for value in iterable]) + "\n"
2630

2731

32+
def inject_label_in_text(row, text_name, tokenized_label_task, constant_outside):
33+
string = ""
34+
token_list = row[f"{text_name}__tokenized"]
35+
36+
close_multitoken_label = False
37+
multitoken_label = False
38+
for idx, token in enumerate(token_list):
39+
40+
if idx < len(token_list) - 1:
41+
token_next = token_list[idx + 1]
42+
label_next = row[tokenized_label_task][idx + 1]
43+
if label_next.startswith(CONSTANT_LABEL_INTERMEDIATE):
44+
multitoken_label = True
45+
else:
46+
if multitoken_label:
47+
close_multitoken_label = True
48+
multitoken_label = False
49+
num_whitespaces = token_next.idx - (token.idx + len(token))
50+
else:
51+
num_whitespaces = 0
52+
whitespaces = " " * num_whitespaces
53+
54+
label = row[tokenized_label_task][idx]
55+
if label != constant_outside:
56+
if multitoken_label:
57+
if label.startswith(CONSTANT_LABEL_BEGIN):
58+
string += f"[{token.text}{whitespaces}"
59+
else:
60+
string += f"{token.text}{whitespaces}"
61+
else:
62+
if close_multitoken_label:
63+
string += f"{token.text}]({label[2:]}){whitespaces}"
64+
close_multitoken_label = False
65+
else:
66+
string += f"[{token.text}]({label[2:]}){whitespaces}"
67+
else:
68+
string += f"{token.text}{whitespaces}"
69+
return string
70+
71+
2872
def build_intent_yaml(
2973
client,
3074
text_name,
3175
intent_label_task,
3276
metadata_label_task=None,
77+
tokenized_label_task=None,
3378
dir_name="data",
3479
file_name="nlu.yml",
80+
constant_outside=CONSTANT_OUTSIDE,
3581
):
36-
df = client.get_record_export(tokenize=False)
82+
df = client.get_record_export(tokenize=(tokenized_label_task is not None))
83+
84+
if tokenized_label_task is not None:
85+
text_name_injected = f"{text_name}__injected"
86+
df[text_name_injected] = df.apply(
87+
lambda x: inject_label_in_text(
88+
x, text_name, tokenized_label_task, constant_outside
89+
),
90+
axis=1,
91+
)
92+
text_name = text_name_injected
3793

3894
nlu_list = []
3995
for label, df_sub_label in df.groupby(intent_label_task):

0 commit comments

Comments
 (0)