Skip to content

Commit b8f4b6c

Browse files
committed
Refactor metric, improve documentation, cleanup
- Refactored the metric computation in `_utils.py`: - Split the `ws_metric` function into three separate functions for better modularity: - `preprocess_scores`: Preprocesses scores into a DataFrame with computed weights and biases. - `compute_accuracy`: Computes weighted accuracy from preprocessed scores. - `compute_bias`: Computes weighted bias from preprocessed scores. - Updated `ws_accuracy` and `ws_bias` in `worldsense.py` to utilize the new functions. - Improved documentation: - Added detailed explanations of problem types and grades in `README.md`, clarifying how `problemname` is formed. - Included a comprehensive docstring for the `worldsense` task in `worldsense.py`, explaining the task's purpose and usage.
1 parent 38e6af4 commit b8f4b6c

File tree

3 files changed

+87
-76
lines changed

3 files changed

+87
-76
lines changed

src/inspect_evals/worldsense/README.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,31 @@
44

55

66
## Dataset
7-
Here is an example prompt from the dataset:
7+
Here is an example prompt ("description") from the dataset (`problemname == "Compl.trivial"`):
88

99
>"Alice is enrolled in 3 courses per week: Alice takes history before computer science and economics before history.
1010
>Choose one of the following alternatives: (1) Alice takes history in between economics and computer science, (2) Alice takes history outside of the time range between economics and computer science, or (3) it is impossible to decide.
1111
>Think carefully, and only respond with one of these possible options (1), (2), or (3).
1212
13-
The model is then tasked to pick the correct choice.
13+
There are three problem types:
14+
15+
1. "Infer" (inference): Determine whether a given statement about a description is true or false.
16+
2. "Compl" (completion): Select which of three statements are true about a description, including an option for when it is not possible to decide.
17+
3. "Consist" (consistency): Determine whether a statement about a description is possible or impossible.
18+
19+
In addition there are two grades:
20+
21+
1. "trivial": can be solved on statements alone
22+
2. "normal": requires a world model to solve
23+
24+
A problemname is formed by concatenation "<type>.<grade>".
1425

1526
## Scoring
1627

1728
- Simple accuracy
1829
- Standard error
19-
- Weighted accuracy by `tuple_ID`, `problemname`, and `problemsize` (reported in the paper)
20-
- Weighted bias by `tuple_ID`, `problemname`, and `problemsize` (reported in the paper)
30+
- Weighted accuracy by `tuple_ID`, `problemname`, and `problemsize` (as reported in the paper)
31+
- Weighted bias by `tuple_ID`, `problemname`, and `problemsize` (as reported in the paper)
2132

2233
In addition to built-in metrics, the main results are weighted accuracy and bias. Here the primary unit is a `tuple_ID` which corresponds to one "description" or scenario above. To this, up to three answer option sets are provided to ensure that all possible correct answers for the options are selected exactly once. Weights are used to average over multiple `tuple_ID`s.
2334

src/inspect_evals/worldsense/_utils.py

Lines changed: 58 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,48 +2,33 @@
22
from inspect_ai.scorer import Score, ValueToFloat, value_to_float
33

44

