3131{question}
3232"""
3333
34+ IMAGE_BASE_DIR = Path (user_cache_dir ("inspect_evals" )) / "docvqa_images"
35+
36+
37+ def _levenshtein_distance (str1 : str , str2 : str ) -> int :
38+ """Computes a Levenshtein distance, same as Levenshtein.distance in the python-Levenshtein package."""
39+ # Create a matrix of size (len(str1) + 1) x (len(str2) + 1)
40+ matrix = [[0 for j in range (len (str2 ) + 1 )] for i in range (len (str1 ) + 1 )]
41+
42+ # Initialize the first row and column
43+ for i in range (len (str1 ) + 1 ):
44+ matrix [i ][0 ] = i
45+ for j in range (len (str2 ) + 1 ):
46+ matrix [0 ][j ] = j
47+
48+ # Fill in the rest of the matrix
49+ for i in range (1 , len (str1 ) + 1 ):
50+ for j in range (1 , len (str2 ) + 1 ):
51+ matrix [i ][j ] = min (
52+ matrix [i - 1 ][j ] + 1 , # deletion
53+ matrix [i ][j - 1 ] + 1 , # insertion
54+ matrix [i - 1 ][j - 1 ] + int (str1 [i - 1 ] != str2 [j - 1 ]), # substitution
55+ )
56+
57+ return matrix [len (str1 )][len (str2 )]
58+
59+
60+ def _best_normalized_levenshtein_similiarity (
61+ completion : str , ground_truths : list [str ], threshold : float
62+ ) -> float :
63+ """
64+ Compute the Average Normalized Levenshtein Similarity (ANLS) as defined in equation (1) of
65+ https://arxiv.org/pdf/1907.00490.pdf
66+
67+ Note that the "average" is computed by the accuracy metric -- not here. This function computes
68+ the term inside the summation of equation (1).
69+ """
70+ best_score = 0.0
71+ for ground_truth in ground_truths :
72+ if len (ground_truth ) == 0 and len (completion ) == 0 :
73+ best_score = 1
74+ break
75+ levenshtein_distance = _levenshtein_distance (
76+ completion .lower (), ground_truth .lower ()
77+ )
78+ normed_levenshtein_distance = levenshtein_distance / max (
79+ len (completion ), len (ground_truth )
80+ )
81+ if normed_levenshtein_distance < threshold :
82+ score = 1.0 - normed_levenshtein_distance
83+ else :
84+ score = 0.0
85+ if score > best_score :
86+ best_score = score
87+ return best_score
88+
3489
3590@task
3691def docvqa () -> Task :
@@ -52,30 +107,9 @@ def docvqa() -> Task:
52107
53108@scorer (metrics = [accuracy (), stderr ()])
54109def docvqa_scorer () -> Scorer :
55- def distance (str1 : str , str2 : str ) -> int :
56- # Create a matrix of size (len(str1) + 1) x (len(str2) + 1)
57- matrix = [[0 for j in range (len (str2 ) + 1 )] for i in range (len (str1 ) + 1 )]
58-
59- # Initialize the first row and column
60- for i in range (len (str1 ) + 1 ):
61- matrix [i ][0 ] = i
62- for j in range (len (str2 ) + 1 ):
63- matrix [0 ][j ] = j
64-
65- # Fill in the rest of the matrix
66- for i in range (1 , len (str1 ) + 1 ):
67- for j in range (1 , len (str2 ) + 1 ):
68- matrix [i ][j ] = min (
69- matrix [i - 1 ][j ] + 1 , # deletion
70- matrix [i ][j - 1 ] + 1 , # insertion
71- matrix [i - 1 ][j - 1 ]
72- + int (str1 [i - 1 ] != str2 [j - 1 ]), # substitution
73- )
74-
75- return matrix [len (str1 )][len (str2 )]
76-
77- async def get_ANLS_score (state : TaskState , target : Target ) -> Score :
78- """Follows https://arxiv.org/pdf/1907.00490.pdf"""
110+ async def normalized_levenshtein_similiarity_score (
111+ state : TaskState , target : Target
112+ ) -> Score :
79113 threshold = 0.5
80114 ground_truths = target .target
81115 match = re .search (
@@ -85,25 +119,10 @@ async def get_ANLS_score(state: TaskState, target: Target) -> Score:
85119 )
86120 if match :
87121 completion = match .groups ()[0 ]
88- best_score = 0.0
89- for ground_truth in ground_truths :
90- if len (ground_truth ) == 0 and len (completion ) == 0 :
91- best_score = 1
92- break
93- levenshtein_distance = distance (
94- completion .lower (), ground_truth .lower ()
95- )
96- normed_levenshtein_distance = levenshtein_distance / max (
97- len (completion ), len (ground_truth )
98- )
99- if normed_levenshtein_distance < threshold :
100- score = 1.0 - normed_levenshtein_distance
101- else :
102- score = 0.0
103- if score > best_score :
104- best_score = score
105122 return Score (
106- value = best_score ,
123+ value = _best_normalized_levenshtein_similiarity (
124+ completion , ground_truths , threshold
125+ ),
107126 answer = completion ,
108127 )
109128
@@ -115,7 +134,7 @@ async def get_ANLS_score(state: TaskState, target: Target) -> Score:
115134 + f"{ state .output .completion } " ,
116135 )
117136
118- return get_ANLS_score
137+ return normalized_levenshtein_similiarity_score
119138
120139
121140@solver
@@ -131,27 +150,26 @@ async def solve(state: TaskState, generate: Generate) -> TaskState:
131150
132151def record_to_sample (record : dict [str , Any ]) -> Sample :
133152 # extract image
134- IMAGE_BASE_DIR = Path (user_cache_dir ("inspect_evals" )) / "docvqa_images"
135- image = Path (IMAGE_BASE_DIR / record ["image" ]["path" ])
153+ image_path = Path (IMAGE_BASE_DIR / record ["image" ]["path" ])
136154
137155 image_bytes = record ["image" ]["bytes" ]
138156 assert is_image_png (image_bytes )
139157
140- if not image .exists ():
141- print (f"Extracting { image .name } " )
158+ if not image_path .exists ():
159+ print (f"Extracting { image_path .name } " )
142160 # ensure parent
143- image .parent .mkdir (exist_ok = True , parents = True )
161+ image_path .parent .mkdir (exist_ok = True , parents = True )
144162 # reduce the image size
145163 img = Image .open (BytesIO (image_bytes ))
146164 img .thumbnail ((1024 , 1024 ))
147165 # save preserving format
148- img .save (image , format = img .format )
166+ img .save (image_path , format = img .format )
149167
150168 message : list [ChatMessage ] = [
151169 ChatMessageUser (
152170 content = [
153171 ContentText (text = record ["question" ]),
154- ContentImage (image = image .as_posix ()),
172+ ContentImage (image = image_path .as_posix ()),
155173 ]
156174 )
157175 ]
0 commit comments