Skip to content

Commit d7fd1e1

Browse files
authored
Accuracy checker support for bloomz-560m (#3801)
* Support for attention mask in wikitext2raw annotation converter * Support for bloomz-560m * Added transformers and scipy to requirements * Unified tkenizer preparation * Changed scipy import location * Return to original log_softmax implementation * cosmetic * Return to log_softmax from scipy. There was a division by zero with original log_softmax implementation (for whole set of dataset). * a few data type fixes * Corrected meta shapes preparation. * Corrected fit to input to return proper data type. * Align scipy version with another requirements * Removed unneeded change. * Readme update, conditional using log_softmax from scipy
1 parent e13f63f commit d7fd1e1

File tree

5 files changed

+47
-9
lines changed

5 files changed

+47
-9
lines changed

tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ The main difference between this converter and `super_resolution` in data organi
370370
* `squad_emb` - converts the Stanford Question Answering Dataset ([SQuAD](https://rajpurkar.github.io/SQuAD-explorer/)) to `Question Answering Embedding Annotation`. **Note: This converter not only converts data to metric specific format but also tokenize and encodes input for model.**
371371
* `testing_file` - path to testing file.
372372
* `vocab_file` - path to model co vocabulary file.
373+
* `class_token_first` - Add [CLS] token to the begin of sequence. If False, will be added as the last token.
374+
* `enable_padding` - pad input sequence to max length.
375+
* `tokenizer_dir` - path to a directory containing vocabulary files required by the transformers tokenizer
376+
* `model_id` - model id of a predefined tokenizer hosted inside a model repo on huggingface.co.
377+
* `lower_case` - converts output to lower case.
373378
* `max_seq_length` - maximum total input sequence length after word-piece tokenization (Optional, default value is 128).
374379
* `max_query_length` - maximum number of tokens for the question (Optional, default value is 64).
375380
* `lower_case` - allows switching tokens to lower case register. It is useful for working with uncased models (Optional, default value is False)

tools/accuracy_checker/openvino/tools/accuracy_checker/annotation_converters/wikitext2raw.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import numpy as np
1818

1919
from ..representation import LanguageModelingAnnotation
20-
from ..config import PathField, NumberField
20+
from ..config import PathField, NumberField, StringField, BoolField
2121
from ..utils import UnsupportedPackage
2222
from .format_converter import BaseFormatConverter, ConverterReturn
23+
from ._nlp_common import get_tokenizer
2324

2425
try:
2526
from tokenizers import Tokenizer, pre_tokenizers, decoders
@@ -42,6 +43,19 @@ def parameters(cls):
4243
'testing_file': PathField(description="Path to testing file."),
4344
'merges_file': PathField(description="Path to merges file."),
4445
'vocab_file': PathField(description='Path to vocabulary file.'),
46+
'class_token_first': BoolField(
47+
optional=True, default=True,
48+
description='Add [CLS] token to the begin of sequence. If False, will be added as the last token.'),
49+
'enable_padding': BoolField(optional=True, default=True, description='pad input sequence to max length'),
50+
'tokenizer_dir': PathField(
51+
optional=True, is_directory=True,
52+
description='A path to a directory containing vocabulary files required by the transformers tokenizer'
53+
),
54+
'model_id': StringField(
55+
optional=True,
56+
description='The model id of a predefined tokenizer hosted inside a model repo on huggingface.co'
57+
),
58+
'lower_case': BoolField(optional=True, default=False, description='converts output to lower case'),
4559
'max_seq_length': NumberField(
4660
description='The maximum total input sequence length after tokenization.',
4761
optional=True, default=128, value_type=int
@@ -57,29 +71,35 @@ def configure(self):
5771
self.vocab_file = self.get_value_from_config('vocab_file')
5872
self.merges_file = self.get_value_from_config('merges_file')
5973
self.max_seq_length = int(self.get_value_from_config('max_seq_length'))
60-
self.tokenizer = Tokenizer(BPE.from_file(str(self.vocab_file), str(self.merges_file)))
74+
self.model_id = self.get_value_from_config('model_id')
75+
self.lower_case = self.get_value_from_config('lower_case')
76+
self.tokenizer, self.external_tok = get_tokenizer(self.config, self.lower_case)
77+
if not self.external_tok:
78+
self.tokenizer = Tokenizer(BPE.from_file(str(self.vocab_file), str(self.merges_file)))
79+
self.tokenizer.decoder = decoders.ByteLevel()
6180
self.tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
62-
self.tokenizer.decoder = decoders.ByteLevel()
6381

6482
def convert(self, check_content=False, progress_callback=None, progress_interval=100, **kwargs):
6583
with open(str(self.testing_file), encoding="utf-8") as f:
6684
text = f.read()
6785

68-
tokens = self.tokenizer.encode_batch([text])
86+
tokens = self.tokenizer([text]) if self.external_tok else self.tokenizer.encode_batch([text])
6987

7088
encoding = tokens[0]
7189
annotations = []
7290
unique_id = 1000000000
7391
for idx in range(0, len(encoding.ids) - self.max_seq_length + 1, self.max_seq_length):
7492
ids = encoding.ids[idx: idx + self.max_seq_length]
7593
tokens = encoding.tokens[idx:idx + self.max_seq_length]
76-
identifier = ['input_ids_{}'.format(idx), 'labels_{}'.format(idx)]
94+
attention_mask = encoding.attention_mask[idx:idx + self.max_seq_length]
95+
identifier = ['input_ids_{}'.format(idx), 'input_mask_{}'.format(idx), 'labels_{}'.format(idx)]
7796
annotation = LanguageModelingAnnotation(
7897
identifier,
7998
np.array(unique_id),
8099
np.array([ids]),
81100
tokens,
82101
labels=np.array(ids),
102+
input_mask=np.array([attention_mask])
83103
)
84104
annotations.append(annotation)
85105
unique_id += 1

tools/accuracy_checker/openvino/tools/accuracy_checker/launcher/openvino_launcher.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,7 @@ def fit_to_input(self, data, layer_name, layout, precision, template=None):
806806
data = data.astype(precision)
807807
if layer_name in self.dyn_input_layers:
808808
self._do_reshape = not self.is_dynamic
809-
return data, template
809+
return data
810810
data_shape = np.shape(data)
811811
if data_shape != layer_shape:
812812
if self.allow_reshape_input:
@@ -836,7 +836,7 @@ def _data_to_blob_dyn(layer_rang, data, layout, template=None):
836836
template = [1] * (np.ndim(data) - len(template)) + template
837837
if len(template) > np.ndim(data):
838838
template = template[0]
839-
if len(layout) == len(data_shape):
839+
if layout and len(layout) == len(data_shape):
840840
if template is not None:
841841
new_template = [template[l_dim] for l_dim in layout]
842842
template = new_template

tools/accuracy_checker/openvino/tools/accuracy_checker/metrics/language_modeling.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,13 @@
1515
"""
1616

1717
import numpy as np
18+
from ..utils import UnsupportedPackage
19+
20+
try:
21+
from scipy.special import log_softmax as scipy_log_softmax
22+
except ImportError as import_error:
23+
scipy_log_softmax = UnsupportedPackage('scipy', import_error.msg)
24+
1825

1926
from ..representation import LanguageModelingAnnotation, LanguageModelingPrediction
2027
from .metric import PerImageEvaluationMetric
@@ -33,7 +40,12 @@ def __init__(self, *args, **kwargs):
3340

3441
def update(self, annotation, prediction):
3542
def cross_entropy(logits, target):
36-
return nll_loss(log_softmax(logits, 1), target)
43+
log_softmax_res = log_softmax(logits, 1)
44+
if -np.inf in log_softmax_res:
45+
log_softmax_res = scipy_log_softmax(logits, 1)
46+
if isinstance(scipy_log_softmax, UnsupportedPackage):
47+
scipy_log_softmax.raise_error(self.__provider__)
48+
return nll_loss(log_softmax_res, target)
3749

3850
def log_softmax(x, dim):
3951
e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))

tools/accuracy_checker/openvino/tools/accuracy_checker/representation/nlp_representation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,12 @@ def __init__(self, identifier=''):
4444

4545

4646
class LanguageModelingAnnotation(LanguageModeling):
47-
def __init__(self, identifier, unique_id, input_ids, tokens, labels=None):
47+
def __init__(self, identifier, unique_id, input_ids, tokens, labels=None, input_mask=None):
4848
super().__init__(identifier)
4949
self.unique_id = unique_id
5050
self.tokens = tokens
5151
self.input_ids = input_ids
52+
self.input_mask = input_mask
5253
self.labels = labels if labels is not None else []
5354

5455

0 commit comments

Comments
 (0)