11# coding: utf-8
22
3- # The Beyond metric estimates the beyond@k metric for code synthesis efficiency.
3+ # The Beyond metric estimates the Beyond@k metric for code synthesis efficiency.
4+ # The sandbox is inspired by OpenAI's release
5+ # https://github.com/openai/human-eval/blob/master/human_eval/execution.py
46
7+ from concurrent .futures import ThreadPoolExecutor , as_completed
8+ from typing import Optional , Dict , Any , List
9+ from multiprocessing import Manager , Process
10+ from tqdm import tqdm
11+ import numpy as np
12+ import faulthandler
13+ import contextlib
14+ import itertools
15+ import platform
16+ import tempfile
17+ import signal
18+ import json
19+ import time
20+ import os
21+ import io
522CITATION = """
623@article{du2024mercury,
724 title={Mercury: An Efficiency Benchmark for LLM Code Synthesis},
1128}
1229"""
1330
14- import io
15- import os
16- import time
17- import math
18- import json
19- import signal
20- import tempfile
21- import platform
22- import itertools
23- import contextlib
24- import faulthandler
25- import numpy as np
26- from tqdm import tqdm
27- from multiprocessing import Manager , Process
28- from typing import Optional , Dict , Any , List
29- from concurrent .futures import ThreadPoolExecutor , as_completed
30-
3131
3232# Timeout Exception
3333class TimeoutException (Exception ):
@@ -38,7 +38,7 @@ class TimeoutException(Exception):
3838class RedirectStdin (contextlib ._RedirectStream ):
3939 """Context manager for temporarily receiving stdin from another source."""
4040 _stream = 'stdin'
41-
41+
4242# WriteOnly IO
4343class WriteOnlyStringIO (io .StringIO ):
4444 """ StringIO that throws an exception when it's read from """
@@ -55,7 +55,8 @@ def readlines(self, *args, **kwargs):
5555 def readable (self , * args , ** kwargs ):
5656 """ Returns True if the IO object can be read. """
5757 return False
58-
58+
59+
5960class Sandbox (object ):
6061 @staticmethod
6162 @contextlib .contextmanager
@@ -104,18 +105,23 @@ def chdir(root):
104105 @staticmethod
105106 def reliability_guard (maximum_memory_bytes : Optional [int ] = None ):
106107 """
107- This disables various destructive functions and prevents the generated code from interfering with the test (e.g. fork bomb, killing other processes, removing filesystem files, etc.)
108+ This disables various destructive functions and prevents the generated code from interfering with the test
109+ (e.g. fork bomb, killing other processes, removing filesystem files, etc.)
108110
109- WARNING
110- This function is NOT a security sandbox. Untrusted code, including, model-generated code, should not be blindly executed outside of one. See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution.
111+ ## WARNING ##
112+ This function is NOT a security sandbox. Untrusted code, including, model-generated code, should not be blindly executed outside of one.
113+ See the Codex paper for more information about OpenAI's code sandbox, and proceed with caution.
111114 """
112115
113116 if maximum_memory_bytes is not None :
114117 import resource
115- resource .setrlimit (resource .RLIMIT_AS , (maximum_memory_bytes , maximum_memory_bytes ))
116- resource .setrlimit (resource .RLIMIT_DATA , (maximum_memory_bytes , maximum_memory_bytes ))
118+ resource .setrlimit (resource .RLIMIT_AS ,
119+ (maximum_memory_bytes , maximum_memory_bytes ))
120+ resource .setrlimit (resource .RLIMIT_DATA ,
121+ (maximum_memory_bytes , maximum_memory_bytes ))
117122 if platform .uname ().system != 'Darwin' :
118- resource .setrlimit (resource .RLIMIT_STACK , (maximum_memory_bytes , maximum_memory_bytes ))
123+ resource .setrlimit (resource .RLIMIT_STACK ,
124+ (maximum_memory_bytes , maximum_memory_bytes ))
119125
120126 faulthandler .disable ()
121127
@@ -170,7 +176,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
170176 sys .modules ['resource' ] = None
171177 sys .modules ['psutil' ] = None
172178 sys .modules ['tkinter' ] = None
173-
179+
174180 @staticmethod
175181 def unsafe_execute (sample , result ):
176182 with Sandbox .create_tempdir ():
@@ -210,53 +216,62 @@ def unsafe_execute(sample, result):
210216 exec ("from collections import deque, defaultdict, OrderedDict" , namespace )
211217 exec ("from typing import List, Optional, Tuple" , namespace )
212218 exec ("from functools import lru_cache, cache" , namespace )
213-
219+
214220 exec ("class ListNode(object):\n \t def __init__(self, val=0, next=None):\n \t \t self.val = val\n \t \t self.next = next" , namespace )
215221 exec ("class TreeNode(object):\n \t def __init__(self, val=0, left=None, right=None):\n \t \t self.val = val\n \t \t self.left = left\n \t \t self.right = right" , namespace )
216-
222+
217223 exec ("def print(*args):pass" , namespace )
218-
224+
219225 total , passed = 0 , 0
220226 runtime = 0
221227 with Sandbox .swallow_io ():
222- with Sandbox .time_limit (sample ['timeout' ]):
228+ with Sandbox .time_limit (sample ['timeout' ]):
223229 try :
224230 exec (sample ['solution' ], namespace )
225231 exec (f"solution=Solution()" , namespace )
226232 exec (sample ['convert_offline' ], namespace )
227233 exec (sample ['evaluate_offline' ], namespace )
228234 except Exception as e :
229- result .append ({"status" : "failed@load" , "runtime" : runtime , "error" : e })
230-
235+ result .append (
236+ {"status" : "failed@load" , "runtime" : runtime , "error" : e })
237+
231238 try :
232239 start_time = time .time ()
233240 for test_case in sample ['test_cases' ]:
234241 namespace ['inputs' ] = test_case ['input' ]
235242 namespace ['expected' ] = test_case ['expected' ]
236- exec ("inputs, expected = convert_offline((inputs, expected))" , namespace )
237- exec (f"outputs = solution.{ sample ['entry_point' ]} (*inputs)" , namespace )
238- exec (f"passed = evaluate_offline(inputs, outputs, expected)" , namespace )
243+ exec (
244+ "inputs, expected = convert_offline((inputs, expected))" , namespace )
245+ exec (
246+ f"outputs = solution.{ sample ['entry_point' ]} (*inputs)" , namespace )
247+ exec (
248+ f"passed = evaluate_offline(inputs, outputs, expected)" , namespace )
239249 total += 1
240250 passed += (1 if namespace ['passed' ] else 0 )
241251 end_time = time .time ()
242252 runtime = end_time - start_time
243253 except Exception as e :
244- result .append ({"status" : "failed@eval" , "runtime" : runtime , "error" : e })
245-
254+ result .append (
255+ {"status" : "failed@eval" , "runtime" : runtime , "error" : e })
256+
246257 if total == passed :
247- result .append ({"status" : "passed" , "runtime" : runtime , "error" : "None" })
258+ result .append (
259+ {"status" : "passed" , "runtime" : runtime , "error" : "None" })
248260 else :
249- result .append ({"status" : "failed@cases" , "runtime" : runtime , "error" : "case error" })
261+ result .append ({"status" : "failed@cases" ,
262+ "runtime" : runtime , "error" : "case error" })
250263 except TimeoutException :
251- result .append ({"status" : "failed@timeout" , "runtime" : runtime , "error" : "execution time out" })
264+ result .append (
265+ {"status" : "failed@timeout" , "runtime" : runtime , "error" : "execution time out" })
252266 except BaseException as e :
253- result .append ({"status" : "failed@error" , "runtime" : runtime , "error" : e })
267+ result .append ({"status" : "failed@error" ,
268+ "runtime" : runtime , "error" : e })
254269
255270 # Needed for cleaning up.
256271 shutil .rmtree = rmtree
257272 os .rmdir = rmdir
258273 os .chdir = chdir
259-
274+
260275 @staticmethod
261276 def run_sample (sample ) -> Dict :
262277 """
@@ -273,31 +288,33 @@ def run_sample(sample) -> Dict:
273288 p .kill ()
274289
275290 if not result :
276- result .append ({"status" : "failed@timeout" , "runtime" : sample ['timeout' ], "error" : "sandbox time out" })
277-
291+ result .append (
292+ {"status" : "failed@timeout" , "runtime" : sample ['timeout' ], "error" : "sandbox time out" })
293+
278294 return dict (
279- result = result [0 ]['status' ],
280- runtime = result [0 ]['runtime' ],
281- index = sample ['solution_index' ],
295+ result = result [0 ]['status' ],
296+ runtime = result [0 ]['runtime' ],
297+ index = sample ['solution_index' ],
282298 error = result [0 ]['error' ],
283299 )
284300
285301 @staticmethod
286302 def run_samples (samples , n_workers = 4 ):
287303 with ThreadPoolExecutor (max_workers = n_workers ) as executor :
288304 futures , results = list (), list ()
289-
305+
290306 for sample in samples :
291307 args = (sample ,)
292308 future = executor .submit (Sandbox .run_sample , * args )
293309 futures .append (future )
294-
310+
295311 for future in tqdm (as_completed (futures ), total = len (futures ), desc = 'Reading futures' ):
296312 result = future .result ()
297313 results .append (result )
298-
314+
299315 return results
300316
317+
301318def estimate_pass_at_k (num_samples , num_correct , k ):
302319 """Estimates pass@k of each problem and returns them in an array."""
303320
@@ -315,54 +332,56 @@ def estimator(n: int, c: int, k: int) -> float:
315332
316333 return np .array ([estimator (int (n ), int (c ), k ) for n , c in zip (num_samples_it , num_correct )])
317334
335+
318336def estimate_beyond_at_k (beyonds , k ):
319337 return sum ([sum (b [:k ]) / k for b in beyonds ]) / len (beyonds )
320338
339+
321340def compute_beyond_eval (generations_list , reference_list , timeout = 10 ):
322341 sandbox = Sandbox ()
323-
342+
324343 scores = {
325344 "Easy" : dict (total_c = list (), correct_c = list (), beyond_c = list ()),
326345 "Medium" : dict (total_c = list (), correct_c = list (), beyond_c = list ()),
327346 "Hard" : dict (total_c = list (), correct_c = list (), beyond_c = list ()),
328347 "Average" : dict (total_c = list (), correct_c = list (), beyond_c = list ()),
329348 }
330-
349+
331350 errors = {
332- "Easy" : {"failed@load" : 0 ,"failed@eval" : 0 ,'failed@cases' : 0 ,"failed@timeout" : 0 ,"failed@error" : 0 ,"passed" :0 },
333- "Medium" : {"failed@load" : 0 ,"failed@eval" : 0 ,"failed@cases" : 0 ,"failed@timeout" : 0 ,"failed@error" : 0 ,"passed" :0 },
334- "Hard" : {"failed@load" : 0 ,"failed@eval" : 0 ,"failed@cases" : 0 ,"failed@timeout" : 0 ,"failed@error" : 0 ,"passed" :0 },
351+ "Easy" : {"failed@load" : 0 , "failed@eval" : 0 , 'failed@cases' : 0 , "failed@timeout" : 0 , "failed@error" : 0 , "passed" : 0 },
352+ "Medium" : {"failed@load" : 0 , "failed@eval" : 0 , "failed@cases" : 0 , "failed@timeout" : 0 , "failed@error" : 0 , "passed" : 0 },
353+ "Hard" : {"failed@load" : 0 , "failed@eval" : 0 , "failed@cases" : 0 , "failed@timeout" : 0 , "failed@error" : 0 , "passed" : 0 },
335354 }
336-
355+
337356 for generations , instance in tqdm (zip (generations_list , reference_list ), total = len (generations_list ), desc = 'compute_beyond_eval' ):
338357 # Construct runtime distribution from sample solutions
339358 runtimes = list ()
340359 for index , solution in tqdm (enumerate (instance ['solutions' ]), desc = "Construct runtime distribution from sample solutions" ):
341360 sample = {
342- "solution" : solution ['solution' ],
343- "convert_offline" : instance ['convert_offline' ],
344- "evaluate_offline" : instance ['evaluate_offline' ],
345- "entry_point" : instance ['entry_point' ],
346- "test_cases" : json .loads (instance ['test_cases' ]),
347- "solution_index" : index ,
348- "timeout" : timeout
349- }
361+ "solution" : solution ['solution' ],
362+ "convert_offline" : instance ['convert_offline' ],
363+ "evaluate_offline" : instance ['evaluate_offline' ],
364+ "entry_point" : instance ['entry_point' ],
365+ "test_cases" : json .loads (instance ['test_cases' ]),
366+ "solution_index" : index ,
367+ "timeout" : timeout
368+ }
350369 result = sandbox .run_sample (sample )
351-
370+
352371 if result ['result' ] == "passed" :
353372 runtimes += [result ['runtime' ]]
354-
373+
355374 # Calculate Range
356375 runtimes = sorted (runtimes )
357376 min_runtime = min (runtimes )
358377 max_runtime = max (runtimes )
359-
378+
360379 # Evaluate generated solutions
361380 t_c , p_c = 0 , 0
362381 b_l = list ()
363382 difficulty = instance ['difficulty' ]
364-
365- for index , solution in tqdm (enumerate (generations ), desc = "generation execution" , total = len (generations )):
383+
384+ for index , solution in tqdm (enumerate (generations ), desc = "generation execution" , total = len (generations )):
366385 sample = {
367386 "solution" : solution ,
368387 "convert_offline" : instance ['convert_offline' ],
@@ -372,53 +391,52 @@ def compute_beyond_eval(generations_list, reference_list, timeout=10):
372391 "solution_index" : index ,
373392 "timeout" : timeout ,
374393 }
375-
394+
376395 results = [sandbox .run_sample (sample ) for _ in range (3 )]
377396 print (results [0 ])
378397 t_c += 1
379-
398+
380399 # Calculate Beyond
381400 if results [0 ]['result' ] == "passed" :
382401 runtime = sum ([r ['runtime' ] for r in results ]) / len (results )
383402 p_c += 1
384403 else :
385404 runtime = float ('inf' )
386-
405+
387406 # Statistic Errors
388407 errors [difficulty ][results [0 ]['result' ]] += 1
389-
408+
390409 runtime = min (runtime , max_runtime )
391410 runtime = max (runtime , min_runtime )
392- b_l += [(max_runtime - runtime ) / (max_runtime - min_runtime )]
393-
411+ b_l += [(max_runtime - runtime ) / (max_runtime - min_runtime )]
412+
394413 scores [difficulty ]['total_c' ] += [t_c ]
395414 scores [difficulty ]['correct_c' ] += [p_c ]
396415 scores [difficulty ]['beyond_c' ] += [b_l ]
397-
416+
398417 scores ['Average' ]['total_c' ] += [t_c ]
399418 scores ['Average' ]['correct_c' ] += [p_c ]
400419 scores ['Average' ]['beyond_c' ] += [b_l ]
401-
420+
402421 # print(f'total: {t_c}')
403422 # print(f'correct: {p_c}')
404423 # print(f'beyond: {b_l}')
405424 # print("-" * 60)
406-
425+
407426 results = dict ()
408427 for difficulty in ['Easy' , "Medium" , "Hard" , "Average" ]:
409428 total = np .array (scores [difficulty ]['total_c' ])
410429 correct = np .array (scores [difficulty ]['correct_c' ])
411430 beyond = scores [difficulty ]['beyond_c' ]
412-
413- pass_at_k = {f"{ difficulty } _pass@{ k } " : estimate_pass_at_k (total , correct , k ).mean () for k in [1 ,3 ,5 ,10 ,15 ,20 ,30 ,50 ,100 ] if (total >= k ).all ()}
414- beyond_at_k = {f"{ difficulty } _beyond@{ k } " : estimate_beyond_at_k (beyond , k ) for k in [1 ,3 ,5 ,10 ,15 ,20 ,30 ,50 ,100 ] if (total >= k ).all ()}
415-
431+
432+ pass_at_k = {f"{ difficulty } _pass@{ k } " : estimate_pass_at_k (total , correct , k ).mean (
433+ ) for k in [1 , 3 , 5 , 10 , 15 , 20 , 30 , 50 , 100 ] if (total >= k ).all ()}
434+ beyond_at_k = {f"{ difficulty } _beyond@{ k } " : estimate_beyond_at_k (
435+ beyond , k ) for k in [1 , 3 , 5 , 10 , 15 , 20 , 30 , 50 , 100 ] if (total >= k ).all ()}
436+
416437 results .update (pass_at_k )
417438 results .update (beyond_at_k )
418-
439+
419440 results .update (errors )
420441
421442 return results
422-
423-
424-
0 commit comments