@@ -23,16 +23,18 @@ def __init__(self, model_class, num_workers=5):
2323
2424
2525 def evaluate_pass_k (self , problems , unit_tests , batch_size = 1 , max_length = 600 ,
26- top_p = 0.95 , k = [1 ,10 ,100 ],
26+ top_p = 0.95 , k = [1 ,10 ,100 ], temperature = 1.2 ,
2727 num_return_sequences = 200 , sequences_per_chunk = 10 , num_workers = 1 ):
2828 # Load dataset
29- data_loader = Dat aLoader (problems , batch_size = batch_size )
29+ # Please keep batch_size = 1 to avoid any unexpected error
30+ data_loader = DataLoader (problems , batch_size = batch_size )
3031 data_loader = self .accelerator .prepare (data_loader )
31-
32+ model_name = type ( self . model_class ). __name__
3233 # Initialize stopping criteria
3334 gen_kwargs = {
3435 "do_sample" : True ,
3536 "top_p" : top_p ,
37+ "temperature" : temperature ,
3638 "stopping_criteria" : StoppingCriteriaList ([EndOfFunctionCriteria (0 , EOF_STRINGS , self .model_class .get_tokenizer ())]),
3739 }
3840
@@ -54,7 +56,6 @@ def evaluate_pass_k(self, problems, unit_tests, batch_size=1, max_length=600,
5456 input_ids = prompt_ids [0 , :attention_masks [0 ].sum ().item ()]
5557
5658 input_data = self .model_class .get_tokenizer ().decode (input_ids , skip_special_tokens = True , clean_up_tokenization_spaces = True )
57-
5859 batch_generated_ids = self .model_class .get_model ().generate (
5960 input_ids = input_ids .unsqueeze (0 ),
6061 attention_mask = attention_masks [0 , :attention_masks [0 ].sum ().item ()].unsqueeze (0 ),
@@ -66,14 +67,16 @@ def evaluate_pass_k(self, problems, unit_tests, batch_size=1, max_length=600,
6667 gen_codes = self .model_class .get_tokenizer ().batch_decode (batch_generated_ids ,
6768 skip_special_tokens = True , clean_up_tokenization_spaces = True )
6869
69- for item in gen_codes :
70- cleaned = remove_last_block (item )
71- solutions_per_chunk .append (cleaned )
70+ for i ,item in enumerate (gen_codes ):
71+ result = remove_last_block (item )
72+ if model_name == "Seq2SeqModel" :
73+ result = f"{ input_data } { result } "
74+
75+ solutions_per_chunk .append (result )
7276
7377 solutions .append (solutions_per_chunk )
7478 dataloader_pbar .set_description (f"Processing step { step + 1 } /{ len (data_loader )} " )
7579
76-
7780 pass_at_k , _ = self .code_eval .compute (
7881 references = unit_tests , predictions = solutions , k = k , num_workers = num_workers
7982 )
0 commit comments