Skip to content

Commit 4b05ad8

Browse files
committed
Format
1 parent 31d4b47 commit 4b05ad8

File tree

2 files changed

+106
-89
lines changed

2 files changed

+106
-89
lines changed

bigcode_eval/tasks/custom_metrics/beyond_eval.py

Lines changed: 105 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,24 @@
11
# coding: utf-8
22

3-
# The Beyond metric estimates the beyond@k metric for code synthesis efficiency.
3+
# The Beyond metric estimates the Beyond@k metric for code synthesis efficiency.
4+
# The sandbox is inspired by OpenAI's release
5+
# https://github.com/openai/human-eval/blob/master/human_eval/execution.py
46

7+
from concurrent.futures import ThreadPoolExecutor, as_completed
8+
from typing import Optional, Dict, Any, List
9+
from multiprocessing import Manager, Process
10+
from tqdm import tqdm
11+
import numpy as np
12+
import faulthandler
13+
import contextlib
14+
import itertools
15+
import platform
16+
import tempfile
17+
import signal
18+
import json
19+
import time
20+
import os
21+
import io
522
CITATION = """
623
@article{du2024mercury,
724
title={Mercury: An Efficiency Benchmark for LLM Code Synthesis},
@@ -11,23 +28,6 @@
1128
}
1229
"""
1330

14-
import io
15-
import os
16-
import time
17-
import math
18-
import json
19-
import signal
20-
import tempfile
21-
import platform
22-
import itertools
23-
import contextlib
24-
import faulthandler
25-
import numpy as np
26-
from tqdm import tqdm
27-
from multiprocessing import Manager, Process
28-
from typing import Optional, Dict, Any, List
29-
from concurrent.futures import ThreadPoolExecutor, as_completed
30-
3131

3232
# Timeout Exception
3333
class TimeoutException(Exception):
@@ -38,7 +38,7 @@ class TimeoutException(Exception):
3838
class RedirectStdin(contextlib._RedirectStream):
3939
"""Context manager for temporarily receiving stdin from another source."""
4040
_stream = 'stdin'
41-
41+
4242
# WriteOnly IO
4343
class WriteOnlyStringIO(io.StringIO):
4444
""" StringIO that throws an exception when it's read from """
@@ -55,7 +55,8 @@ def readlines(self, *args, **kwargs):
5555
def readable(self, *args, **kwargs):
5656
""" Returns True if the IO object can be read. """
5757
return False
58-
58+
59+
5960
class Sandbox(object):
6061
@staticmethod
6162
@contextlib.contextmanager
@@ -104,18 +105,23 @@ def chdir(root):
104105
@staticmethod
105106
def reliability_guard(maximum_memory_bytes: Optional[int] = None):
106107
"""
107-
This disables various destructive functions and prevents the generated code from interfering with the test (e.g. fork bomb, killing other processes, removing filesystem files, etc.)
108+
This disables various destructive functions and prevents the generated code from interfering with the test
109+
(e.g. fork bomb, killing other processes, removing filesystem files, etc.)
108110
109-
WARNING
110-
This function is NOT a security sandbox. Untrusted code, including, model-generated code, should not be blindly executed outside of one. See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution.
111+
## WARNING ##
112+
This function is NOT a security sandbox. Untrusted code, including, model-generated code, should not be blindly executed outside of one.
113+
See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution.
111114
"""
112115

113116
if maximum_memory_bytes is not None:
114117
import resource
115-
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
116-
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
118+
resource.setrlimit(resource.RLIMIT_AS,
119+
(maximum_memory_bytes, maximum_memory_bytes))
120+
resource.setrlimit(resource.RLIMIT_DATA,
121+
(maximum_memory_bytes, maximum_memory_bytes))
117122
if platform.uname().system != 'Darwin':
118-
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
123+
resource.setrlimit(resource.RLIMIT_STACK,
124+
(maximum_memory_bytes, maximum_memory_bytes))
119125

120126
faulthandler.disable()
121127

