Skip to content

Commit e842ba1

Browse files
[docs] Add Google-style docstrings for dspy/evaluate/metrics.py (#8954)
* docs(metrics): add Google-style docstrings for public metrics * docs(metrics): address review feedback (concise openings, mkdocs block examples); revert non-doc changes * fixes --------- Co-authored-by: chenmoneygithub <chen.qian@databricks.com>
1 parent 6c43880 commit e842ba1

File tree

1 file changed

+204
-12
lines changed

1 file changed

+204
-12
lines changed

dspy/evaluate/metrics.py

Lines changed: 204 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,102 @@
99

1010

1111
def EM(prediction, answers_list): # noqa: N802
12-
assert isinstance(answers_list, list)
12+
"""Compute the Exact Match (EM) metric between a prediction and reference answers.
13+
14+
Returns True if any reference exactly matches the prediction after normalization;
15+
otherwise False. Normalization applies Unicode NFD, lowercasing, punctuation
16+
removal, English article removal ("a", "an", "the"), and whitespace collapse.
17+
18+
Args:
19+
prediction (str): Predicted answer string.
20+
answers_list (list[str]): List of reference answers.
21+
22+
Returns:
23+
bool: Whether any reference exactly equals the prediction after normalization.
24+
25+
Example:
26+
```python
27+
EM("The Eiffel Tower", ["Eiffel Tower", "Louvre"]) # True
28+
29+
EM("paris", ["Paris"]) # True
30+
EM("paris", ["Paris, France"]) # False
31+
```
32+
"""
33+
if not isinstance(answers_list, list):
34+
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")
1335

1436
return max(em_score(prediction, ans) for ans in answers_list)
1537

1638

1739
def F1(prediction, answers_list): # noqa: N802
18-
assert isinstance(answers_list, list)
40+
"""Compute the maximum token-level F1 score against reference answers.
41+
42+
Strings are normalized (same as in `EM`) and whitespace-tokenized. The function
43+
returns the maximum F1 over all provided references.
44+
45+
Args:
46+
prediction (str): Predicted answer string.
47+
answers_list (list[str]): List of reference answers.
48+
49+
Returns:
50+
float: Highest F1 score in [0.0, 1.0].
51+
52+
Example:
53+
```python
54+
round(F1("Eiffel Tower is in Paris", ["Paris"]), 2) # 0.33
55+
```
56+
"""
57+
if not isinstance(answers_list, list):
58+
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")
1959

2060
return max(f1_score(prediction, ans) for ans in answers_list)
2161

2262

2363
def HotPotF1(prediction, answers_list): # noqa: N802
24-
assert isinstance(answers_list, list)
64+
"""Compute the maximum HotPotQA-style F1 score against reference answers.
65+
66+
Like `F1`, but if either normalized side is one of {"yes", "no", "noanswer"}
67+
and they differ, the score is 0. Otherwise, standard token-level F1 is used.
68+
69+
Args:
70+
prediction (str): Predicted answer.
71+
answers_list (list[str]): List of reference answers.
72+
73+
Returns:
74+
float: Highest HotPotQA-style F1 in [0.0, 1.0].
75+
76+
Example:
77+
```python
78+
HotPotF1("yes", ["no"]) # 0.0
79+
```
80+
"""
81+
if not isinstance(answers_list, list):
82+
raise ValueError(f"`answers_list` must be a list, got {type(answers_list)}")
2583

2684
return max(hotpot_f1_score(prediction, ans) for ans in answers_list)
2785

2886

2987
def normalize_text(s):
88+
"""Normalize text for string and token comparisons.
89+
90+
Steps:
91+
1) Unicode NFD normalization
92+
2) lowercasing
93+
3) punctuation removal
94+
4) English article removal ("a", "an", "the")
95+
5) whitespace collapse
96+
97+
Args:
98+
s (str): Input string.
99+
100+
Returns:
101+
str: Normalized string.
102+
103+
Example:
104+
```python
105+
normalize_text("The, Eiffel Tower!") # "eiffel tower"
106+
```
107+
"""
30108
s = unicodedata.normalize("NFD", s)
31109

32110
def remove_articles(text):
@@ -46,15 +124,42 @@ def lower(text):
46124

47125

48126
def em_score(prediction, ground_truth):
49-
return normalize_text(prediction) == normalize_text(ground_truth)
127+
"""Compute boolean exact match after normalization.
50128
129+
Args:
130+
prediction (str): Predicted answer.
131+
ground_truth (str): Reference answer.
51132
52-
# See: https://github.com/hotpotqa/hotpot/blob/master/hotpot_evaluate_v1.py
53-
# See: https://rajpurkar.github.io/SQuAD-explorer/ under Evaluation Script
54-
# See: QReCC's
133+
Returns:
134+
bool: True if normalized strings are identical; otherwise False.
135+
136+
Example:
137+
```python
138+
em_score("Paris", "paris") # True
139+
```
140+
"""
141+
return normalize_text(prediction) == normalize_text(ground_truth)
55142

56143

57144
def f1_score(prediction, ground_truth):
145+
"""Compute token-level F1 between prediction and reference (after normalization).
146+
147+
Strings are normalized (see `normalize_text`) and split by whitespace. F1 is
148+
computed from token precision and recall. If there is no token overlap, returns 0.
149+
If both sides are empty, a diagnostic message is printed; score remains 0.
150+
151+
Args:
152+
prediction (str): Predicted answer.
153+
ground_truth (str): Reference answer.
154+
155+
Returns:
156+
float: F1 score in [0.0, 1.0].
157+
158+
Example:
159+
```python
160+
round(f1_score("the Eiffel Tower", "Eiffel Tower"), 2) # 1.0
161+
```
162+
"""
58163
prediction_tokens = normalize_text(prediction).split()
59164
ground_truth_tokens = normalize_text(ground_truth).split()
60165

@@ -76,6 +181,23 @@ def f1_score(prediction, ground_truth):
76181

77182

78183
def hotpot_f1_score(prediction, ground_truth):
184+
"""Compute HotPotQA-style token F1 with special labels.
185+
186+
If either normalized string is in {"yes", "no", "noanswer"} and they differ,
187+
the score is 0. Otherwise compute standard token F1 after normalization.
188+
189+
Args:
190+
prediction (str): Predicted answer.
191+
ground_truth (str): Reference answer.
192+
193+
Returns:
194+
float: HotPotQA-style F1 score in [0.0, 1.0].
195+
196+
Example:
197+
```python
198+
hotpot_f1_score("no", "yes") # 0.0
199+
```
200+
"""
79201
normalized_prediction = normalize_text(prediction)
80202
normalized_ground_truth = normalize_text(ground_truth)
81203

@@ -97,6 +219,24 @@ def hotpot_f1_score(prediction, ground_truth):
97219

98220

99221
def precision_score(prediction, ground_truth):
222+
"""Compute token-level precision of prediction against reference (after normalization).
223+
224+
Precision is (# overlapping tokens) / (# tokens in prediction). If there is no
225+
token overlap, returns 0. If both sides are empty, a diagnostic message is printed;
226+
precision remains 0.
227+
228+
Args:
229+
prediction (str): Predicted answer.
230+
ground_truth (str): Reference answer.
231+
232+
Returns:
233+
float: Precision in [0.0, 1.0].
234+
235+
Example:
236+
```python
237+
precision_score("eiffel tower in paris", "eiffel tower") # 0.67
238+
```
239+
"""
100240
prediction_tokens = normalize_text(prediction).split()
101241
ground_truth_tokens = normalize_text(ground_truth).split()
102242

@@ -105,22 +245,23 @@ def precision_score(prediction, ground_truth):
105245

106246
if len(prediction_tokens) == len(ground_truth_tokens) == 0:
107247
# Unlike most tasks, QReCC and SQuAD-2.0 assign 1.0 in this edge case. We don't for uniformity.
108-
print_message("\n#> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n")
248+
print_message(
249+
"\n#> Precision Metric: Rare edge case of len(prediction_tokens) == len(ground_truth_tokens) == 0.\n"
250+
)
109251

110252
if num_same == 0:
111253
return 0
112254

113255
precision = 1.0 * num_same / len(prediction_tokens)
114-
115256
return precision
116257

117258

118259
def _passage_match(passages: list[str], answers: list[str]) -> bool:
119-
"""Returns True if any of the passages contains the answer."""
260+
"""Return True if any passage contains any answer (normalized & DPR-normalized)."""
120261
from dspy.dsp.utils import DPR_normalize, has_answer
121262

122263
def passage_has_answers(passage: str, answers: list[str]) -> bool:
123-
"""Returns True if the passage contains the answer."""
264+
"""Return True if the passage contains any of the answers."""
124265
return has_answer(
125266
tokenized_answers=[DPR_normalize(normalize_text(ans)) for ans in answers],
126267
text=normalize_text(passage),
@@ -130,15 +271,44 @@ def passage_has_answers(passage: str, answers: list[str]) -> bool:
130271

131272

132273
def _answer_match(prediction, answers, frac=1.0):
133-
"""Returns True if the prediction matches any of the answers."""
274+
"""Return True if prediction matches any answer.
134275
276+
When `frac >= 1.0`, require exact match (EM). Otherwise, return whether the
277+
maximum token-level F1 across answers is at least `frac`.
278+
"""
135279
if frac >= 1.0:
136280
return EM(prediction, answers)
137281

138282
return F1(prediction, answers) >= frac
139283

140284

141285
def answer_exact_match(example, pred, trace=None, frac=1.0):
286+
"""Evaluate exact match or F1-thresholded match for an example/prediction pair.
287+
288+
If `example.answer` is a string, compare `pred.answer` against it. If it's a list,
289+
compare against any of the references. When `frac >= 1.0` (default), use EM;
290+
otherwise require that the maximum F1 across references is at least `frac`.
291+
292+
Args:
293+
example: `dspy.Example` object with field `answer` (str or list[str]).
294+
pred: `dspy.Prediction` object with field `answer` (str).
295+
trace: Unused; reserved for compatibility.
296+
frac (float, optional): Threshold in [0.0, 1.0]. `1.0` means EM.
297+
298+
Returns:
299+
bool: True if the match condition holds; otherwise False.
300+
301+
Example:
302+
```python
303+
import dspy
304+
305+
example = dspy.Example(answer=["Eiffel Tower", "Louvre"])
306+
pred = dspy.Prediction(answer="The Eiffel Tower")
307+
308+
answer_exact_match(example, pred, frac=1.0) # equivalent to EM, True
309+
answer_exact_match(example, pred, frac=0.5) # True
310+
```
311+
"""
142312
if isinstance(example.answer, str):
143313
return _answer_match(pred.answer, [example.answer], frac=frac)
144314
elif isinstance(example.answer, list):
@@ -148,6 +318,28 @@ def answer_exact_match(example, pred, trace=None, frac=1.0):
148318

149319

150320
def answer_passage_match(example, pred, trace=None):
321+
"""Return True if any passage in `pred.context` contains the answer(s).
322+
323+
Strings are normalized (and passages also use DPR normalization internally).
324+
325+
Args:
326+
example: `dspy.Example` object with field `answer` (str or list[str]).
327+
pred: `dspy.Prediction` object with field `context` (list[str]) containing passages.
328+
trace: Unused; reserved for compatibility.
329+
330+
Returns:
331+
bool: True if any passage contains any reference answer; otherwise False.
332+
333+
Example:
334+
```python
335+
import dspy
336+
337+
example = dspy.Example(answer="Eiffel Tower")
338+
pred = dspy.Prediction(context=["The Eiffel Tower is in Paris.", "..."])
339+
340+
answer_passage_match(example, pred) # True
341+
```
342+
"""
151343
if isinstance(example.answer, str):
152344
return _passage_match(pred.context, [example.answer])
153345
elif isinstance(example.answer, list):

0 commit comments

Comments
 (0)