Skip to content

Commit 625c510

Browse files
committed
Added v2 fewshot and refactored
1 parent a07f79c commit 625c510

File tree

3 files changed

+52
-54
lines changed

3 files changed

+52
-54
lines changed

src/inspect_evals/_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from .piqa import piqa
5050
from .pubmedqa import pubmedqa
5151
from .race_h import race_h
52-
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2
52+
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2, sec_qa_v2_5_shot
5353
from .squad import squad
5454
from .swe_bench import swe_bench
5555
from .truthfulqa import truthfulqa
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2
1+
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2, sec_qa_v2_5_shot
22

3-
__all__ = ["sec_qa_v1", "sec_qa_v1_5_shot", "sec_qa_v2"]
3+
__all__ = ["sec_qa_v1", "sec_qa_v1_5_shot", "sec_qa_v2", "sec_qa_v2_5_shot"]

src/inspect_evals/sec_qa/sec_qa.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,66 +8,40 @@
88
from typing import Any
99

1010
from inspect_ai import Task, task
11-
from inspect_ai.dataset import Sample, hf_dataset
11+
from inspect_ai.dataset import Dataset, Sample, hf_dataset
1212
from inspect_ai.scorer import choice
1313
from inspect_ai.solver import multiple_choice, system_message
1414

1515
# setup for problem + instructions for providing answer
16-
SYSTEM_MESSAGE = """
17-
The following are multiple choice questions about Computer Security.
18-
""".strip()
19-
20-
SYSTEM_MESSAGE_FEWSHOT = (
21-
SYSTEM_MESSAGE
22-
+ """
23-
Some examples are provided below.
16+
SYSTEM_MESSAGE_FEWSHOT = """
17+
The following are multiple choice questions about Computer Security. Some examples are provided below.
2418
2519
{examples}
2620
""".strip()
27-
)
21+
22+
DATASET_PATH = "zefang-liu/secqa"
23+
DATASET_SUBSET_NAME_V1 = "secqa_v1"
24+
DATASET_SUBSET_NAME_V2 = "secqa_v2"
25+
DATASET_TEST_SPLIT = "test"
26+
DATASET_FEWSHOT_SPLIT = "dev"
2827

2928

3029
@task
3130
def sec_qa_v1() -> Task:
3231
"""Inspect Task implementing the SecQA benchmark v1 0-shot"""
33-
# dataset
34-
dataset = hf_dataset(
35-
path="zefang-liu/secqa",
36-
name="secqa_v1", # aka subset
37-
split="dev",
38-
sample_fields=record_to_sample,
39-
trust=True,
40-
)
41-
42-
# define task
32+
dataset = retrieve_hf_dataset(DATASET_SUBSET_NAME_V1, DATASET_TEST_SPLIT)
4333
return Task(
4434
dataset=dataset,
45-
solver=[system_message(SYSTEM_MESSAGE), multiple_choice()],
35+
solver=[multiple_choice()],
4636
scorer=choice(),
4737
)
4838

4939

5040
@task
5141
def sec_qa_v1_5_shot() -> Task:
5242
"""Inspect Task implementing the SecQA benchmark v1 5-shot"""
53-
fewshot_samples = hf_dataset(
54-
path="zefang-liu/secqa",
55-
name="secqa_v1", # aka subset
56-
split="dev",
57-
sample_fields=record_to_sample,
58-
trust=True,
59-
)
60-
61-
# dataset
62-
dataset = hf_dataset(
63-
path="zefang-liu/secqa",
64-
name="secqa_v1", # aka subset
65-
split="val",
66-
sample_fields=record_to_sample,
67-
trust=True,
68-
)
69-
70-
# define task
43+
fewshot_samples = retrieve_hf_dataset(DATASET_SUBSET_NAME_V1, DATASET_FEWSHOT_SPLIT)
44+
dataset = retrieve_hf_dataset(DATASET_SUBSET_NAME_V1, DATASET_TEST_SPLIT)
7145
return Task(
7246
dataset=dataset,
7347
solver=[
@@ -87,23 +61,45 @@ def sec_qa_v1_5_shot() -> Task:
8761
@task
8862
def sec_qa_v2() -> Task:
8963
"""Inspect Task implementing the SecQA benchmark v2 0-shot"""
90-
# dataset
91-
dataset = hf_dataset(
92-
path="zefang-liu/secqa",
93-
name="secqa_v2", # aka subset
94-
split="dev",
95-
sample_fields=record_to_sample,
96-
trust=True,
64+
dataset = retrieve_hf_dataset(DATASET_SUBSET_NAME_V2, DATASET_TEST_SPLIT)
65+
return Task(
66+
dataset=dataset,
67+
solver=[multiple_choice()],
68+
scorer=choice(),
9769
)
9870

99-
# define task
71+
72+
@task
73+
def sec_qa_v2_5_shot() -> Task:
74+
"""Inspect Task implementing the SecQA benchmark v2 5-shot"""
75+
fewshot_samples = retrieve_hf_dataset(DATASET_SUBSET_NAME_V2, DATASET_FEWSHOT_SPLIT)
76+
dataset = retrieve_hf_dataset(DATASET_SUBSET_NAME_V2, DATASET_TEST_SPLIT)
10077
return Task(
10178
dataset=dataset,
102-
solver=[system_message(SYSTEM_MESSAGE), multiple_choice()],
79+
solver=[
80+
system_message(
81+
SYSTEM_MESSAGE_FEWSHOT.format(
82+
examples="\n\n".join(
83+
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
84+
)
85+
)
86+
),
87+
multiple_choice(),
88+
],
10389
scorer=choice(),
10490
)
10591

10692

93+
def retrieve_hf_dataset(name: str, split: str) -> Dataset:
94+
return hf_dataset(
95+
path=DATASET_PATH,
96+
name=name,
97+
split=split,
98+
sample_fields=record_to_sample,
99+
trust=True,
100+
)
101+
102+
107103
def record_to_sample(record: dict[str, Any]) -> Sample:
108104
return Sample(
109105
input=record["Question"],
@@ -114,10 +110,12 @@ def record_to_sample(record: dict[str, Any]) -> Sample:
114110

115111
def sample_to_fewshot(sample: Sample) -> str:
116112
prob_str = f"""QUESTION:\n{sample.input}"""
117-
labeled_choices = []
113+
labeled_choices = "\n"
118114
for i, letter_label in enumerate(["A", "B", "C", "D"]):
119-
labeled_choice = f"{letter_label}: {sample.choices[i] if sample.choices is not None else None}"
120-
labeled_choices.append(labeled_choice)
115+
labeled_choices = (
116+
labeled_choices
117+
+ f"{letter_label}: {sample.choices[i] if sample.choices is not None else None}\n"
118+
)
121119
choices_str = f"""CHOICES:{labeled_choices}"""
122120
ans_str = f"""ANSWER: {sample.target}"""
123121

0 commit comments

Comments
 (0)