Skip to content

Commit a07f79c

Browse files
committed
Implemented the v1 5-shot
1 parent b844272 commit a07f79c

File tree

3 files changed

+63
-5
lines changed

3 files changed

+63
-5
lines changed

src/inspect_evals/_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from .piqa import piqa
5050
from .pubmedqa import pubmedqa
5151
from .race_h import race_h
52-
from .sec_qa import sec_qa_v1, sec_qa_v2
52+
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2
5353
from .squad import squad
5454
from .swe_bench import swe_bench
5555
from .truthfulqa import truthfulqa
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .sec_qa import sec_qa_v1, sec_qa_v2
1+
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2
22

3-
__all__ = ["sec_qa_v1", "sec_qa_v2"]
3+
__all__ = ["sec_qa_v1", "sec_qa_v1_5_shot", "sec_qa_v2"]

src/inspect_evals/sec_qa/sec_qa.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@
1717
The following are multiple choice questions about Computer Security.
1818
""".strip()
1919

20+
SYSTEM_MESSAGE_FEWSHOT = (
21+
SYSTEM_MESSAGE
22+
+ """
23+
Some examples are provided below.
24+
25+
{examples}
26+
""".strip()
27+
)
28+
2029

2130
@task
2231
def sec_qa_v1() -> Task:
23-
"""Inspect Task implementing the SecQA benchmark v1"""
32+
"""Inspect Task implementing the SecQA benchmark v1 0-shot"""
2433
# dataset
2534
dataset = hf_dataset(
2635
path="zefang-liu/secqa",
@@ -38,9 +47,46 @@ def sec_qa_v1() -> Task:
3847
)
3948

4049

50+
@task
51+
def sec_qa_v1_5_shot() -> Task:
52+
"""Inspect Task implementing the SecQA benchmark v1 5-shot"""
53+
fewshot_samples = hf_dataset(
54+
path="zefang-liu/secqa",
55+
name="secqa_v1", # aka subset
56+
split="dev",
57+
sample_fields=record_to_sample,
58+
trust=True,
59+
)
60+
61+
# dataset
62+
dataset = hf_dataset(
63+
path="zefang-liu/secqa",
64+
name="secqa_v1", # aka subset
65+
split="val",
66+
sample_fields=record_to_sample,
67+
trust=True,
68+
)
69+
70+
# define task
71+
return Task(
72+
dataset=dataset,
73+
solver=[
74+
system_message(
75+
SYSTEM_MESSAGE_FEWSHOT.format(
76+
examples="\n\n".join(
77+
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
78+
)
79+
)
80+
),
81+
multiple_choice(),
82+
],
83+
scorer=choice(),
84+
)
85+
86+
4187
@task
4288
def sec_qa_v2() -> Task:
43-
"""Inspect Task implementing the SecQA benchmark v2"""
89+
"""Inspect Task implementing the SecQA benchmark v2 0-shot"""
4490
# dataset
4591
dataset = hf_dataset(
4692
path="zefang-liu/secqa",
@@ -64,3 +110,15 @@ def record_to_sample(record: dict[str, Any]) -> Sample:
64110
choices=[record["A"], record["B"], record["C"], record["D"]],
65111
target=record["Answer"],
66112
)
113+
114+
115+
def sample_to_fewshot(sample: Sample) -> str:
116+
prob_str = f"""QUESTION:\n{sample.input}"""
117+
labeled_choices = []
118+
for i, letter_label in enumerate(["A", "B", "C", "D"]):
119+
labeled_choice = f"{letter_label}: {sample.choices[i] if sample.choices is not None else None}"
120+
labeled_choices.append(labeled_choice)
121+
choices_str = f"""CHOICES:{labeled_choices}"""
122+
ans_str = f"""ANSWER: {sample.target}"""
123+
124+
return f"""{prob_str}\n\n{choices_str}\n\n{ans_str}"""

0 commit comments

Comments
 (0)