5-
def ws_metric(
6-
scores: list[Score],
7-
kind: str,
8-
to_float: ValueToFloat = value_to_float(),
9-
) -> float:
10-
"""Compute the weighted accuracy or bias metric.
11-
12-
Args:
13-
scores (list[Score]): List of Score objects containing evaluation data.
14-
kind (str): Type of metric to compute ('acc' for accuracy or 'bias').
15-
to_float (ValueToFloat, optional): Function to convert `Score` values to floats. Defaults to `value_to_float()`.
16-
17-
Returns:
18-
float: The computed metric value.
19-
"""
20-
# Build DataFrame from the list of Score objects
21-
data = []
22-
for score in scores:
23-
value = to_float(score.value)
24-
answer = score.answer
25-
metadata = score.metadata or {}
26-
27-
tuple_ID = metadata.get("tuple_ID")
28-
problemname = metadata.get("problemname")
29-
problemsize = metadata.get("problemsize")
30-
31-
if None in (value, answer, tuple_ID, problemname, problemsize):
32-
continue
33-
34-
data.append(
35-
{
36-
"value": value,
37-
"answer": answer,
38-
"tuple_ID": tuple_ID,
39-
"problemsize": problemsize,
40-
"problemname": problemname,
41-
}
5+
def preprocess_scores(
6+
scores: list[Score], to_float: ValueToFloat = value_to_float()
7+
) -> pd.DataFrame:
8+
"""Preprocesses a list of Score objects into a DataFrame with computed weights and biases."""
9+
processed_scores = [
10+
{
11+
"value": to_float(score.value),
12+
"answer": score.answer,
13+
"tuple_ID": score.metadata.get("tuple_ID"),
14+
"problemname": score.metadata.get("problemname"),
15+
"problemsize": score.metadata.get("problemsize"),
16+
}
17+
for score in scores
18+
if score.metadata
19+
and None
20+
not in (
21+
to_float(score.value),
22+
score.answer,
23+
score.metadata.get("tuple_ID"),
24+
score.metadata.get("problemname"),
25+
score.metadata.get("problemsize"),
4226
)
27+
]
4328

44-
df = pd.DataFrame(data)
29+
score_df = pd.DataFrame(processed_scores)
4530

46-
# Define mappings for bias and weight
31+
# Mappings for bias and weight
4732
bias_mapping = {
4833
"1": 1,
4934
"2": 1,
@@ -63,46 +48,50 @@ def ws_metric(
6348
"IMPOSSIBLE": 0.5,
6449
}
6550

66-
# Calculate weight and bias values
67-
df["weight"] = df["answer"].map(weight_mapping).astype(float)
68-
df["bias"] = df["answer"].map(bias_mapping).astype(float)
69-
df["value"] = df["value"].astype(float)
70-
df["value"] *= df["weight"]
71-
df["bias"] *= df["weight"]
51+
# Calculate weighted values and biases
52+
score_df["weight"] = score_df["answer"].map(weight_mapping).astype(float)
53+
score_df["bias"] = score_df["answer"].map(bias_mapping).astype(float)
54+
score_df["value"] = score_df["value"].astype(float) * score_df["weight"]
55+
score_df["bias"] *= score_df["weight"]
7256

73-
# Step 5.1: Group by 'tuple_ID', 'problemname', 'problemsize' and sum
74-
grouped = (
75-
df.groupby(["tuple_ID", "problemname", "problemsize"])
57+
# Group and normalize
58+
grouped_scores = (
59+
score_df.groupby(["tuple_ID", "problemname", "problemsize"])
7660
.agg({"value": "sum", "bias": "sum", "weight": "sum"})
7761
.reset_index()
7862
)
7963

80-
# Step 5.2: Normalize 'value' and 'bias' by dividing by total 'weight'
81-
grouped["value"] = grouped["value"] / grouped["weight"].where(
82-
grouped["weight"] != 0, 1
64+
grouped_scores["value"] = grouped_scores["value"] / grouped_scores["weight"].where(
65+
grouped_scores["weight"] != 0, 1
8366
)
84-
grouped["bias"] = grouped["bias"] / grouped["weight"].where(
85-
grouped["weight"] != 0, 1
67+
grouped_scores["bias"] = grouped_scores["bias"] / grouped_scores["weight"].where(
68+
grouped_scores["weight"] != 0, 1
8669
)
8770

88-
# Step 6: Compute mean 'acc' and 'bias' grouped by 'problemname' and 'problemsize'
89-
summaries = (
90-
grouped.groupby(["problemname", "problemsize"])
91-
.agg({"value": "mean", "bias": "mean"})
71+
return grouped_scores
72+
73+
74+
def compute_accuracy(grouped_scores: pd.DataFrame) -> float:
75+
"""Compute the weighted accuracy from preprocessed scores."""
76+
problem_summary = (
77+
grouped_scores.groupby(["problemname", "problemsize"])
78+
.agg({"value": "mean"})
9279
.reset_index()
9380
)
81+
final_summary = (
82+
problem_summary.groupby("problemname").agg({"value": "mean"}).reset_index()
83+
)
84+
return float(final_summary["value"].mean())
85+
9486

95-
# Step 7: Compute overall mean 'acc' and 'bias' grouped by 'problemname'
96-
final_summaries = (
97-
summaries.groupby("problemname")
98-
.agg({"value": "mean", "bias": "mean"})
87+
def compute_bias(grouped_scores: pd.DataFrame) -> float:
88+
"""Compute the weighted bias from preprocessed scores."""
89+
problem_summary = (
90+
grouped_scores.groupby(["problemname", "problemsize"])
91+
.agg({"bias": "mean"})
9992
.reset_index()
10093
)
101-
102-
# Compute the final metric
103-
if kind == "acc":
104-
return float(final_summaries["value"].mean())
105-
elif kind == "bias":
106-
return float(final_summaries["bias"].mean())
107-
else:
108-
raise ValueError("Invalid kind argument, must be 'acc' or 'bias'")
94+
final_summary = (
95+
problem_summary.groupby("problemname").agg({"bias": "mean"}).reset_index()
96+
)
97+
return float(final_summary["bias"].mean())

src/inspect_evals/worldsense/worldsense.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,20 @@
3838
)
3939
from inspect_ai.solver import TaskState, generate
4040

41-
from ._utils import ws_metric
41+
from ._utils import compute_accuracy, compute_bias, preprocess_scores
4242

4343

4444
@task
4545
def worldsense(problemnames: str | list[str] = []) -> Task:
46+
"""
47+
Task for evaluating reasoning related to world descriptions. There are three problem types ("Infer", "Compl", "Consist") and two grades ("trivial", "normal"). A problemname is formed by concatenation "<type>.<grade>". See README for details.
48+
49+
Args:
50+
problemnames (str | list[str], optional): A string or list of strings specifying the names of problems or tasks to filter the dataset. If provided, it filters the dataset to samples that contain matching metadata for the specified problem names.
51+
52+
Returns:
53+
Task: A task object configured with a dataset filtered by problem names (if specified), a solver, a scoring pattern for evaluating task responses, and custom metrics.
54+
"""
4655
# filter dataset if requested
4756
problemnames = problemnames if isinstance(problemnames, list) else [problemnames]
4857
if len(problemnames) > 0:
@@ -152,7 +161,8 @@ def ws_accuracy() -> Metric:
152161
"""Compute weighted accuracy metric."""
153162

154163
def metric(scores: list[Score]) -> float:
155-
return ws_metric(scores, kind="acc")
164+
grouped_scores = preprocess_scores(scores)
165+
return compute_accuracy(grouped_scores)
156166

157167
return metric
158168

@@ -162,6 +172,7 @@ def ws_bias() -> Metric:
162172
"""Compute weighted bias metric."""
163173

164174
def metric(scores: list[Score]) -> float:
165-
return ws_metric(scores, kind="bias")
175+
grouped_scores = preprocess_scores(scores)
176+
return compute_bias(grouped_scores)
166177

167178
return metric

0 commit comments

Comments
 (0)