|
| 1 | +""" |
| 2 | +WorldSense: A Synthetic Benchmark for Grounded Reasoning |
| 3 | +in Large Language Models |
| 4 | +
|
| 5 | +Youssef Benchekroun, Megi Dervishi, Mark Ibrahim, Jean-Baptiste Gaya, Xavier Martinet, Grégoire Mialon, Thomas Scialom, Emmanuel Dupoux, Dieuwke Hupkes, Pascal Vincent |
| 6 | +https://arxiv.org/pdf/2311.15930 |
| 7 | +
|
| 8 | +# eval all problemnames w/ 500 randomly selected samples |
| 9 | +inspect eval worldsense --limit 500 |
| 10 | +
|
| 11 | +# add chain of thought |
| 12 | +inspect eval worldsense --limit 500 |
| 13 | +
|
| 14 | +# eval selected problem |
| 15 | +inspect eval worldsense -T problemnames=Compl.normal |
| 16 | +inspect eval worldsense -T problemnames=Consist.normal |
| 17 | +inspect eval worldsense -T problemnames=Infer.normal,Infer.trivial |
| 18 | +""" # noqa: D205 |
| 19 | + |
| 20 | +import bz2 |
| 21 | +import json |
| 22 | +from typing import Any, Callable, Dict |
| 23 | + |
| 24 | +import requests |
| 25 | +from inspect_ai import Task, task |
| 26 | +from inspect_ai.dataset import Dataset, MemoryDataset, Sample |
| 27 | +from inspect_ai.model import GenerateConfig |
| 28 | +from inspect_ai.scorer import ( |
| 29 | + Metric, |
| 30 | + Score, |
| 31 | + Scorer, |
| 32 | + Target, |
| 33 | + accuracy, |
| 34 | + metric, |
| 35 | + pattern, |
| 36 | + scorer, |
| 37 | + stderr, |
| 38 | +) |
| 39 | +from inspect_ai.solver import TaskState, generate |
| 40 | + |
| 41 | +from ._utils import compute_accuracy, compute_bias, preprocess_scores |
| 42 | + |
| 43 | + |
| 44 | +@task |
| 45 | +def worldsense(problemnames: str | list[str] = []) -> Task: |
| 46 | + """ |
| 47 | + Task for evaluating reasoning related to world descriptions. There are three problem types ("Infer", "Compl", "Consist") and two grades ("trivial", "normal"). A problemname is formed by concatenation "<type>.<grade>". See README for details. |
| 48 | +
|
| 49 | + Args: |
| 50 | + problemnames (str | list[str], optional): A string or list of strings specifying the names of problems or tasks to filter the dataset. If provided, it filters the dataset to samples that contain matching metadata for the specified problem names. |
| 51 | +
|
| 52 | + Returns: |
| 53 | + Task: A task object configured with a dataset filtered by problem names (if specified), a solver, a scoring pattern for evaluating task responses, and custom metrics. |
| 54 | + """ |
| 55 | + # filter dataset if requested |
| 56 | + problemnames = problemnames if isinstance(problemnames, list) else [problemnames] |
| 57 | + if len(problemnames) > 0: |
| 58 | + task_dataset = dataset.filter( |
| 59 | + name=f"{dataset.name}-{'-'.join(problemnames)}", |
| 60 | + predicate=lambda sample: sample.metadata is not None |
| 61 | + and sample.metadata.get("problemname") in problemnames, |
| 62 | + ) |
| 63 | + else: |
| 64 | + task_dataset = dataset |
| 65 | + |
| 66 | + return Task( |
| 67 | + dataset=task_dataset, |
| 68 | + solver=generate(), |
| 69 | + scorer=pattern_with_metadata( |
| 70 | + r"^\(?\s*(1|2|3|TRUE|FALSE|IMPOSSIBLE|POSSIBLE)\s*\)?" |
| 71 | + ), |
| 72 | + metrics=[accuracy(), stderr(), ws_accuracy(), ws_bias()], |
| 73 | + config=GenerateConfig(temperature=0.0), |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +def record_to_sample(record: Dict[str, Any]) -> Sample: |
| 78 | + goldresp_mapping: Dict[str, str] = { |
| 79 | + "Emmanuel": "TRUE", |
| 80 | + "Megi": "FALSE", |
| 81 | + "Dieuwke": "POSSIBLE", |
| 82 | + "Pascal": "IMPOSSIBLE", |
| 83 | + "Mark": "1", |
| 84 | + "Youssef": "2", |
| 85 | + "Yoda": "3", |
| 86 | + } |
| 87 | + |
| 88 | + return Sample( |
| 89 | + input=record["text"], |
| 90 | + choices=record["expectedresp"], |
| 91 | + target=goldresp_mapping.get(record["goldresp_obfusc"], "None"), |
| 92 | + metadata={ |
| 93 | + "tuple_ID": record["tuple_ID"], |
| 94 | + "problemname": record["problemname"], |
| 95 | + "problemsize": record["problemsize"], |
| 96 | + }, |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def load_worldsense_dataset( |
| 101 | + sample_fields: Callable[[dict[str, Any]], Sample], |
| 102 | + shuffle: bool = True, |
| 103 | +) -> Dataset: |
| 104 | + """ |
| 105 | + Load the worldsense dataset from a bz2 file directly into memory and return a Dataset. |
| 106 | +
|
| 107 | + Args: |
| 108 | + sample_fields (Callable): Function to map records to samples. |
| 109 | + shuffle (bool): Whether to shuffle the dataset. Default is True. |
| 110 | +
|
| 111 | + Returns: |
| 112 | + Dataset: The loaded and decompressed dataset in memory. |
| 113 | + """ |
| 114 | + url = "https://github.com/facebookresearch/worldsense/raw/bd81d945077f169cf95ff39207f788f86e4645e9/data/worldsense/test_set/trials.jsonl.bz2" |
| 115 | + |
| 116 | + # Download and decompress the bz2 data directly into memory |
| 117 | + response = requests.get(url) |
| 118 | + response.raise_for_status() |
| 119 | + decompressed_data = bz2.decompress(response.content).decode("utf-8") |
| 120 | + |
| 121 | + # Parse the decompressed data into records |
| 122 | + lines = decompressed_data.strip().splitlines() |
| 123 | + samples = [] |
| 124 | + for line in lines: |
| 125 | + record = json.loads(line) |
| 126 | + sample = sample_fields(record) |
| 127 | + if isinstance(sample, list): |
| 128 | + samples.extend(sample) |
| 129 | + else: |
| 130 | + samples.append(sample) |
| 131 | + |
| 132 | + # Create a MemoryDataset |
| 133 | + dataset = MemoryDataset(samples=samples, name="worldsense", location=url) |
| 134 | + |
| 135 | + # Shuffle if needed |
| 136 | + if shuffle: |
| 137 | + dataset.shuffle() |
| 138 | + |
| 139 | + return dataset |
| 140 | + |
| 141 | + |
| 142 | +dataset = load_worldsense_dataset(sample_fields=record_to_sample, shuffle=True) |
| 143 | + |
| 144 | + |
| 145 | +@scorer(metrics=[accuracy(), stderr()]) |
| 146 | +def pattern_with_metadata( |
| 147 | + pattern_str: str, ignore_case: bool = True, match_all: bool = False |
| 148 | +) -> Scorer: |
| 149 | + base_scorer = pattern(pattern_str, ignore_case, match_all) |
| 150 | + |
| 151 | + async def score(state: TaskState, target: Target) -> Score: |
| 152 | + base_score = await base_scorer(state, target) |
| 153 | + base_score.metadata = state.metadata |
| 154 | + return base_score |
| 155 | + |
| 156 | + return score |
| 157 | + |
| 158 | + |
| 159 | +@metric |
| 160 | +def ws_accuracy() -> Metric: |
| 161 | + """Compute weighted accuracy metric.""" |
| 162 | + |
| 163 | + def metric(scores: list[Score]) -> float: |
| 164 | + grouped_scores = preprocess_scores(scores) |
| 165 | + return compute_accuracy(grouped_scores) |
| 166 | + |
| 167 | + return metric |
| 168 | + |
| 169 | + |
| 170 | +@metric |
| 171 | +def ws_bias() -> Metric: |
| 172 | + """Compute weighted bias metric.""" |
| 173 | + |
| 174 | + def metric(scores: list[Score]) -> float: |
| 175 | + grouped_scores = preprocess_scores(scores) |
| 176 | + return compute_bias(grouped_scores) |
| 177 | + |
| 178 | + return metric |
0 commit comments