|
| 1 | +""" |
| 2 | +ClassEval: Class level Python code evaluation |
| 3 | +
|
| 4 | +Based on the paper https://arxiv.org/pdf/2308.01861 . |
| 5 | +The datasets can be found either on https://huggingface.co/datasets/FudanSELab/ClassEval |
| 6 | +or on https://github.com/FudanSELab/ClassEval |
| 7 | +
|
| 8 | +This is an inspect_ai implementation of the paper. |
| 9 | +""" |
| 10 | + |
| 11 | +import re |
| 12 | +from typing import Any |
| 13 | + |
| 14 | +from inspect_ai import Task, task, Epochs |
| 15 | +from inspect_ai.dataset import Sample, hf_dataset |
| 16 | +from inspect_ai.scorer import ( |
| 17 | + CORRECT, |
| 18 | + INCORRECT, |
| 19 | + Score, |
| 20 | + Scorer, |
| 21 | + Target, |
| 22 | + scorer, |
| 23 | + mean, |
| 24 | + std, |
| 25 | +) |
| 26 | +from inspect_ai.solver import ( |
| 27 | + generate, |
| 28 | + Solver, |
| 29 | + system_message, |
| 30 | + TaskState, |
| 31 | +) |
| 32 | +from inspect_ai.util import ExecResult, sandbox |
| 33 | + |
| 34 | +from .utils import construct_prompt |
| 35 | + |
| 36 | +# Timeout for scoring. |
| 37 | +VERIFY_TIMEOUT = 30 |
| 38 | + |
| 39 | +@task |
| 40 | +def class_eval(few_shot: int = 1, few_shot_seed: int = 42) -> Task: |
| 41 | + """Inspect Task implementation of ClassEval. |
| 42 | +
|
| 43 | + Args: |
| 44 | + k_shot (int): The number of few shots to include. |
| 45 | + k_shot_seed (int): The seed for generating few shots. |
| 46 | + solver (Solver): The solver to use for this evaluation. Defaults to the default solver. |
| 47 | + instruction_prompt (String): The prompt to prepend to the code problem. |
| 48 | + scorer (Scorer): The scorer to use for this evaluation. Defaults to the default scorer. |
| 49 | + """ |
| 50 | + |
| 51 | + dataset = hf_dataset( |
| 52 | + path="FudanSELab/ClassEval", |
| 53 | + split="test", |
| 54 | + sample_fields=record_to_sample |
| 55 | + ) |
| 56 | + |
| 57 | + INSTRUCTION = """ |
| 58 | +
|
| 59 | + You are an expert Python programmer. You will be given a task, and the tests that your code must pass. |
| 60 | +
|
| 61 | + """ |
| 62 | + |
| 63 | + return Task( |
| 64 | + dataset=dataset, |
| 65 | + solver=[system_message(INSTRUCTION), generate()], |
| 66 | + scorer=class_eval_scorer(), |
| 67 | + epochs = Epochs(few_shot, [f"pass_at_{few_shot}"]), |
| 68 | + sandbox="docker", |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +@scorer(metrics=[mean(), std()]) |
| 73 | +def class_eval_scorer() -> Scorer: |
| 74 | + ''' |
| 75 | + This is the scorer for class eval. It will first identify the python code in the output of |
| 76 | + the model. Then we append the test cases to the code and execute it. If the code passes all the test cases, |
| 77 | + we return a CORRECT score. Otherwise, we return an INCORRECT score. |
| 78 | + ''' |
| 79 | + |
| 80 | + async def score(state:TaskState, target: Target) -> Score: |
| 81 | + |
| 82 | + result = {} |
| 83 | + |
| 84 | + generated_code = find_code(state.output.completion) |
| 85 | + code = generated_code + "\n" + state.metadata["test"] |
| 86 | + |
| 87 | + explanation = "" |
| 88 | + explanation += "The following code was executed:\n\n```python\n" |
| 89 | + explanation += code |
| 90 | + explanation += "\n```\n" |
| 91 | + |
| 92 | + try: |
| 93 | + result = await sandbox().exec( |
| 94 | + cmd=["python", "-c", code], |
| 95 | + timeout=VERIFY_TIMEOUT, |
| 96 | + ) |
| 97 | + |
| 98 | + if result.success: |
| 99 | + explanation += "All test cases passed.\n" |
| 100 | + else: |
| 101 | + explanation += "Code did not pass all test cases.\n" |
| 102 | + if result.stderr: |
| 103 | + explanation += f"See details below.\n ```python\n {result.stderr} \n ```\n" |
| 104 | + except TimeoutError: |
| 105 | + result = ExecResult(False, 1, "", "Verification timed out.") |
| 106 | + explanation += "Verification timed out." |
| 107 | + |
| 108 | + return Score( |
| 109 | + value=CORRECT if result.success else INCORRECT, |
| 110 | + answer=generated_code, |
| 111 | + explanation=explanation, |
| 112 | + ) |
| 113 | + |
| 114 | + return score |
| 115 | + |
| 116 | + |
| 117 | +def find_code(completion: str) -> str: |
| 118 | + """Remove Markdown formatting around generated code blocks.""" |
| 119 | + pattern = re.compile(r"```python\n(.*?)```", re.DOTALL) |
| 120 | + matches = pattern.findall(completion) |
| 121 | + extracted_answer = matches[0] if len(matches) >= 1 else completion |
| 122 | + |
| 123 | + return str(extracted_answer) |
| 124 | + |
| 125 | + |
| 126 | +def record_to_sample( |
| 127 | + record: dict[str, Any], |
| 128 | +) -> Sample: |
| 129 | + ''' |
| 130 | + Maps class_eval record into inspect sample |
| 131 | + ''' |
| 132 | + return Sample( |
| 133 | + input = construct_prompt(record), |
| 134 | + target = record["solution_code"], |
| 135 | + id = record["task_id"], |
| 136 | + metadata={ |
| 137 | + "task_id": record["task_id"], |
| 138 | + "skeleton": record["skeleton"], |
| 139 | + "test": record["test"], |
| 140 | + "solution_code": record["solution_code"], |
| 141 | + "import_statement": record["import_statement"], |
| 142 | + "class_description": record["class_description"], |
| 143 | + "methods_info": record["methods_info"], |
| 144 | + "class_name": record["class_name"], |
| 145 | + "test_classes": record["test_classes"], |
| 146 | + "class_constructor": record["class_constructor"], |
| 147 | + "fields": record["fields"], |
| 148 | + } |
| 149 | + ) |
0 commit comments