Skip to content

Commit f6a7724

Browse files
PR comments
1 parent 296981d commit f6a7724

File tree

2 files changed

+163
-50
lines changed

2 files changed

+163
-50
lines changed

src/inspect_evals/docvqa/docvqa.py

Lines changed: 68 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,61 @@
3131
{question}
3232
"""
3333

34+
IMAGE_BASE_DIR = Path(user_cache_dir("inspect_evals")) / "docvqa_images"
35+
36+
37+
def _levenshtein_distance(str1: str, str2: str) -> int:
38+
"""Computes a Levenshtein distance, same as Levenshtein.distance in the python-Levenshtein package."""
39+
# Create a matrix of size (len(str1) + 1) x (len(str2) + 1)
40+
matrix = [[0 for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
41+
42+
# Initialize the first row and column
43+
for i in range(len(str1) + 1):
44+
matrix[i][0] = i
45+
for j in range(len(str2) + 1):
46+
matrix[0][j] = j
47+
48+
# Fill in the rest of the matrix
49+
for i in range(1, len(str1) + 1):
50+
for j in range(1, len(str2) + 1):
51+
matrix[i][j] = min(
52+
matrix[i - 1][j] + 1, # deletion
53+
matrix[i][j - 1] + 1, # insertion
54+
matrix[i - 1][j - 1] + int(str1[i - 1] != str2[j - 1]), # substitution
55+
)
56+
57+
return matrix[len(str1)][len(str2)]
58+
59+
60+
def _best_normalized_levenshtein_similiarity(
61+
completion: str, ground_truths: list[str], threshold: float
62+
) -> float:
63+
"""
64+
Compute the Average Normalized Levenshtein Similarity (ANLS) as defined in equation (1) of
65+
https://arxiv.org/pdf/1907.00490.pdf
66+
67+
Note that the "average" is computed by the accuracy metric -- not here. This function computes
68+
the term inside the summation of equation (1).
69+
"""
70+
best_score = 0.0
71+
for ground_truth in ground_truths:
72+
if len(ground_truth) == 0 and len(completion) == 0:
73+
best_score = 1
74+
break
75+
levenshtein_distance = _levenshtein_distance(
76+
completion.lower(), ground_truth.lower()
77+
)
78+
normed_levenshtein_distance = levenshtein_distance / max(
79+
len(completion), len(ground_truth)
80+
)
81+
if normed_levenshtein_distance < threshold:
82+
score = 1.0 - normed_levenshtein_distance
83+
else:
84+
score = 0.0
85+
if score > best_score:
86+
best_score = score
87+
return best_score
88+
3489

3590
@task
3691
def docvqa() -> Task:
@@ -52,30 +107,9 @@ def docvqa() -> Task:
52107

53108
@scorer(metrics=[accuracy(), stderr()])
54109
def docvqa_scorer() -> Scorer:
55-
def distance(str1: str, str2: str) -> int:
56-
# Create a matrix of size (len(str1) + 1) x (len(str2) + 1)
57-
matrix = [[0 for j in range(len(str2) + 1)] for i in range(len(str1) + 1)]
58-
59-
# Initialize the first row and column
60-
for i in range(len(str1) + 1):
61-
matrix[i][0] = i
62-
for j in range(len(str2) + 1):
63-
matrix[0][j] = j
64-
65-
# Fill in the rest of the matrix
66-
for i in range(1, len(str1) + 1):
67-
for j in range(1, len(str2) + 1):
68-
matrix[i][j] = min(
69-
matrix[i - 1][j] + 1, # deletion
70-
matrix[i][j - 1] + 1, # insertion
71-
matrix[i - 1][j - 1]
72-
+ int(str1[i - 1] != str2[j - 1]), # substitution
73-
)
74-
75-
return matrix[len(str1)][len(str2)]
76-
77-
async def get_ANLS_score(state: TaskState, target: Target) -> Score:
78-
"""Follows https://arxiv.org/pdf/1907.00490.pdf"""
110+
async def normalized_levenshtein_similiarity_score(
111+
state: TaskState, target: Target
112+
) -> Score:
79113
threshold = 0.5
80114
ground_truths = target.target
81115
match = re.search(
@@ -85,25 +119,10 @@ async def get_ANLS_score(state: TaskState, target: Target) -> Score:
85119
)
86120
if match:
87121
completion = match.groups()[0]
88-
best_score = 0.0
89-
for ground_truth in ground_truths:
90-
if len(ground_truth) == 0 and len(completion) == 0:
91-
best_score = 1
92-
break
93-
levenshtein_distance = distance(
94-
completion.lower(), ground_truth.lower()
95-
)
96-
normed_levenshtein_distance = levenshtein_distance / max(
97-
len(completion), len(ground_truth)
98-
)
99-
if normed_levenshtein_distance < threshold:
100-
score = 1.0 - normed_levenshtein_distance
101-
else:
102-
score = 0.0
103-
if score > best_score:
104-
best_score = score
105122
return Score(
106-
value=best_score,
123+
value=_best_normalized_levenshtein_similiarity(
124+
completion, ground_truths, threshold
125+
),
107126
answer=completion,
108127
)
109128

@@ -115,7 +134,7 @@ async def get_ANLS_score(state: TaskState, target: Target) -> Score:
115134
+ f"{state.output.completion}",
116135
)
117136

118-
return get_ANLS_score
137+
return normalized_levenshtein_similiarity_score
119138

120139

