|
| 1 | +"""SEvenLLM: Benchmarking, Eliciting, and Enhancing Abilities of Large Language Models in Cyber Threat Intelligence |
| 2 | +
|
| 3 | +Hangyuan Ji, Jian Yang, Linzheng Chai, Chaoren Wei, Liqun Yang, |
| 4 | +Yunlong Duan, Yunli Wang, Tianzhen Sun, Hongcheng Guo, Tongliang Li, |
| 5 | +Changyu Ren, Zhoujun Li |
| 6 | +
|
| 7 | +https://arxiv.org/abs/2405.03446 |
| 8 | +
|
| 9 | +# Eval for MCQs (Understanding) in Simplified Chinese (Zh) |
| 10 | +inspect eval inspect_evals/sevenllm_mcq_zh |
| 11 | +
|
| 12 | +# Eval for MCQs (Understanding) in English (En) |
| 13 | +inspect eval inspect_evals/sevenllm_mcq_en |
| 14 | +
|
| 15 | +# Eval for QAs (Generation) in Simplified Chinese (Zh) |
| 16 | +inspect eval inspect_evals/sevenllm_qa_zh |
| 17 | +
|
| 18 | +# Eval for QA (Generation) in English (En) |
| 19 | +inspect eval inspect_evals/sevenllm_qa_en |
| 20 | +""" |
| 21 | + |
| 22 | +import json |
| 23 | +import re |
| 24 | +from typing import Any |
| 25 | + |
| 26 | +from inspect_ai import Task, task |
| 27 | +from inspect_ai.dataset import Dataset, Sample, json_dataset |
| 28 | +from inspect_ai.scorer import choice |
| 29 | +from inspect_ai.solver import generate, multiple_choice, prompt_template |
| 30 | + |
| 31 | +from inspect_evals.sevenllm.scorers import rouge_l_scorer, semantic_similarity_scorer |
| 32 | + |
| 33 | +BENCHMARK_DATASET_URL = "https://huggingface.co/datasets/Multilingual-Multimodal-NLP/SEVENLLM-Dataset/raw/main/test.jsonl" |
| 34 | +# This prompt has been inspired from the inference work done by the authors. |
| 35 | +# See: https://github.com/CSJianYang/SEevenLLM/blob/main/infer.py |
| 36 | +# The prompt remains the same for both the languages. |
| 37 | +TEMPLATE = r""" |
| 38 | +Below is an instruction that describes a task, paired with an input that provides further context. |
| 39 | +Write a response that appropriately completes the request. If you are given options (A, B, C, D), |
| 40 | +return just the option letter and not the entire text. If you're required to return a JSON, only return a JSON and |
| 41 | +nothing else. |
| 42 | +
|
| 43 | +{prompt} |
| 44 | +""" |
| 45 | + |
| 46 | + |
| 47 | +@task |
| 48 | +def sevenllm_mcq_zh() -> Task: |
| 49 | + """Inspect task implementing the SEvenLLM benchmark for MCQs in Simplified Chinese.""" |
| 50 | + return Task( |
| 51 | + dataset=get_sevenllm_dataset(language="zh", data_format="mcq"), |
| 52 | + solver=[prompt_template(template=TEMPLATE), multiple_choice()], |
| 53 | + scorer=choice(), |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +@task |
| 58 | +def sevenllm_mcq_en() -> Task: |
| 59 | + """Inspect task implementing the SEvenLLM benchmark for MCQs in English.""" |
| 60 | + return Task( |
| 61 | + dataset=get_sevenllm_dataset(language="en", data_format="mcq"), |
| 62 | + solver=[prompt_template(template=TEMPLATE), multiple_choice()], |
| 63 | + scorer=choice(), |
| 64 | + ) |
| 65 | + |
| 66 | + |
| 67 | +@task |
| 68 | +def sevenllm_qa_zh() -> Task: |
| 69 | + """Inspect task implementing the SEvenLLM benchmark for QA in Simplified Chinese.""" |
| 70 | + return Task( |
| 71 | + dataset=get_sevenllm_dataset(language="zh", data_format="qa"), |
| 72 | + solver=[prompt_template(template=TEMPLATE), generate()], |
| 73 | + scorer=[rouge_l_scorer(is_zh=True), semantic_similarity_scorer()], |
| 74 | + ) |
| 75 | + |
| 76 | + |
| 77 | +@task |
| 78 | +def sevenllm_qa_en() -> Task: |
| 79 | + """Inspect task implementing the SEvenLLM benchmark for QA in English.""" |
| 80 | + return Task( |
| 81 | + dataset=get_sevenllm_dataset(language="en", data_format="qa"), |
| 82 | + solver=[prompt_template(template=TEMPLATE), generate()], |
| 83 | + scorer=[rouge_l_scorer(is_zh=False), semantic_similarity_scorer()], |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +def contains_zsh(text: str) -> bool: |
| 88 | + """Return True if the text contains a simplified-chinese character.""" |
| 89 | + # Regular expression to match Simplified Chinese characters |
| 90 | + # CJK Unified Ideographs range: \u4e00-\u9fff |
| 91 | + pattern = re.compile(r"[\u4e00-\u9fff]") |
| 92 | + |
| 93 | + return bool(pattern.search(text)) |
| 94 | + |
| 95 | + |
| 96 | +def record_to_sample(record: dict[str, Any]) -> Sample: |
| 97 | + """Applies transformations to each record in the dataset for the Task.""" |
| 98 | + instruction = record["instruction"] |
| 99 | + record_format = "qa" if isinstance(instruction, str) else "mcq" |
| 100 | + text = instruction if isinstance(instruction, str) else instruction["question"] |
| 101 | + record_language = "zh" if contains_zsh(text) else "en" |
| 102 | + |
| 103 | + sample = { |
| 104 | + "id": record["id"], |
| 105 | + "input": f"{text}\n\n{record['input']}\n\n", |
| 106 | + "metadata": { |
| 107 | + "category": record["category"], |
| 108 | + "cot": record["thought"], |
| 109 | + "language": record_language, |
| 110 | + "format": record_format, |
| 111 | + }, |
| 112 | + "target": json.dumps(record["output"], ensure_ascii=False) |
| 113 | + if record_format == "qa" |
| 114 | + else str(record["output"]), |
| 115 | + } |
| 116 | + |
| 117 | + if record_format == "mcq": |
| 118 | + sample["choices"] = [ |
| 119 | + str(instruction["choice"]["A"]), |
| 120 | + str(instruction["choice"]["B"]), |
| 121 | + str(instruction["choice"]["C"]), |
| 122 | + str(instruction["choice"]["D"]), |
| 123 | + ] |
| 124 | + |
| 125 | + return Sample(**sample) |
| 126 | + |
| 127 | + |
| 128 | +def get_sevenllm_dataset(language: str, data_format: str) -> Dataset: |
| 129 | + """Get a filtered dataset from the SEvenLLM benchmark based on language and format.""" |
| 130 | + # Cannot use `hf_dataset` here because the benchmark jsonl file contains json of |
| 131 | + # multiple schemas, which violates PyArrow validation checks. |
| 132 | + dataset = json_dataset( |
| 133 | + json_file=BENCHMARK_DATASET_URL, |
| 134 | + sample_fields=record_to_sample, |
| 135 | + ) |
| 136 | + |
| 137 | + return dataset.filter( |
| 138 | + lambda sample: sample.metadata["format"] == data_format |
| 139 | + and sample.metadata["language"] == language |
| 140 | + ) |
0 commit comments