|
| 1 | +from typing import Any |
| 2 | + |
| 3 | +from inspect_ai.dataset import Sample |
| 4 | + |
| 5 | +from inspect_evals.sec_qa import record_to_sample |
| 6 | +from inspect_evals.sec_qa.sec_qa import format_system_message_fewshot, sample_to_fewshot |
| 7 | + |
| 8 | +EXAMPLE_RECORD: dict[str, Any] = { |
| 9 | + "Question": "Which of the following is a common indicator of an SQL injection attack?", |
| 10 | + "A": "Frequent changes in user account permissions.", |
| 11 | + "B": "Decreased performance of endpoint protection systems.", |
| 12 | + "C": "Unusually high data upload rates to a web server.", |
| 13 | + "D": "Sudden uptick in SQL queries, far beyond the usual baseline for the application.", |
| 14 | + "Answer": "D", |
| 15 | +} |
| 16 | + |
| 17 | +EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT = "QUESTION:\nWhich of the following is a common indicator of an SQL injection attack?\n\nCHOICES:\nA) Frequent changes in user account permissions.\nB) Decreased performance of endpoint protection systems.\nC) Unusually high data upload rates to a web server.\nD) Sudden uptick in SQL queries, far beyond the usual baseline for the application.\n\n\nANSWER: D" |
| 18 | + |
| 19 | + |
| 20 | +def test_record_to_sample(): |
| 21 | + """Test that the record is mapped correctly to Inspect Sample""" |
| 22 | + example_sample = record_to_sample(EXAMPLE_RECORD) |
| 23 | + assert example_sample == Sample( |
| 24 | + input="Which of the following is a common indicator of an SQL injection attack?", |
| 25 | + choices=[ |
| 26 | + "Frequent changes in user account permissions.", |
| 27 | + "Decreased performance of endpoint protection systems.", |
| 28 | + "Unusually high data upload rates to a web server.", |
| 29 | + "Sudden uptick in SQL queries, far beyond the usual baseline for the application.", |
| 30 | + ], |
| 31 | + target="D", |
| 32 | + ) |
| 33 | + |
| 34 | + |
| 35 | +def test_format_system_message_fewshot(): |
| 36 | + """Test that the system message is formatted correctly for some samples""" |
| 37 | + formatted_system_message = format_system_message_fewshot( |
| 38 | + [record_to_sample(EXAMPLE_RECORD), record_to_sample(EXAMPLE_RECORD)] |
| 39 | + ) |
| 40 | + assert ( |
| 41 | + formatted_system_message |
| 42 | + == "The following are multiple choice questions about Computer Security. Some examples are provided below." |
| 43 | + + "\n\n" |
| 44 | + + EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT |
| 45 | + + "\n\n" |
| 46 | + + EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT |
| 47 | + ) |
| 48 | + |
| 49 | + |
| 50 | +def test_sample_to_fewshot(): |
| 51 | + """Test that the sample is formatted correctly""" |
| 52 | + formatted_sample_to_fewshot = sample_to_fewshot(record_to_sample(EXAMPLE_RECORD)) |
| 53 | + assert formatted_sample_to_fewshot == EXPECTED_FORMATTED_SAMPLE_TO_FEWSHOT |
0 commit comments