From af4d201a7eeb3c6a0b51b71f040851d2a68c9f3e Mon Sep 17 00:00:00 2001 From: Jonas Ehrenstein Date: Mon, 5 Jul 2021 12:06:19 +0200 Subject: [PATCH 1/2] Fix NER script: 'ALL_MODELS' is not needed and 'pretrained_config_archive_map' is deprecated. Error in predictions for label-ends leads to labels snaking the next one without accounting for 'O'-labels between them. --- models/ner.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/models/ner.py b/models/ner.py index 36f9c3f..801a9f9 100644 --- a/models/ner.py +++ b/models/ner.py @@ -29,10 +29,10 @@ logger = logging.getLogger(__name__) - -ALL_MODELS = sum( - [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)], - []) +# 'pretrained_config_archive_map' is deprecated and variable not actually in use +# ALL_MODELS = sum( +# [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)], +# []) MODEL_CLASSES = { 'bert': (BertConfig, BertForTokenClassification, BertTokenizer), @@ -407,14 +407,23 @@ def predict(self, tasks, **kwargs): result = [] + # bool for when 'O'-label is not appended in 'result' + skipped = False + for label, group in groupby(zip(preds, starts, scores), key=lambda i: re.sub('^(B-|I-)', '', i[0])): _, group_start, _ = list(group)[0] if len(result) > 0: if group_start == 0: result.pop(-1) - else: + # when 'O' is skipped when appending 'result', 'end' of previous 'result' sould not be changed, until new label is appended + elif not skipped: result[-1]['value']['end'] = group_start - 1 + # remove incorrect predictions, where 'end is smaller than 'start' + if result[-1]['value']['end'] is not None and result[-1]['value']['end'] < result[-1]['value']['start']: + result.pop(-1) + if label != 'O': + skipped = False result.append({ 'from_name': from_name, 'to_name': to_name, @@ -426,6 +435,9 @@ def predict(self, tasks, **kwargs): 'text': '...' } }) + else: + skipped = True + if result and result[-1]['value']['end'] is None: result[-1]['value']['end'] = len(string) results.append({ From 0a80880a4586720f7288ca6506b13e02373480d7 Mon Sep 17 00:00:00 2001 From: Jonas Ehrenstein Date: Tue, 13 Jul 2021 13:29:34 +0200 Subject: [PATCH 2/2] learnwhens v0.1 --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 8bc1d55..f5936ee 100644 --- a/.gitignore +++ b/.gitignore @@ -101,3 +101,6 @@ venv.bak/ # mypy .mypy_cache/ + +# Models +/predictwhens \ No newline at end of file