121140
@solver
@@ -131,27 +150,26 @@ async def solve(state: TaskState, generate: Generate) -> TaskState:
131150

132151
def record_to_sample(record: dict[str, Any]) -> Sample:
133152
# extract image
134-
IMAGE_BASE_DIR = Path(user_cache_dir("inspect_evals")) / "docvqa_images"
135-
image = Path(IMAGE_BASE_DIR / record["image"]["path"])
153+
image_path = Path(IMAGE_BASE_DIR / record["image"]["path"])
136154

137155
image_bytes = record["image"]["bytes"]
138156
assert is_image_png(image_bytes)
139157

140-
if not image.exists():
141-
print(f"Extracting {image.name}")
158+
if not image_path.exists():
159+
print(f"Extracting {image_path.name}")
142160
# ensure parent
143-
image.parent.mkdir(exist_ok=True, parents=True)
161+
image_path.parent.mkdir(exist_ok=True, parents=True)
144162
# reduce the image size
145163
img = Image.open(BytesIO(image_bytes))
146164
img.thumbnail((1024, 1024))
147165
# save preserving format
148-
img.save(image, format=img.format)
166+
img.save(image_path, format=img.format)
149167

150168
message: list[ChatMessage] = [
151169
ChatMessageUser(
152170
content=[
153171
ContentText(text=record["question"]),
154-
ContentImage(image=image.as_posix()),
172+
ContentImage(image=image_path.as_posix()),
155173
]
156174
)
157175
]

tests/docvqa/test_docvqa.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from inspect_evals.docvqa.docvqa import (
2+
_levenshtein_distance as levenshtein,
3+
_best_normalized_levenshtein_similiarity as best,
4+
)
5+
6+
7+
def test_levenshtein():
8+
# Basic test cases
9+
assert levenshtein("", "") == 0 # Empty strings
10+
assert levenshtein("a", "a") == 0 # Same single char
11+
assert levenshtein("abc", "abc") == 0 # Same string
12+
13+
# Single operations
14+
assert levenshtein("a", "") == 1 # Single deletion
15+
assert levenshtein("", "a") == 1 # Single insertion
16+
assert levenshtein("a", "b") == 1 # Single substitution
17+
18+
# Multiple operations
19+
assert levenshtein("kitten", "sitting") == 3 # Classic example
20+
assert levenshtein("sunday", "saturday") == 3 # Real words
21+
22+
23+
def test_best_normalized_levenshtein_distance():
24+
def best_norm_lev_sim(completion, ground_truths, threshold=2.0):
25+
return round(best(completion, ground_truths, threshold), 3)
26+
27+
# Basic cases
28+
assert best_norm_lev_sim("", [""]) == 1.0 # Empty strings
29+
assert best_norm_lev_sim("a", ["a"]) == 1.0 # Single char match
30+
assert best_norm_lev_sim("", ["a"]) == 0.0 # Empty vs char
31+
assert best_norm_lev_sim("a", ["b"]) == 0.0 # Different chars
32+
33+
# Multiple correct answers
34+
assert (
35+
best_norm_lev_sim("color", ["color", "colour"]) == 1.0
36+
) # Exact match with variants
37+
38+
assert (
39+
best_norm_lev_sim("theatre", ["theater", "theatre"]) == 1.0
40+
) # Regional spellings
41+
42+
# Partial matches with multiple answers
43+
assert best_norm_lev_sim("thetre", ["theater", "theatre"]) == round(
44+
1 - 1 / 7, 3
45+
) # One deletion
46+
47+
# Case insensitivity
48+
assert best_norm_lev_sim("HELLO", ["hello", "hola"]) == 1.0 # All case differences
49+
50+
# Length differences
51+
assert best_norm_lev_sim("hi", ["hello", "hey"]) == round(
52+
1 - 2 / 3, 3
53+
) # Short vs longer options
54+
55+
assert best_norm_lev_sim("hi", ["hello", "hey"], 0.5) == 0.0 # Test threshold
56+
57+
assert best_norm_lev_sim("hi", ["hello", "hey"], 0.75) == round(
58+
1 - 2 / 3, 3
59+
) # Test threshold
60+
61+
# Numeric and special characters
62+
assert (
63+
best_norm_lev_sim("2nd floor", ["second floor", "2nd floor", "floor 2"]) == 1.0
64+
) # Number representations
65+
66+
# Common abbreviations
67+
assert (
68+
best_norm_lev_sim("dept", ["department", "dept.", "dept"]) == 1.0
69+
) # Abbreviation matches
70+
71+
# Multiple errors
72+
assert best_norm_lev_sim(
73+
"californa", ["california", "calif", "ca"]
74+
) > best_norm_lev_sim(
75+
"calfrnia", ["california", "calif", "ca"]
76+
) # Better partial match
77+
78+
# Spaces and formatting
79+
assert (
80+
best_norm_lev_sim("new york", ["newyork", "new york", "ny"]) == 1.0
81+
) # Space variations
82+
83+
# Unicode and special characters
84+
assert best_norm_lev_sim("café", ["cafe", "café", "caffè"]) == 1.0 # Accent marks
85+
86+
# Long string comparisons
87+
assert (
88+
best_norm_lev_sim(
89+
"mississipi river", ["mississippi river", "river mississippi"]
90+
)
91+
> 0.9
92+
) # Minor spelling error
93+
94+
# Completely different strings
95+
assert best_norm_lev_sim("kiwi", ["banana", "orange"]) == 0.0 # No similarity

0 commit comments

Comments
 (0)