@@ -170,7 +176,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
170176
sys.modules['resource'] = None
171177
sys.modules['psutil'] = None
172178
sys.modules['tkinter'] = None
173-
179+
174180
@staticmethod
175181
def unsafe_execute(sample, result):
176182
with Sandbox.create_tempdir():
@@ -210,53 +216,62 @@ def unsafe_execute(sample, result):
210216
exec("from collections import deque, defaultdict, OrderedDict", namespace)
211217
exec("from typing import List, Optional, Tuple", namespace)
212218
exec("from functools import lru_cache, cache", namespace)
213-
219+
214220
exec("class ListNode(object):\n\tdef __init__(self, val=0, next=None):\n\t\tself.val = val\n\t\tself.next = next", namespace)
215221
exec("class TreeNode(object):\n\tdef __init__(self, val=0, left=None, right=None):\n\t\tself.val = val\n\t\tself.left = left\n\t\tself.right = right", namespace)
216-
222+
217223
exec("def print(*args):pass", namespace)
218-
224+
219225
total, passed = 0, 0
220226
runtime = 0
221227
with Sandbox.swallow_io():
222-
with Sandbox.time_limit(sample['timeout']):
228+
with Sandbox.time_limit(sample['timeout']):
223229
try:
224230
exec(sample['solution'], namespace)
225231
exec(f"solution=Solution()", namespace)
226232
exec(sample['convert_offline'], namespace)
227233
exec(sample['evaluate_offline'], namespace)
228234
except Exception as e:
229-
result.append({"status": "failed@load", "runtime": runtime, "error": e})
230-
235+
result.append(
236+
{"status": "failed@load", "runtime": runtime, "error": e})
237+
231238
try:
232239
start_time = time.time()
233240
for test_case in sample['test_cases']:
234241
namespace['inputs'] = test_case['input']
235242
namespace['expected'] = test_case['expected']
236-
exec("inputs, expected = convert_offline((inputs, expected))", namespace)
237-
exec(f"outputs = solution.{sample['entry_point']}(*inputs)", namespace)
238-
exec(f"passed = evaluate_offline(inputs, outputs, expected)", namespace)
243+
exec(
244+
"inputs, expected = convert_offline((inputs, expected))", namespace)
245+
exec(
246+
f"outputs = solution.{sample['entry_point']}(*inputs)", namespace)
247+
exec(
248+
f"passed = evaluate_offline(inputs, outputs, expected)", namespace)
239249
total += 1
240250
passed += (1 if namespace['passed'] else 0)
241251
end_time = time.time()
242252
runtime = end_time-start_time
243253
except Exception as e:
244-
result.append({"status": "failed@eval", "runtime": runtime, "error": e})
245-
254+
result.append(
255+
{"status": "failed@eval", "runtime": runtime, "error": e})
256+
246257
if total == passed:
247-
result.append({"status": "passed", "runtime": runtime, "error": "None"})
258+
result.append(
259+
{"status": "passed", "runtime": runtime, "error": "None"})
248260
else:
249-
result.append({"status": "failed@cases", "runtime": runtime, "error": "case error"})
261+
result.append({"status": "failed@cases",
262+
"runtime": runtime, "error": "case error"})
250263
except TimeoutException:
251-
result.append({"status": "failed@timeout", "runtime": runtime, "error": "execution time out"})
264+
result.append(
265+
{"status": "failed@timeout", "runtime": runtime, "error": "execution time out"})
252266
except BaseException as e:
253-
result.append({"status": "failed@error", "runtime": runtime, "error": e})
267+
result.append({"status": "failed@error",
268+
"runtime": runtime, "error": e})
254269

