Skip to content

Commit 299f69c

Browse files
Merge pull request rllm-org#65 from mjbroerman/feature/worldsense
WorldSense Benchmark Implementation | ASET - Arcadia Impact
2 parents 74ba667 + b8f4b6c commit 299f69c

File tree

7 files changed

+326
-0
lines changed

7 files changed

+326
-0
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ inspect_evals = "inspect_evals._registry"
123123
[project.optional-dependencies]
124124
swe_bench = ["swebench>=2.1.0","docker"]
125125
mathematics = ["sympy","antlr4-python3-runtime==4.13.2"]
126+
worldsense = ["pandas"]
126127

127128
dev = [
128129
"inspect_ai@git+https://github.com/UKGovernmentBEIS/inspect_ai",

src/inspect_evals/_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,5 @@
6464
from .vstar_bench import vstar_bench_ar, vstar_bench_srr
6565
from .winogrande import winogrande
6666
from .wmdp import wmdp_bio, wmdp_chem, wmdp_cyber
67+
from .worldsense import worldsense
6768
from .xstest import xstest
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
trials.jsonl
2+
example.jsonl
3+
__pycache__/
4+
/.quarto/
5+
_output/
6+
_quarto.yml
7+
*.qmd
8+
*.quarto_ipynb
9+
img/
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# WorldSense
2+
3+
[WorldSense](https://arxiv.org/pdf/2311.15930) is a benchmark to measure reasoning over a world-model while controlling for dataset bias.
4+
5+
6+
## Dataset
7+
Here is an example prompt ("description") from the dataset (`problemname == "Compl.trivial"`):
8+
9+
>"Alice is enrolled in 3 courses per week: Alice takes history before computer science and economics before history.
10+
>Choose one of the following alternatives: (1) Alice takes history in between economics and computer science, (2) Alice takes history outside of the time range between economics and computer science, or (3) it is impossible to decide.
11+
>Think carefully, and only respond with one of these possible options (1), (2), or (3).
12+
13+
There are three problem types:
14+
15+
1. "Infer" (inference): Determine whether a given statement about a description is true or false.
16+
2. "Compl" (completion): Select which of three statements are true about a description, including an option for when it is not possible to decide.
17+
3. "Consist" (consistency): Determine whether a statement about a description is possible or impossible.
18+
19+
In addition there are two grades:
20+
21+
1. "trivial": can be solved on statements alone
22+
2. "normal": requires a world model to solve
23+
24+
A problemname is formed by concatenation "<type>.<grade>".
25+
26+
## Scoring
27+
28+
- Simple accuracy
29+
- Standard error
30+
- Weighted accuracy by `tuple_ID`, `problemname`, and `problemsize` (as reported in the paper)
31+
- Weighted bias by `tuple_ID`, `problemname`, and `problemsize` (as reported in the paper)
32+
33+
In addition to built-in metrics, the main results are weighted accuracy and bias. Here the primary unit is a `tuple_ID` which corresponds to one "description" or scenario above. To this, up to three answer option sets are provided to ensure that all possible correct answers for the options are selected exactly once. Weights are used to average over multiple `tuple_ID`s.
34+
35+
All answer options are coded as positive or negative (in the above example, 1: 1, 2: 1, and 3: -1), and all option sets have the same arrangement of option codings. Bias is calculated from this, weighted accordingly.
36+
37+
Problem size corresponds to the number of entities in the description. In the example above, `problemsize == 3`. Within a `problemname`, the score is the grand average of problem sizes. If multiple `problemnames` are specified (including none specified), the score is the grand average of this also.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .worldsense import worldsense
2+
3+
__all__ = ["worldsense"]
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pandas as pd
2+
from inspect_ai.scorer import Score, ValueToFloat, value_to_float
3+
4+
5+
def preprocess_scores(
6+
scores: list[Score], to_float: ValueToFloat = value_to_float()
7+
) -> pd.DataFrame:
8+
"""Preprocesses a list of Score objects into a DataFrame with computed weights and biases."""
9+
processed_scores = [
10+
{
11+
"value": to_float(score.value),
12+
"answer": score.answer,
13+
"tuple_ID": score.metadata.get("tuple_ID"),
14+
"problemname": score.metadata.get("problemname"),
15+
"problemsize": score.metadata.get("problemsize"),
16+
}
17+
for score in scores
18+
if score.metadata
19+
and None
20+
not in (
21+
to_float(score.value),
22+
score.answer,
23+
score.metadata.get("tuple_ID"),
24+
score.metadata.get("problemname"),
25+
score.metadata.get("problemsize"),
26+
)
27+
]
28+
29+
score_df = pd.DataFrame(processed_scores)
30+
31+
# Mappings for bias and weight
32+
bias_mapping = {
33+
"1": 1,
34+
"2": 1,
35+
"TRUE": 1,
36+
"POSSIBLE": 1,
37+
"3": -1,
38+
"FALSE": -1,
39+
"IMPOSSIBLE": -1,
40+
}
41+
weight_mapping = {
42+
"1": 0.25,
43+
"2": 0.25,
44+
"3": 0.5,
45+
"TRUE": 0.5,
46+
"POSSIBLE": 0.5,
47+
"FALSE": 0.5,
48+
"IMPOSSIBLE": 0.5,
49+
}
50+
51+
# Calculate weighted values and biases
52+
score_df["weight"] = score_df["answer"].map(weight_mapping).astype(float)
53+
score_df["bias"] = score_df["answer"].map(bias_mapping).astype(float)
54+
score_df["value"] = score_df["value"].astype(float) * score_df["weight"]
55+
score_df["bias"] *= score_df["weight"]
56+
57+
# Group and normalize
58+
grouped_scores = (
59+
score_df.groupby(["tuple_ID", "problemname", "problemsize"])
60+
.agg({"value": "sum", "bias": "sum", "weight": "sum"})
61+
.reset_index()
62+
)
63+
64+
grouped_scores["value"] = grouped_scores["value"] / grouped_scores["weight"].where(
65+
grouped_scores["weight"] != 0, 1
66+
)
67+
grouped_scores["bias"] = grouped_scores["bias"] / grouped_scores["weight"].where(
68+
grouped_scores["weight"] != 0, 1
69+
)
70+
71+
return grouped_scores
72+
73+
74+
def compute_accuracy(grouped_scores: pd.DataFrame) -> float:
75+
"""Compute the weighted accuracy from preprocessed scores."""
76+
problem_summary = (
77+
grouped_scores.groupby(["problemname", "problemsize"])
78+
.agg({"value": "mean"})
79+
.reset_index()
80+
)
81+
final_summary = (
82+
problem_summary.groupby("problemname").agg({"value": "mean"}).reset_index()
83+
)
84+
return float(final_summary["value"].mean())
85+
86+
87+
def compute_bias(grouped_scores: pd.DataFrame) -> float:
88+
"""Compute the weighted bias from preprocessed scores."""
89+
problem_summary = (
90+
grouped_scores.groupby(["problemname", "problemsize"])
91+
.agg({"bias": "mean"})
92+
.reset_index()
93+
)
94+
final_summary = (
95+
problem_summary.groupby("problemname").agg({"bias": "mean"}).reset_index()
96+
)
97+
return float(final_summary["bias"].mean())
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
"""
2+
WorldSense: A Synthetic Benchmark for Grounded Reasoning
3+
in Large Language Models
4+
5+
Youssef Benchekroun, Megi Dervishi, Mark Ibrahim, Jean-Baptiste Gaya, Xavier Martinet, Grégoire Mialon, Thomas Scialom, Emmanuel Dupoux, Dieuwke Hupkes, Pascal Vincent
6+
https://arxiv.org/pdf/2311.15930
7+
8+
# eval all problemnames w/ 500 randomly selected samples
9+
inspect eval worldsense --limit 500
10+
11+
# add chain of thought
12+
inspect eval worldsense --limit 500
13+
14+
# eval selected problem
15+
inspect eval worldsense -T problemnames=Compl.normal
16+
inspect eval worldsense -T problemnames=Consist.normal
17+
inspect eval worldsense -T problemnames=Infer.normal,Infer.trivial
18+
""" # noqa: D205
19+
20+
import bz2
21+
import json
22+
from typing import Any, Callable, Dict
23+
24+
import requests
25+
from inspect_ai import Task, task
26+
from inspect_ai.dataset import Dataset, MemoryDataset, Sample
27+
from inspect_ai.model import GenerateConfig
28+
from inspect_ai.scorer import (
29+
Metric,
30+
Score,
31+
Scorer,
32+
Target,
33+
accuracy,
34+
metric,
35+
pattern,
36+
scorer,
37+
stderr,
38+
)
39+
from inspect_ai.solver import TaskState, generate
40+
41+
from ._utils import compute_accuracy, compute_bias, preprocess_scores
42+
43+
44+
@task
45+
def worldsense(problemnames: str | list[str] = []) -> Task:
46+
"""
47+
Task for evaluating reasoning related to world descriptions. There are three problem types ("Infer", "Compl", "Consist") and two grades ("trivial", "normal"). A problemname is formed by concatenation "<type>.<grade>". See README for details.
48+
49+
Args:
50+
problemnames (str | list[str], optional): A string or list of strings specifying the names of problems or tasks to filter the dataset. If provided, it filters the dataset to samples that contain matching metadata for the specified problem names.
51+
52+
Returns:
53+
Task: A task object configured with a dataset filtered by problem names (if specified), a solver, a scoring pattern for evaluating task responses, and custom metrics.
54+
"""
55+
# filter dataset if requested
56+
problemnames = problemnames if isinstance(problemnames, list) else [problemnames]
57+
if len(problemnames) > 0:
58+
task_dataset = dataset.filter(
59+
name=f"{dataset.name}-{'-'.join(problemnames)}",
60+
predicate=lambda sample: sample.metadata is not None
61+
and sample.metadata.get("problemname") in problemnames,
62+
)
63+
else:
64+
task_dataset = dataset
65+
66+
return Task(
67+
dataset=task_dataset,
68+
solver=generate(),
69+
scorer=pattern_with_metadata(
70+
r"^\(?\s*(1|2|3|TRUE|FALSE|IMPOSSIBLE|POSSIBLE)\s*\)?"
71+
),
72+
metrics=[accuracy(), stderr(), ws_accuracy(), ws_bias()],
73+
config=GenerateConfig(temperature=0.0),
74+
)
75+
76+
77+
def record_to_sample(record: Dict[str, Any]) -> Sample:
78+
goldresp_mapping: Dict[str, str] = {
79+
"Emmanuel": "TRUE",
80+
"Megi": "FALSE",
81+
"Dieuwke": "POSSIBLE",
82+
"Pascal": "IMPOSSIBLE",
83+
"Mark": "1",
84+
"Youssef": "2",
85+
"Yoda": "3",
86+
}
87+
88+
return Sample(
89+
input=record["text"],
90+
choices=record["expectedresp"],
91+
target=goldresp_mapping.get(record["goldresp_obfusc"], "None"),
92+
metadata={
93+
"tuple_ID": record["tuple_ID"],
94+
"problemname": record["problemname"],
95+
"problemsize": record["problemsize"],
96+
},
97+
)
98+
99+
100+
def load_worldsense_dataset(
101+
sample_fields: Callable[[dict[str, Any]], Sample],
102+
shuffle: bool = True,
103+
) -> Dataset:
104+
"""
105+
Load the worldsense dataset from a bz2 file directly into memory and return a Dataset.
106+
107+
Args:
108+
sample_fields (Callable): Function to map records to samples.
109+
shuffle (bool): Whether to shuffle the dataset. Default is True.
110+
111+
Returns:
112+
Dataset: The loaded and decompressed dataset in memory.
113+
"""
114+
url = "https://github.com/facebookresearch/worldsense/raw/bd81d945077f169cf95ff39207f788f86e4645e9/data/worldsense/test_set/trials.jsonl.bz2"
115+
116+
# Download and decompress the bz2 data directly into memory
117+
response = requests.get(url)
118+
response.raise_for_status()
119+
decompressed_data = bz2.decompress(response.content).decode("utf-8")
120+
121+
# Parse the decompressed data into records
122+
lines = decompressed_data.strip().splitlines()
123+
samples = []
124+
for line in lines:
125+
record = json.loads(line)
126+
sample = sample_fields(record)
127+
if isinstance(sample, list):
128+
samples.extend(sample)
129+
else:
130+
samples.append(sample)
131+
132+
# Create a MemoryDataset
133+
dataset = MemoryDataset(samples=samples, name="worldsense", location=url)
134+
135+
# Shuffle if needed
136+
if shuffle:
137+
dataset.shuffle()
138+
139+
return dataset
140+
141+
142+
dataset = load_worldsense_dataset(sample_fields=record_to_sample, shuffle=True)
143+
144+
145+
@scorer(metrics=[accuracy(), stderr()])
146+
def pattern_with_metadata(
147+
pattern_str: str, ignore_case: bool = True, match_all: bool = False
148+
) -> Scorer:
149+
base_scorer = pattern(pattern_str, ignore_case, match_all)
150+
151+
async def score(state: TaskState, target: Target) -> Score:
152+
base_score = await base_scorer(state, target)
153+
base_score.metadata = state.metadata
154+
return base_score
155+
156+
return score
157+
158+
159+
@metric
160+
def ws_accuracy() -> Metric:
161+
"""Compute weighted accuracy metric."""
162+
163+
def metric(scores: list[Score]) -> float:
164+
grouped_scores = preprocess_scores(scores)
165+
return compute_accuracy(grouped_scores)
166+
167+
return metric
168+
169+
170+
@metric
171+
def ws_bias() -> Metric:
172+
"""Compute weighted bias metric."""
173+
174+
def metric(scores: list[Score]) -> float:
175+
grouped_scores = preprocess_scores(scores)
176+
return compute_bias(grouped_scores)
177+
178+
return metric

0 commit comments

Comments
 (0)