@@ -46,7 +46,9 @@ def __iter__(self):
4646 elif isinstance (prompt_contents , dict ):
4747 assert set (prompt_contents .keys ()) == {"prefix" , "suffix" }
4848 infill .append (True )
49- prompt = self .prefix + self ._make_infill_prompt (** prompt_contents )
49+ prompt = self ._make_infill_prompt (
50+ ** prompt_contents , preprefix = self .prefix
51+ )
5052 else :
5153 raise ValueError (f"Unsupported prompt format: { type (prompt_contents )} " )
5254 prompts .append (prompt )
@@ -83,18 +85,18 @@ def __iter__(self):
8385 "input_len" : outputs .attention_mask [sample ].sum (),
8486 }
8587
86- def _make_infill_prompt (self , prefix , suffix ):
88+ def _make_infill_prompt (self , prefix , suffix , preprefix = "" ):
8789 """Make a prompt for infilling.
8890 Currently supported only for official InCoder and SantaCoder implementations.
8991 """
9092 model_id = self .tokenizer .name_or_path
9193 if model_id in ["facebook/incoder-1B" , "facebook/incoder-6B" ]:
9294 self .tokenizer .add_special_tokens ({"pad_token" : "<pad>" })
93- return f"{ prefix } <|mask:0|>{ suffix } <|mask:0|>"
95+ return f"{ preprefix } { prefix } <|mask:0|>{ suffix } <|mask:0|>"
9496 elif model_id in ["bigcode/santacoder" ]:
95- return f"<fim-prefix>{ prefix } <fim-suffix>{ suffix } <fim-middle>"
96- elif model_id in ["bigcode/large-model " ]:
97- return f"<fim_prefix>{ prefix } <fim_suffix>{ suffix } <fim_middle>"
97+ return f"<fim-prefix>{ preprefix } { prefix } <fim-suffix>{ suffix } <fim-middle>"
98+ elif model_id in ["bigcode/starcoder" , "bigcode/starcoderbase " ]:
99+ return f"<fim_prefix>{ preprefix } { prefix } <fim_suffix>{ suffix } <fim_middle>"
98100 else :
99101 raise ValueError (f"Infilling not yet supported for: { model_id } " )
100102
@@ -160,7 +162,7 @@ def parse_infill(code, tokenizer):
160162 prefix , rest = code .split ("<fim-suffix>" , 1 )
161163 suffix , infill = rest .split ("<fim-middle>" , 1 )
162164 infill = infill .split ("<|endoftext|>" )[0 ]
163- elif model_id in ["bigcode/large-model " ]:
165+ elif model_id in ["bigcode/starcoder" , "bigcode/starcoderbase " ]:
164166 prefix , rest = code .split ("<fim_suffix>" , 1 )
165167 suffix , infill = rest .split ("<fim_middle>" , 1 )
166168 infill = infill .split ("<|endoftext|>" )[0 ]
@@ -177,29 +179,26 @@ def parse_infill(code, tokenizer):
177179 code_gens = [[] for _ in range (n_tasks )]
178180 for sample , generated_tokens in gen_token_dict .items ():
179181 for s in generated_tokens :
180- if INFILL_MODE :
181- gen_code = parse_infill (
182- tokenizer .decode (
183- s , skip_special_tokens = False , clean_up_tokenization_spaces = False
184- ),
185- tokenizer ,
186- )
187- elif tokenizer .eos_token in task .stop_words :
182+ if INFILL_MODE or tokenizer .eos_token in task .stop_words :
188183 gen_code = tokenizer .decode (
189184 s , skip_special_tokens = False , clean_up_tokenization_spaces = False
190185 )
186+ if INFILL_MODE :
187+ gen_code = parse_infill (gen_code , tokenizer )
191188 else :
192189 gen_code = tokenizer .decode (
193190 s , skip_special_tokens = True , clean_up_tokenization_spaces = True
194191 )
192+ if not INFILL_MODE :
193+ gen_code = gen_code [len (prefix ) :]
195194 if postprocess :
196195 code_gens [sample ].append (
197- task .postprocess_generation (gen_code [ len ( prefix ) :] , int (sample ))
196+ task .postprocess_generation (gen_code , int (sample ))
198197 )
199198 else :
200199 warnings .warn (
201200 "model output is not postprocessed, this might lower evaluation scores"
202201 )
203- code_gens [sample ].append (gen_code [ len ( prefix ) :] )
202+ code_gens [sample ].append (gen_code )
204203
205204 return code_gens
0 commit comments