99
1010
1111def EM (prediction , answers_list ): # noqa: N802
12- assert isinstance (answers_list , list )
12+ """Compute the Exact Match (EM) metric between a prediction and reference answers.
13+
14+ Returns True if any reference exactly matches the prediction after normalization;
15+ otherwise False. Normalization applies Unicode NFD, lowercasing, punctuation
16+ removal, English article removal ("a", "an", "the"), and whitespace collapse.
17+
18+ Args:
19+ prediction (str): Predicted answer string.
20+ answers_list (list[str]): List of reference answers.
21+
22+ Returns:
23+ bool: Whether any reference exactly equals the prediction after normalization.
24+
25+ Example:
26+ ```python
27+ EM("The Eiffel Tower", ["Eiffel Tower", "Louvre"]) # True
28+
29+ EM("paris", ["Paris"]) # True
30+ EM("paris", ["Paris, France"]) # False
31+ ```
32+ """
33+ if not isinstance (answers_list , list ):
34+ raise ValueError (f"`answers_list` must be a list, got { type (answers_list )} " )
1335
1436 return max (em_score (prediction , ans ) for ans in answers_list )
1537
1638
1739def F1 (prediction , answers_list ): # noqa: N802
18- assert isinstance (answers_list , list )
40+ """Compute the maximum token-level F1 score against reference answers.
41+
42+ Strings are normalized (same as in `EM`) and whitespace-tokenized. The function
43+ returns the maximum F1 over all provided references.
44+
45+ Args:
46+ prediction (str): Predicted answer string.
47+ answers_list (list[str]): List of reference answers.
48+
49+ Returns:
50+ float: Highest F1 score in [0.0, 1.0].
51+
52+ Example:
53+ ```python
54+ round(F1("Eiffel Tower is in Paris", ["Paris"]), 2) # 0.33
55+ ```
56+ """
57+ if not isinstance (answers_list , list ):
58+ raise ValueError (f"`answers_list` must be a list, got { type (answers_list )} " )
1959
2060 return max (f1_score (prediction , ans ) for ans in answers_list )
2161
2262
2363def HotPotF1 (prediction , answers_list ): # noqa: N802
24- assert isinstance (answers_list , list )
64+ """Compute the maximum HotPotQA-style F1 score against reference answers.
65+
66+ Like `F1`, but if either normalized side is one of {"yes", "no", "noanswer"}
67+ and they differ, the score is 0. Otherwise, standard token-level F1 is used.
68+
69+ Args:
70+ prediction (str): Predicted answer.
71+ answers_list (list[str]): List of reference answers.
72+
73+ Returns:
74+ float: Highest HotPotQA-style F1 in [0.0, 1.0].
75+
76+ Example:
77+ ```python
78+ HotPotF1("yes", ["no"]) # 0.0
79+ ```
80+ """
81+ if not isinstance (answers_list , list ):
82+ raise ValueError (f"`answers_list` must be a list, got { type (answers_list )} " )
2583
2684 return max (hotpot_f1_score (prediction , ans ) for ans in answers_list )
2785
2886
2987def normalize_text (s ):
88+ """Normalize text for string and token comparisons.
89+
90+ Steps:
91+ 1) Unicode NFD normalization
92+ 2) lowercasing
93+ 3) punctuation removal
94+ 4) English article removal ("a", "an", "the")
95+ 5) whitespace collapse
96+
97+ Args:
98+ s (str): Input string.
99+
100+ Returns:
101+ str: Normalized string.
102+
103+ Example:
104+ ```python
105+ normalize_text("The, Eiffel Tower!") # "eiffel tower"
106+ ```
107+ """
30108 s = unicodedata .normalize ("NFD" , s )
31109
32110 def remove_articles (text ):
@@ -46,15 +124,42 @@ def lower(text):
46124
47125
48126def em_score (prediction , ground_truth ):
49- return normalize_text ( prediction ) == normalize_text ( ground_truth )
127+ """Compute boolean exact match after normalization.
50128
129+ Args:
130+ prediction (str): Predicted answer.
131+ ground_truth (str): Reference answer.
51132
52- # See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
53- # See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script
54- # See: QReCC's
133+ Returns:
134+ bool: True if normalized strings are identical; otherwise False.
135+
136+ Example:
137+ ```python
138+ em_score("Paris", "paris") # True
139+ ```
140+ """
141+ return normalize_text (prediction ) == normalize_text (ground_truth )
55142
56143
57144def f1_score (prediction , ground_truth ):
145+ """Compute token-level F1 between prediction and reference (after normalization).
146+
147+ Strings are normalized (see `normalize_text`) and split by whitespace. F1 is
148+ computed from token precision and recall. If there is no token overlap, returns 0.
149+ If both sides are empty, a diagnostic message is printed; score remains 0.
150+
151+ Args:
152+ prediction (str): Predicted answer.
153+ ground_truth (str): Reference answer.
154+
155+ Returns:
156+ float: F1 score in [0.0, 1.0].
157+
158+ Example:
159+ ```python
160+ round(f1_score("the Eiffel Tower", "Eiffel Tower"), 2) # 1.0
161+ ```
162+ """
58163 prediction_tokens = normalize_text (prediction ).split ()
59164 ground_truth_tokens = normalize_text (ground_truth ).split ()
60165
@@ -76,6 +181,23 @@ def f1_score(prediction, ground_truth):
76181
77182
78183def hotpot_f1_score (prediction , ground_truth ):
184+ """Compute HotPotQA-style token F1 with special labels.
185+
186+ If either normalized string is in {"yes", "no", "noanswer"} and they differ,
187+ the score is 0. Otherwise compute standard token F1 after normalization.
188+
189+ Args:
190+ prediction (str): Predicted answer.
191+ ground_truth (str): Reference answer.
192+
193+ Returns:
194+ float: HotPotQA-style F1 score in [0.0, 1.0].
195+
196+ Example:
197+ ```python
198+ hotpot_f1_score("no", "yes") # 0.0
199+ ```
200+ """
79201 normalized_prediction = normalize_text (prediction )
80202 normalized_ground_truth = normalize_text (ground_truth )
81203
@@ -97,6 +219,24 @@ def hotpot_f1_score(prediction, ground_truth):
97219
98220
99221def precision_score (prediction , ground_truth ):
222+ """Compute token-level precision of prediction against reference (after normalization).
223+
224+ Precision is (# overlapping tokens) / (# tokens in prediction). If there is no
225+ token overlap, returns 0. If both sides are empty, a diagnostic message is printed;
226+ precision remains 0.
227+
228+ Args:
229+ prediction (str): Predicted answer.
230+ ground_truth (str): Reference answer.
231+
232+ Returns:
233+ float: Precision in [0.0, 1.0].
234+
235+ Example:
236+ ```python
237+ precision_score("eiffel tower in paris", "eiffel tower") # 0.67
238+ ```
239+ """
100240 prediction_tokens = normalize_text (prediction ).split ()
101241 ground_truth_tokens = normalize_text (ground_truth ).split ()
102242
@@ -105,22 +245,23 @@ def precision_score(prediction, ground_truth):
105245
106246 if len (prediction_tokens ) == len (ground_truth_tokens ) == 0 :
107247 # Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
108- print_message ("\n #> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n " )
248+ print_message (
249+ "\n #> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n "
250+ )
109251
110252 if num_same == 0 :
111253 return 0
112254
113255 precision = 1.0 * num_same / len (prediction_tokens )
114-
115256 return precision
116257
117258
118259def _passage_match (passages : list [str ], answers : list [str ]) -> bool :
119- """Returns True if any of the passages contains the answer."""
260+ """Return True if any passage contains any answer (normalized & DPR-normalized) ."""
120261 from dspy .dsp .utils import DPR_normalize , has_answer
121262
122263 def passage_has_answers (passage : str , answers : list [str ]) -> bool :
123- """Returns True if the passage contains the answer ."""
264+ """Return True if the passage contains any of the answers ."""
124265 return has_answer (
125266 tokenized_answers = [DPR_normalize (normalize_text (ans )) for ans in answers ],
126267 text = normalize_text (passage ),
@@ -130,15 +271,44 @@ def passage_has_answers(passage: str, answers: list[str]) -> bool:
130271
131272
132273def _answer_match (prediction , answers , frac = 1.0 ):
133- """Returns True if the prediction matches any of the answers."""
274+ """Return True if prediction matches any answer.
134275
276+ When `frac >= 1.0`, require exact match (EM). Otherwise, return whether the
277+ maximum token-level F1 across answers is at least `frac`.
278+ """
135279 if frac >= 1.0 :
136280 return EM (prediction , answers )
137281
138282 return F1 (prediction , answers ) >= frac
139283
140284
141285def answer_exact_match (example , pred , trace = None , frac = 1.0 ):
286+ """Evaluate exact match or F1-thresholded match for an example/prediction pair.
287+
288+ If `example.answer` is a string, compare `pred.answer` against it. If it's a list,
289+ compare against any of the references. When `frac >= 1.0` (default), use EM;
290+ otherwise require that the maximum F1 across references is at least `frac`.
291+
292+ Args:
293+ example: `dspy.Example` object with field `answer` (str or list[str]).
294+ pred: `dspy.Prediction` object with field `answer` (str).
295+ trace: Unused; reserved for compatibility.
296+ frac (float, optional): Threshold in [0.0, 1.0]. `1.0` means EM.
297+
298+ Returns:
299+ bool: True if the match condition holds; otherwise False.
300+
301+ Example:
302+ ```python
303+ import dspy
304+
305+ example = dspy.Example(answer=["Eiffel Tower", "Louvre"])
306+ pred = dspy.Prediction(answer="The Eiffel Tower")
307+
308+ answer_exact_match(example, pred, frac=1.0) # equivalent to EM, True
309+ answer_exact_match(example, pred, frac=0.5) # True
310+ ```
311+ """
142312 if isinstance (example .answer , str ):
143313 return _answer_match (pred .answer , [example .answer ], frac = frac )
144314 elif isinstance (example .answer , list ):
@@ -148,6 +318,28 @@ def answer_exact_match(example, pred, trace=None, frac=1.0):
148318
149319
150320def answer_passage_match (example , pred , trace = None ):
321+ """Return True if any passage in `pred.context` contains the answer(s).
322+
323+ Strings are normalized (and passages also use DPR normalization internally).
324+
325+ Args:
326+ example: `dspy.Example` object with field `answer` (str or list[str]).
327+ pred: `dspy.Prediction` object with field `context` (list[str]) containing passages.
328+ trace: Unused; reserved for compatibility.
329+
330+ Returns:
331+ bool: True if any passage contains any reference answer; otherwise False.
332+
333+ Example:
334+ ```python
335+ import dspy
336+
337+ example = dspy.Example(answer="Eiffel Tower")
338+ pred = dspy.Prediction(context=["The Eiffel Tower is in Paris.", "..."])
339+
340+ answer_passage_match(example, pred) # True
341+ ```
342+ """
151343 if isinstance (example .answer , str ):
152344 return _passage_match (pred .context , [example .answer ])
153345 elif isinstance (example .answer , list ):
0 commit comments