Skip to content

Commit 1ee5a42

Browse files
authored
Merge pull request rllm-org#51 from owenparsons/niah_aset
NIAH task implementation | ASET - Arcadia Impact
2 parents 2f03b4a + f2acc98 commit 1ee5a42

File tree

16 files changed

+2205
-3
lines changed

16 files changed

+2205
-3
lines changed

README.md

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,19 @@ The questions were generated by GPT-4 based on the "Computer Systems Security: P
352352
- ### [MuSR: Testing the Limits of Chain-of-thought with Multistep Soft Reasoning](src/inspect_evals/musr)
353353
Evaluating models on multistep soft reasoning tasks in the form of free text narratives.
354354
<sub><sup>Contributed by: [@farrelmahaztra](https://github.com/farrelmahaztra)</sub></sup>
355-
```
355+
356+
```bash
356357
inspect eval inspect_evals/musr
357358
```
358359

360+
- ### [Needle in a Haystack (NIAH): In-Context Retrieval Benchmark for Long Context LLMs](src/inspect_evals/niah)
361+
NIAH evaluates in-context retrieval ability of long context LLMs by testing a model's ability to extract factual information from long-context inputs.
362+
363+
364+
```bash
365+
inspect eval inspect_evals/niah
366+
```
367+
359368
- ### [PAWS: Paraphrase Adversaries from Word Scrambling](src/inspect_evals/paws)
360369
Evaluating models on the task of paraphrase detection by providing pairs of sentences that are either paraphrases or not.
361370
<sub><sup>Contributed by: [@meltemkenis](https://github.com/meltemkenis)</sub></sup>
@@ -434,4 +443,4 @@ The questions were generated by GPT-4 based on the "Computer Systems Security: P
434443
inspect eval inspect_evals/agie_lsat_lr
435444
```
436445

437-
<!-- /Eval Listing: Automatically Generated -->
446+
<!-- /Eval Listing: Automatically Generated -->

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ convention = "google"
4040

4141
[tool.pytest.ini_options]
4242
minversion = "7.0"
43-
addopts = "-rA --doctest-modules --color=yes"
43+
addopts = "-rA --doctest-modules --color=yes -m 'not dataset_download'"
4444
testpaths = ["tests"]
4545
doctest_optionflags = ["NORMALIZE_WHITESPACE", "IGNORE_EXCEPTION_DETAIL"]
4646
norecursedirs = [
@@ -51,6 +51,9 @@ norecursedirs = [
5151
asyncio_mode = "auto"
5252
asyncio_default_fixture_loop_scope = "function"
5353
log_level = "warning"
54+
markers = [
55+
"dataset_download: marks tests that download datasets (deselect with '-m \"not dataset_download\"')", # (disabled by default)
56+
]
5457

5558
[tool.mypy]
5659
exclude = [

src/inspect_evals/_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from .mmlu_pro import mmlu_pro
5555
from .mmmu import mmmu_multiple_choice, mmmu_open
5656
from .musr import musr
57+
from .niah import niah
5758
from .paws import paws
5859
from .piqa import piqa
5960
from .pubmedqa import pubmedqa

src/inspect_evals/niah/README.md

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# Inspect Task Evaluation README
2+
3+
## Overview
4+
5+
The Inspect implementation for NIAH (Needle in a Haystack) is designed to evaluate the in-context retrieval ability of long context language models (LLMs). The main goal of this evaluation is to place a random fact or statement (the "needle") in the middle of a long context window (the "haystack") and ask the model to retrieve this statement. The evaluation iterates over various document depths and context lengths to measure performance.
6+
7+
This evaluation was contributed by [Owen Parsons](https://owenparsons.github.io).
8+
9+
## Installation
10+
11+
To get started, install the required Python packages:
12+
13+
```bash
14+
pip install inspect_ai
15+
pip install git+https://github.com/UKGovernmentBEIS/inspect_evals
16+
```
17+
18+
## Running the Evaluation
19+
20+
You can evaluate against models using the following command:
21+
22+
```bash
23+
inspect eval inspect_evals/niah --model openai/gpt-3.5-turbo-16k
24+
```
25+
26+
If you prefer not to specify the `--model` argument each time you run an evaluation, create a `.env` configuration file in your working directory that defines the `INSPECT_EVAL_MODEL` environment variable along with your API key.
27+
28+
### Example of the .env file:
29+
```
30+
INSPECT_EVAL_MODEL=openai/gpt-3.5-turbo-16k
31+
```
32+
33+
You can also use the `-T` tag to define input values for the evaluation. For example:
34+
35+
```bash
36+
inspect eval niah.py --model openai/gpt-3.5-turbo -T n_needles=20
37+
```
38+
39+
## Configuration Variables
40+
41+
Here are the configuration variables used in the NIAH evaluation, along with their default values and types:
42+
43+
| Variable | Type | Default Value | Description |
44+
|-------------------|----------|---------------|-----------------------------------------------------------------------------------------------|
45+
| `min_context` | `int` | `10000` | Minimum context length to evaluate. |
46+
| `max_context` | `int` | `120000` | Maximum context length to evaluate. |
47+
| `n_contexts` | `int` | `15` | The number of contexts to evaluate. |
48+
| `n_positions` | `int` | `15` | The number of positions to evaluate for a given context length. |
49+
| `start_buffer` | `int` | `0` | Buffer at the top of the context to avoid placing needles. (Example: If `start_buffer` is `100`, then the first needle position would aim to be at the 100th token in the context.) |
50+
| `end_buffer` | `int` | `0` | Buffer at the bottom of the context to avoid placing needles. (Example: For a context length of `1000`, if `end_buffer` is `100`, then the final needle position would aim to be at the 900th token in the context.) |
51+
| `n_needles` | `int` | `1` | The number of needles to sample. |
52+
| `sample_method` | `str` | `"fixed"` | Method for sampling the needles. |
53+
| `fixed_index` | `int` | `0` | The index of the needle to use when `sample_method` is set to `"fixed"`. |
54+
| `n_runs` | `int` | `1` | The number of runs for the evaluation. |
55+
56+
## Scoring Metric
57+
58+
This benchmark uses a modified version of ``model_graded_qa()`` that allows for the scorer call to only include the question related to the needle, rather than passing the original prompt during the benchmark task. This is to avoid the model having to handle long context inputs during the scoring.
59+
60+
The scoring criteria is taken from [Greg Kramradt's implementation](https://github.com/gkamradt/LLMTest_NeedleInAHaystack) and is shown below:
61+
62+
```
63+
Score 1: The answer is completely unrelated to the reference.
64+
Score 3: The answer has minor relevance but does not align with the reference.
65+
Score 5: The answer has moderate relevance but contains inaccuracies.
66+
Score 7: The answer aligns with the reference but has minor omissions.
67+
Score 10: The answer is completely accurate and aligns perfectly with the reference.
68+
```
69+
70+
## Dataset
71+
72+
The dataset used for the evaluation is read from [OpenCompass](https://opencompass.readthedocs.io/en/latest/advanced_guides/needleinahaystack_eval.html), and the specific dataset is generated based on the values defined in the configuration.
73+
74+
## Dataset Construction
75+
76+
The final dataset used by the NIAH evaluation is generated using the OpenCompass dataset and the configuration variables by following the series of steps summarised below.
77+
78+
1. Model Encoding: An encoder for the specified model is created using the tiktoken library, which facilitates tokenisation for text processing.
79+
80+
2. Context Length Generation: Context lengths are generated for the specified range of context values given by `min_context`, `max_context` and `n_contexts`.
81+
82+
3. Needle Position Generation: Needle positions are determined across the generated contexts based on the number of needles specified by `n_needles` accounting for a buffer at the start/end of the document if specified by
83+
`start_buffer` or `end_buffer`.
84+
85+
4. Data Retrieval: The relevant haystack and needle datasets are extracted from the Hugging Face dataset repository.
86+
87+
5. Needle Filtering: The needles are filtered to include only English entries.
88+
89+
6. Needle Sampling: Needles are sampled based on the chosen method (fixed, sequential, or random).
90+
91+
7. Needle Adjustment: The needles are repeated and shifted to prepare for multiple runs.
92+
93+
8. Token Length Calculations: The maximum combined token lengths for the needles, questions, and answers are computed, along with the token counts for the main and question prompts.
94+
95+
9. Context Reading and Trimming: The relevant context texts are read, combined and trimmed to match the required context token lengths.
96+
97+
10. Period Indices Calculation: Period indices are determined for each context length to facilitate the insertion of needles at appropriate positions between sentences.
98+
99+
11. Model Context Length Verification: A check is performed to ensure that the model context length is larger than the specified context lengths.
100+
101+
12. Dataset Creation: The final dataset is constructed by combining the contexts, needle insertions, and associated metadata.
102+
103+
13. DataFrame Conversion: The complete dataset is converted into a Pandas DataFrame.
104+
105+
These steps are carried out in the `generate_context_with_needles` function from [utils.py](./utils.py).

src/inspect_evals/niah/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .niah import niah
2+
3+
__all__ = ["niah"]

src/inspect_evals/niah/niah.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
from typing import Literal
2+
3+
import tiktoken
4+
from inspect_ai import Task, task
5+
from inspect_ai.model import get_model
6+
from inspect_ai.solver import (
7+
Generate,
8+
Solver,
9+
TaskState,
10+
generate,
11+
prompt_template,
12+
solver,
13+
)
14+
from tiktoken.core import Encoding
15+
16+
from inspect_evals.niah.utils.dataset_generation import (
17+
ExpParams,
18+
Haystack,
19+
generate_full_context,
20+
get_data_from_hf,
21+
)
22+
from inspect_evals.niah.utils.prompting import (
23+
MAIN_PROMPT,
24+
QUESTION_PROMPT,
25+
create_question_prompt,
26+
)
27+
from inspect_evals.niah.utils.sample_generation import (
28+
generate_samples,
29+
needle_into_haystack,
30+
sample_df_to_dataset,
31+
)
32+
from inspect_evals.niah.utils.scoring import (
33+
custom_model_graded_qa_with_history_scorer,
34+
return_metadata_variable_as_history,
35+
)
36+
from inspect_evals.niah.utils.text_utils import get_model_or_default
37+
38+
# Define a token buffer for max context to avoid potential edge case issues with long outputs
39+
TOKEN_BUFFER = 100
40+
41+
42+
@task
43+
def niah(
44+
min_context: int = 10000,
45+
max_context: int = 120000,
46+
n_contexts: int = 15,
47+
n_positions: int = 15,
48+
start_buffer: int = 0,
49+
end_buffer: int = 0,
50+
n_needles: int = 1,
51+
sample_method: Literal["fixed", "sequential", "random"] = "fixed",
52+
fixed_index: int = 0,
53+
n_runs: int = 1,
54+
) -> Task:
55+
"""
56+
Inspect Task implementation for NIAH (Needle in a Haystack).
57+
58+
This function generates a task that evaluates the model on a dataset with varying context sizes and needle positions.
59+
Needles are inserted into the context to measure the model's ability to retrieve relevant information.
60+
61+
Args:
62+
min_context (int): Minimum context length to evaluate. Default is 10000.
63+
max_context (int): Maximum context length to evaluate. Default is 120000.
64+
n_contexts (int): The number of contexts to evaluate. Default is 15.
65+
n_positions (int): The number of positions to evaluate for a given context length. Default is 15.
66+
start_buffer (int): Buffer at the top of the context to avoid placing needles. Default is 0.
67+
end_buffer (int): Buffer at the bottom of the context to avoid placing needles. Default is 0.
68+
n_needles (int): The number of needles to sample. Default is 1.
69+
sample_method (Literal["fixed", "random"]): Method for sampling the needles.
70+
If "fixed", a single specific needle index is used for all trials.
71+
If "random", a new needle is randomly sampled for each trial.
72+
If "sequential", needles are sequentially sampled across trials.
73+
Default is "fixed".
74+
fixed_index (int): The index of the needle to use when `sample_method` is "fixed" or
75+
the index of the starting position when `sample_method` is "sequential". Default is 0.
76+
n_runs (int): The number of runs per set of experimental parameters. Default is 1.
77+
78+
Returns:
79+
Task: A Task object containing the dataset, the solver configuration, and a custom scorer with metadata handling.
80+
"""
81+
# Get the active model
82+
model = get_model()
83+
84+
# Use default model name for tokenisation if no tokeniser found for current model
85+
tokeniser_model_name = get_model_or_default(model.name)
86+
87+
# Create an encoder for given model
88+
enc = tiktoken.encoding_for_model(tokeniser_model_name)
89+
90+
# Import OpenCompass 'Needle in a Haystack' dataset from HF
91+
hf_data = get_data_from_hf()
92+
93+
# Generate ExpParams object for storing experimental parameters.
94+
exp_params = ExpParams(
95+
min_context=min_context,
96+
max_context=max_context,
97+
n_contexts=n_contexts,
98+
n_positions=n_positions,
99+
start_buffer=start_buffer,
100+
end_buffer=end_buffer,
101+
n_needles=n_needles,
102+
sample_method=sample_method,
103+
fixed_index=fixed_index,
104+
n_runs=n_runs,
105+
main_prompt=MAIN_PROMPT,
106+
question_prompt=QUESTION_PROMPT,
107+
token_buffer=TOKEN_BUFFER,
108+
model_name=model.name,
109+
)
110+
111+
# Generate the haystack for the largest context length required. Smaller context lengths will trim haystack.
112+
haystack = generate_full_context(hf_data, exp_params, enc)
113+
114+
# Generate a DataFrame with Sample information
115+
samples_df = generate_samples(hf_data, haystack, exp_params, enc)
116+
117+
# Convert Sample DataFrame to Dataset
118+
dataset = sample_df_to_dataset(samples_df)
119+
120+
# Return the Task
121+
return Task(
122+
dataset=dataset,
123+
solver=[
124+
add_to_haystack(
125+
haystack, enc
126+
), # Take needle and other information from Sample to generate combined haystack and needle text.
127+
prompt_template(MAIN_PROMPT),
128+
generate(),
129+
],
130+
# Custom wrapper used to allow for grouped scoring and parsing metadata to scorer
131+
scorer=custom_model_graded_qa_with_history_scorer(
132+
include_history=return_metadata_variable_as_history,
133+
),
134+
)
135+
136+
137+
@solver
138+
def add_to_haystack(haystack: Haystack, enc: Encoding) -> Solver:
139+
"""
140+
Custom solver function.
141+
142+
Inserts a specified prompt (needle) into a larger context (haystack) string based on provided Sample metadata parameters
143+
144+
Args:
145+
haystack (Haystack): Haystack object containing complete context (haystack) in which the needle will be embedded.
146+
enc (Encoding): The tokeniser encoding object, used to convert text to tokens.
147+
148+
Returns:
149+
Solver: An asynchronous solver function that takes `TaskState` and `Generate`
150+
and returns an updated `TaskState` with prompt text.
151+
"""
152+
153+
async def solve(state: TaskState, generate: Generate) -> TaskState:
154+
prompt = state.user_prompt
155+
metadata = state.metadata
156+
full_context = needle_into_haystack(
157+
haystack.encoded_context,
158+
prompt.text,
159+
metadata["haystack_length"],
160+
metadata["position"],
161+
enc,
162+
)
163+
prompt.text = create_question_prompt(full_context, metadata["needle_question"])
164+
return state
165+
166+
return solve

0 commit comments

Comments
 (0)