255270
# Needed for cleaning up.
256271
shutil.rmtree = rmtree
257272
os.rmdir = rmdir
258273
os.chdir = chdir
259-
274+
260275
@staticmethod
261276
def run_sample(sample) -> Dict:
262277
"""
@@ -273,31 +288,33 @@ def run_sample(sample) -> Dict:
273288
p.kill()
274289

275290
if not result:
276-
result.append({"status": "failed@timeout", "runtime": sample['timeout'], "error": "sandbox time out"})
277-
291+
result.append(
292+
{"status": "failed@timeout", "runtime": sample['timeout'], "error": "sandbox time out"})
293+
278294
return dict(
279-
result=result[0]['status'],
280-
runtime=result[0]['runtime'],
281-
index=sample['solution_index'],
295+
result=result[0]['status'],
296+
runtime=result[0]['runtime'],
297+
index=sample['solution_index'],
282298
error=result[0]['error'],
283299
)
284300

285301
@staticmethod
286302
def run_samples(samples, n_workers=4):
287303
with ThreadPoolExecutor(max_workers=n_workers) as executor:
288304
futures, results = list(), list()
289-
305+
290306
for sample in samples:
291307
args = (sample,)
292308
future = executor.submit(Sandbox.run_sample, *args)
293309
futures.append(future)
294-
310+
295311
for future in tqdm(as_completed(futures), total=len(futures), desc='Reading futures'):
296312
result = future.result()
297313
results.append(result)
298-
314+
299315
return results
300316

317+
301318
def estimate_pass_at_k(num_samples, num_correct, k):
302319
"""Estimates pass@k of each problem and returns them in an array."""
303320

@@ -315,54 +332,56 @@ def estimator(n: int, c: int, k: int) -> float:
315332

316333
return np.array([estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)])
317334

335+
318336
def estimate_beyond_at_k(beyonds, k):
319337
return sum([sum(b[:k]) / k for b in beyonds]) / len(beyonds)
320338

339+
321340
def compute_beyond_eval(generations_list, reference_list, timeout=10):
322341
sandbox = Sandbox()
323-
342+
324343
scores = {
325344
"Easy": dict(total_c=list(), correct_c=list(), beyond_c=list()),
326345
"Medium": dict(total_c=list(), correct_c=list(), beyond_c=list()),
327346
"Hard": dict(total_c=list(), correct_c=list(), beyond_c=list()),
328347
"Average": dict(total_c=list(), correct_c=list(), beyond_c=list()),
329348
}
330-
349+
331350
errors = {
332-
"Easy": {"failed@load": 0,"failed@eval": 0,'failed@cases': 0,"failed@timeout": 0,"failed@error": 0,"passed":0},
333-
"Medium": {"failed@load": 0,"failed@eval": 0,"failed@cases": 0,"failed@timeout": 0,"failed@error": 0,"passed":0},
334-
"Hard": {"failed@load": 0,"failed@eval": 0,"failed@cases": 0,"failed@timeout": 0,"failed@error": 0,"passed":0},
351+
"Easy": {"failed@load": 0, "failed@eval": 0, 'failed@cases': 0, "failed@timeout": 0, "failed@error": 0, "passed": 0},
352+
"Medium": {"failed@load": 0, "failed@eval": 0, "failed@cases": 0, "failed@timeout": 0, "failed@error": 0, "passed": 0},
353+
"Hard": {"failed@load": 0, "failed@eval": 0, "failed@cases": 0, "failed@timeout": 0, "failed@error": 0, "passed": 0},
335354
}
336-
355+
337356
for generations, instance in tqdm(zip(generations_list, reference_list), total=len(generations_list), desc='compute_beyond_eval'):
338357
# Construct runtime distribution from sample solutions
339358
runtimes = list()
340359
for index, solution in tqdm(enumerate(instance['solutions']), desc="Construct runtime distribution from sample solutions"):
341360
sample = {
342-
"solution": solution['solution'],
343-
"convert_offline": instance['convert_offline'],
344-
"evaluate_offline": instance['evaluate_offline'],
345-
"entry_point": instance['entry_point'],
346-
"test_cases": json.loads(instance['test_cases']),
347-
"solution_index": index,
348-
"timeout": timeout
349-
}
361+
"solution": solution['solution'],
362+
"convert_offline": instance['convert_offline'],
363+
"evaluate_offline": instance['evaluate_offline'],
364+
"entry_point": instance['entry_point'],
365+
"test_cases": json.loads(instance['test_cases']),
366+
"solution_index": index,
367+
"timeout": timeout
368+
}
350369
result = sandbox.run_sample(sample)
351-
370+
352371
if result['result'] == "passed":
353372
runtimes += [result['runtime']]
354-
373+
355374
# Calculate Range
356375
runtimes = sorted(runtimes)
357376
min_runtime = min(runtimes)
358377
max_runtime = max(runtimes)
359-
378+
360379
# Evaluate generated solutions
361380
t_c, p_c = 0, 0
362381
b_l = list()
363382
difficulty = instance['difficulty']
364-
365-
for index, solution in tqdm(enumerate(generations), desc="generation execution", total=len(generations)):
383+
384+
for index, solution in tqdm(enumerate(generations), desc="generation execution", total=len(generations)):
366385
sample = {
367386
"solution": solution,
368387
"convert_offline": instance['convert_offline'],
@@ -372,53 +391,52 @@ def compute_beyond_eval(generations_list, reference_list, timeout=10):
372391
"solution_index": index,
373392
"timeout": timeout,
374393
}
375-
394+
376395
results = [sandbox.run_sample(sample) for _ in range(3)]
377396
print(results[0])
378397
t_c += 1
379-
398+
380399
# Calculate Beyond
381400
if results[0]['result'] == "passed":
382401
runtime = sum([r['runtime'] for r in results]) / len(results)
383402
p_c += 1
384403
else:
385404
runtime = float('inf')
386-
405+
387406
# Statistic Errors
388407
errors[difficulty][results[0]['result']] += 1
389-
408+
390409
runtime = min(runtime, max_runtime)
391410
runtime = max(runtime, min_runtime)
392-
b_l += [(max_runtime - runtime) / (max_runtime - min_runtime)]
393-
411+
b_l += [(max_runtime - runtime) / (max_runtime - min_runtime)]
412+
394413
scores[difficulty]['total_c'] += [t_c]
395414
scores[difficulty]['correct_c'] += [p_c]
396415
scores[difficulty]['beyond_c'] += [b_l]
397-
416+
398417
scores['Average']['total_c'] += [t_c]
399418
scores['Average']['correct_c'] += [p_c]
400419
scores['Average']['beyond_c'] += [b_l]
401-
420+
402421
# print(f'total: {t_c}')
403422
# print(f'correct: {p_c}')
404423
# print(f'beyond: {b_l}')
405424
# print("-" * 60)
406-
425+
407426
results = dict()
408427
for difficulty in ['Easy', "Medium", "Hard", "Average"]:
409428
total = np.array(scores[difficulty]['total_c'])
410429
correct = np.array(scores[difficulty]['correct_c'])
411430
beyond = scores[difficulty]['beyond_c']
412-
413-
pass_at_k = {f"{difficulty}_pass@{k}": estimate_pass_at_k(total, correct, k).mean() for k in [1,3,5,10,15,20,30,50,100] if (total >= k).all()}
414-
beyond_at_k = {f"{difficulty}_beyond@{k}": estimate_beyond_at_k(beyond, k) for k in [1,3,5,10,15,20,30,50,100] if (total >= k).all()}
415-
431+
432+
pass_at_k = {f"{difficulty}_pass@{k}": estimate_pass_at_k(total, correct, k).mean(
433+
) for k in [1, 3, 5, 10, 15, 20, 30, 50, 100] if (total >= k).all()}
434+
beyond_at_k = {f"{difficulty}_beyond@{k}": estimate_beyond_at_k(
435+
beyond, k) for k in [1, 3, 5, 10, 15, 20, 30, 50, 100] if (total >= k).all()}
436+
416437
results.update(pass_at_k)
417438
results.update(beyond_at_k)
418-
439+
419440
results.update(errors)
420441

421442
return results
422-
423-
424-

bigcode_eval/tasks/mercury.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
_CITATION = """
1616
@article{du2024mercury,
1717
title={Mercury: An Efficiency Benchmark for LLM Code Synthesis},
18-
author={Du, Mingzhe and Luu, Anh Tuan and Ji, Bin and Ng, See-Kiong},
18+
author={Du, Mingzhe and Luu, Anh Tuan and Ji, Bin and Qian, Liu and Ng, See-Kiong},
1919
journal={arXiv preprint arXiv:2402.07844},
2020
year={2024}
2121
}
@@ -31,7 +31,6 @@ class Mercury(Task):
3131

3232
def __init__(self, prompt):
3333
super().__init__(
34-
# stop_words=["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n```", "<file_sep>", "<|end▁of▁sentence|>"],
3534
stop_words=["\nclass", "\ndef", "\n#", "\n@", "\nprint", "\nif", "\n```", "<file_sep>", "<|end▁of▁sentence|>", "\n###", "\n\n\n\n\n", "<|endoftext|>"],
3635
requires_execution=True,
3736
)

0 commit comments

Comments
 (0)