Skip to content

Commit 0835b12

Browse files
adressed issues in class_eval.py
1 parent ccbbc37 commit 0835b12

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

src/inspect_evals/class_eval/class_eval.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,11 @@
3333

3434
from utils import construct_prompt
3535

36-
37-
INSTRUCTION = """
38-
39-
You are an expert Python programmer. You will be given a task, and the tests that your code must pass.
40-
41-
"""
42-
4336
# Timeout for scoring.
4437
VERIFY_TIMEOUT = 30
4538

4639
@task
47-
def class_eval(
48-
k_shot: int = 1,
49-
solver: Solver | None = None,
50-
instruction_prompt: str = INSTRUCTION,
51-
scorer: Scorer | list[Scorer] | None = None,
52-
) -> Task:
40+
def class_eval(few_shot: int = 1, few_shot_seed: int = 42) -> Task:
5341
"""Inspect Task implementation of ClassEval.
5442
5543
Args:
@@ -63,24 +51,35 @@ def class_eval(
6351
dataset = hf_dataset(
6452
path="FudanSELab/ClassEval",
6553
split="test",
66-
sample_fields=record_to_sample,
54+
sample_fields=record_to_sample
6755
)
6856

69-
solver = [system_message(instruction_prompt), generate()]
57+
INSTRUCTION = """
58+
59+
You are an expert Python programmer. You will be given a task, and the tests that your code must pass.
60+
61+
"""
62+
63+
solver = [system_message(INSTRUCTION), generate()]
7064

7165
scorer = class_eval_scorer()
7266

7367
return Task(
7468
dataset=dataset,
7569
solver=solver or generate(),
7670
scorer=scorer,
77-
epochs = Epochs(k_shot, [f"pass_at_{k_shot}"]),
71+
epochs = Epochs(few_shot, [f"pass_at_{few_shot}"]),
7872
sandbox="docker",
7973
)
8074

8175

8276
@scorer(metrics=[mean(), std()])
8377
def class_eval_scorer() -> Scorer:
78+
'''
79+
This is the scorer for class eval. It will first identify the python code in the output of
80+
the model. Then we append the test cases to the code and execute it. If the code passes all the test cases,
81+
we return a CORRECT score. Otherwise, we return an INCORRECT score.
82+
'''
8483

8584
async def score(state:TaskState, target: Target) -> Score:
8685

@@ -105,10 +104,7 @@ async def score(state:TaskState, target: Target) -> Score:
105104
else:
106105
explanation += "Code did not pass all test cases.\n"
107106
if result.stderr:
108-
explanation += "See details below.\n"
109-
explanation += "```python\n"
110-
explanation += result.stderr + "\n"
111-
explanation += "```\n"
107+
explanation += f"See details below.\n ```python\n {result.stderr} \n ```\n"
112108
except TimeoutError:
113109
result = ExecResult(False, 1, "", "Verification timed out.")
114110
explanation += "Verification timed out."
@@ -131,10 +127,12 @@ def find_code(completion: str) -> str:
131127
return str(extracted_answer)
132128

133129

134-
# map class_eval record into inspect sample
135130
def record_to_sample(
136131
record: dict[str, Any],
137132
) -> Sample:
133+
'''
134+
Maps class_eval record into inspect sample
135+
'''
138136
return Sample(
139137
input = construct_prompt(record),
140138
target = record["solution_code"],

src/inspect_evals/class_eval/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_find_code(self):
2525
self.assertEqual(find_code(raw_code), sample_code)
2626

2727
def test_task(self):
28-
task = class_eval(k_shot = 5)
28+
task = class_eval(few_shot = 5)
2929
self.assertEqual(task.dataset.name, 'FudanSELab/ClassEval')
3030
self.assertEqual(task.epochs, 5)
3131
self.assertEqual(task.sandbox.type, "docker")

0 commit comments

Comments
 (0)