Skip to content

Commit a1b4a79

Browse files
committed
Add StudentEval from LLM4Code 2024
1 parent 1b0147c commit a1b4a79

File tree

3 files changed

+201
-1
lines changed

3 files changed

+201
-1
lines changed

bigcode_eval/tasks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from . import (apps, codexglue_code_to_text, codexglue_text_to_text, conala,
55
concode, ds1000, gsm, humaneval, humanevalplus, humanevalpack,
66
instruct_humaneval, instruct_wizard_humaneval, mbpp, mbppplus,
7-
multiple, parity, python_bugs, quixbugs, recode, santacoder_fim)
7+
multiple, parity, python_bugs, quixbugs, recode, santacoder_fim,
8+
studenteval)
89

910
TASK_REGISTRY = {
1011
**apps.create_all_tasks(),
@@ -28,6 +29,7 @@
2829
**instruct_humaneval.create_all_tasks(),
2930
**recode.create_all_tasks(),
3031
**santacoder_fim.create_all_tasks(),
32+
"studenteval": studenteval.StudentEval,
3133
}
3234

3335
ALL_TASKS = sorted(list(TASK_REGISTRY))

bigcode_eval/tasks/studenteval.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
StudentEval is a dataset of 1,749 prompts for 48 problems, authored by 80
3+
students who have only completed a one-semester Python programming class.
4+
Unlike many other benchmarks, it has multiple prompts per problem and multiple
5+
attempts by the same participant.
6+
7+
Web page: https://huggingface.co/datasets/wellesley-easel/StudentEval
8+
"""
9+
10+
from bigcode_eval.base import Task
11+
from datasets import load_dataset
12+
from multiprocessing import cpu_count
13+
from concurrent.futures import ThreadPoolExecutor
14+
from tqdm import tqdm
15+
import tempfile
16+
import pandas as pd
17+
import numpy as np
18+
import subprocess
19+
20+
_CITATION = """\
21+
@misc{babe2023studenteval,
22+
title={StudentEval: A Benchmark of Student-Written Prompts for Large Language Models of Code},
23+
author={Hannah McLean Babe and Sydney Nguyen and Yangtian Zi and Arjun Guha and Molly Q Feldman and Carolyn Jane Anderson},
24+
year={2023},
25+
eprint={2306.04556},
26+
archivePrefix={arXiv},
27+
primaryClass={cs.LG}
28+
}"""
29+
30+
EXECUTION_TIMEOUT = 15
31+
32+
33+
# Source: Chen at al. Evaluating Large Language Models of Code. 2021
34+
def _estimator(n: int, c: int, k: int) -> float:
35+
"""
36+
Calculates 1 - comb(n - c, k) / comb(n, k).
37+
"""
38+
assert c <= n, "c must be less than n"
39+
if n - c < k:
40+
return 1.0
41+
return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1))
42+
43+
44+
def _run_assembled_program(item):
45+
"""
46+
Runs the program with a timeout. The result dictionary has a "success" key
47+
that is 1 on success and 0 on failure. It also includes keys necessary to
48+
group results (problem, prompt, and group) and report results for each
49+
subset of StudentEval.
50+
"""
51+
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
52+
f.write(item["program"])
53+
f.flush()
54+
try:
55+
result = subprocess.run(
56+
["python3", f.name],
57+
timeout=EXECUTION_TIMEOUT,
58+
stdout=subprocess.DEVNULL,
59+
stderr=subprocess.DEVNULL,
60+
stdin=subprocess.DEVNULL,
61+
)
62+
exit_code = result.returncode
63+
except subprocess.TimeoutExpired:
64+
exit_code = 1
65+
return {
66+
"problem": item["problem"],
67+
"prompt": item["prompt"],
68+
"group": item["group"],
69+
"success": 1 if exit_code == 0 else 0,
70+
}
71+
72+
73+
def _get_group(item):
74+
"""
75+
These boolean flags are mutually exclusive in the dataset. We turn them into a
76+
a string for easy grouping with Pandas.
77+
"""
78+
if item["is_first_success"]:
79+
return "First Success"
80+
if item["is_last_success"]:
81+
return "Last Success"
82+
if item["is_first_failure"]:
83+
return "First Failure"
84+
if item["is_last_failure"]:
85+
return "Last Failure"
86+
return None
87+
88+
89+
class StudentEval(Task):
90+
DATASET_PATH = "wellesley-easel/StudentEval"
91+
92+
def __init__(self):
93+
self.stop_words = ["\ndef", "\nclass", "\nif", "\nprint"]
94+
self.requires_execution = True
95+
self.dataset = load_dataset(path=self.DATASET_PATH)
96+
# NOTE(Arjun Guha): Avoiding .filter so that we don't get a datasets
97+
# cache item on disk.
98+
self.dataset = [
99+
item for item in self.dataset["test"] if _get_group(item) is not None
100+
]
101+
102+
def get_dataset(self):
103+
return self.dataset
104+
105+
def get_prompt(self, doc):
106+
return doc["prompt"].rstrip()
107+
108+
# For a task with tests, the reference solution is the suite of tests.
109+
def get_reference(self, doc):
110+
return {
111+
"prompt": doc["prompt"],
112+
"assertions": doc["assertions"],
113+
"problem": doc["problem"],
114+
"group": _get_group(doc),
115+
}
116+
117+
def postprocess_generation(self, generation, idx):
118+
"""Defines the postprocessing for a LM generation.
119+
:param generation: str
120+
code generation from LM
121+
:param idx: int
122+
index of doc in the dataset to which the generation belongs
123+
(not used for Humaneval-Task)
124+
"""
125+
prompt = self.get_prompt(self.dataset[idx])
126+
generation = generation[len(prompt) :]
127+
return prompt + self._stop_at_stop_token(generation, self.stop_words)
128+
129+
def process_results(self, generations, references):
130+
"""Takes the list of LM generations and evaluates them against ground truth references,
131+
returning the metric for the generations.
132+
:param generations: list(list(str))
133+
list of lists containing generations
134+
:param references: list({ "assertions": list(str), "problem": str })
135+
list of reference solutions
136+
"""
137+
138+
worklist = []
139+
for generations, reference in zip(generations, references):
140+
# NOTE(Arjun Guha): This can be more efficient. At low temperature, we get lots of
141+
# repeated completions. So, this will end up running the same program repeatedly.
142+
# The original StudentEval code runs each generation once.
143+
for generation in generations:
144+
item = {
145+
"program": generation + "\n\n" + reference["assertions"],
146+
"prompt": reference["prompt"],
147+
"problem": reference["problem"],
148+
"group": reference["group"],
149+
}
150+
worklist.append(item)
151+
152+
with ThreadPoolExecutor(max_workers=cpu_count() - 1) as executor:
153+
results_df = pd.DataFrame(
154+
list(
155+
tqdm(
156+
executor.map(_run_assembled_program, worklist),
157+
total=len(worklist),
158+
)
159+
)
160+
)
161+
162+
# Calculate pass@1 for each prompt
163+
results_df = results_df.groupby(["problem", "prompt", "group"]).agg(
164+
c=("success", np.sum), n=("success", "count")
165+
)
166+
results_df.reset_index(inplace=True)
167+
results_df["pass1"] = results_df.apply(
168+
lambda row: _estimator(row["n"], row["c"], 1), axis=1
169+
)
170+
171+
# Calculate mean pass@1 for each group
172+
results_df = results_df.groupby(["group"]).agg(pass1=("pass1", np.mean))
173+
174+
# Turn into JSON
175+
results_df.reset_index(inplace=True)
176+
results_df = results_df.to_dict(orient="records")
177+
return results_df

docs/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,27 @@ accelerate launch main.py \
382382
--allow_code_execution
383383
```
384384

385+
### StudentEval
386+
387+
[StudentEval](https://huggingface.co/datasets/wellesley-easel/StudentEval) is a
388+
dataset of 1,749 prompts for 48 problems, authored by 80 students who have only
389+
completed a one-semester Python programming class. Unlike many other benchmarks,
390+
it has multiple prompts per problem and multiple attempts by the same
391+
participant. Each problem is accompanied by a set of instructor-written test
392+
cases.
393+
394+
```python
395+
accelerate launch main.py \
396+
--model <MODEL_NAME> \
397+
--max_length_generation 512 \
398+
--tasks studenteval \
399+
--temperature 0.2 \
400+
--top_p 0.95 \
401+
--do_sample True \
402+
--n_samples 20 \
403+
--batch_size 20 \
404+
--allow_code_execution
405+
```
385406

386407
## Code generation benchmarks without unit tests
387408

0 commit comments

Comments
 (0)