Skip to content

Commit aa969fc

Browse files
committed
unit tests
1 parent 625c510 commit aa969fc

File tree

3 files changed

+82
-17
lines changed

3 files changed

+82
-17
lines changed
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,19 @@
1-
from .sec_qa import sec_qa_v1, sec_qa_v1_5_shot, sec_qa_v2, sec_qa_v2_5_shot
1+
from .sec_qa import (
2+
format_system_message_fewshot,
3+
record_to_sample,
4+
sample_to_fewshot,
5+
sec_qa_v1,
6+
sec_qa_v1_5_shot,
7+
sec_qa_v2,
8+
sec_qa_v2_5_shot,
9+
)
210

3-
__all__ = ["sec_qa_v1", "sec_qa_v1_5_shot", "sec_qa_v2", "sec_qa_v2_5_shot"]
11+
__all__ = [
12+
"format_system_message_fewshot",
13+
"record_to_sample",
14+
"sample_to_fewshot",
15+
"sec_qa_v1",
16+
"sec_qa_v1_5_shot",
17+
"sec_qa_v2",
18+
"sec_qa_v2_5_shot",
19+
]

src/inspect_evals/sec_qa/sec_qa.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,7 @@ def sec_qa_v1_5_shot() -> Task:
4545
return Task(
4646
dataset=dataset,
4747
solver=[
48-
system_message(
49-
SYSTEM_MESSAGE_FEWSHOT.format(
50-
examples="\n\n".join(
51-
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
52-
)
53-
)
54-
),
48+
system_message(format_system_message_fewshot(fewshot_samples)),
5549
multiple_choice(),
5650
],
5751
scorer=choice(),
@@ -77,13 +71,7 @@ def sec_qa_v2_5_shot() -> Task:
7771
return Task(
7872
dataset=dataset,
7973
solver=[
80-
system_message(
81-
SYSTEM_MESSAGE_FEWSHOT.format(
82-
examples="\n\n".join(
83-
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
84-
)
85-
)
86-
),
74+
system_message(format_system_message_fewshot(fewshot_samples)),
8775
multiple_choice(),
8876
],
8977
scorer=choice(),
@@ -108,13 +96,21 @@ def record_to_sample(record: dict[str, Any]) -> Sample:
10896
)
10997

11098

99+
def format_system_message_fewshot(fewshot_samples: Dataset) -> str:
100+
return SYSTEM_MESSAGE_FEWSHOT.format(
101+
examples="\n\n".join(
102+
[sample_to_fewshot(sample=sample) for sample in fewshot_samples]
103+
)
104+
)
105+
106+
111107
def sample_to_fewshot(sample: Sample) -> str:
112108
prob_str = f"""QUESTION:\n{sample.input}"""
113109
labeled_choices = "\n"
114110
for i, letter_label in enumerate(["A", "B", "C", "D"]):
115111
labeled_choices = (
116112
labeled_choices
117-
+ f"{letter_label}: {sample.choices[i] if sample.choices is not None else None}\n"
113+
+ f"{letter_label}) {sample.choices[i] if sample.choices is not None else None}\n"
118114
)
119115
choices_str = f"""CHOICES:{labeled_choices}"""
120116
ans_str = f"""ANSWER: {sample.target}"""

tests/sec_qa/test_sec_qa.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from typing import Any
2+
3+
from inspect_ai.dataset import Sample
4+
5+
from inspect_evals.sec_qa import record_to_sample
6+
from inspect_evals.sec_qa.sec_qa import format_system_message_fewshot, sample_to_fewshot
7+
8+
EXAMPLE_RECORD: dict[str, Any] = {
9+
"Question": "Which of the following is a common indicator of an SQL injection attack?",
10+
"A": "Frequent changes in user account permissions.",
11+
"B": "Decreased performance of endpoint protection systems.",
12+
"C": "Unusually high data upload rates to a web server.",
13+
"D": "Sudden uptick in SQL queries, far beyond the usual baseline for the application.",
14+
"Answer": "D",
15+
}
16+
17+
EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT = "QUESTION:\nWhich of the following is a common indicator of an SQL injection attack?\n\nCHOICES:\nA) Frequent changes in user account permissions.\nB) Decreased performance of endpoint protection systems.\nC) Unusually high data upload rates to a web server.\nD) Sudden uptick in SQL queries, far beyond the usual baseline for the application.\n\n\nANSWER: D"
18+
19+
20+
def test_record_to_sample():
21+
"""Test that the record is mapped correctly to Inspect Sample"""
22+
example_sample = record_to_sample(EXAMPLE_RECORD)
23+
assert example_sample == Sample(
24+
input="Which of the following is a common indicator of an SQL injection attack?",
25+
choices=[
26+
"Frequent changes in user account permissions.",
27+
"Decreased performance of endpoint protection systems.",
28+
"Unusually high data upload rates to a web server.",
29+
"Sudden uptick in SQL queries, far beyond the usual baseline for the application.",
30+
],
31+
target="D",
32+
)
33+
34+
35+
def test_format_system_message_fewshot():
36+
"""Test that the system message is formatted correctly for some samples"""
37+
formatted_system_message = format_system_message_fewshot(
38+
[record_to_sample(EXAMPLE_RECORD), record_to_sample(EXAMPLE_RECORD)]
39+
)
40+
assert (
41+
formatted_system_message
42+
== "The following are multiple choice questions about Computer Security. Some examples are provided below."
43+
+ "\n\n"
44+
+ EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT
45+
+ "\n\n"
46+
+ EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT
47+
)
48+
49+
50+
def test_sample_to_fewshot():
51+
"""Test that the sample is formatted correctly"""
52+
formatted_sample_to_fewshot = sample_to_fewshot(record_to_sample(EXAMPLE_RECORD))
53+
assert formatted_sample_to_fewshot == EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT

0 commit comments

Comments
 (0)