Skip to content

Commit 25add04

Browse files
authored
Merge pull request #70 from benlipkin/FixPrefixFIM
prepend and parse prefix cli arg correctly when doing FIM
2 parents cc916af + 2643079 commit 25add04

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

lm_eval/utils.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def __iter__(self):
4646
elif isinstance(prompt_contents, dict):
4747
assert set(prompt_contents.keys()) == {"prefix", "suffix"}
4848
infill.append(True)
49-
prompt = self.prefix + self._make_infill_prompt(**prompt_contents)
49+
prompt = self._make_infill_prompt(
50+
**prompt_contents, preprefix=self.prefix
51+
)
5052
else:
5153
raise ValueError(f"Unsupported prompt format: {type(prompt_contents)}")
5254
prompts.append(prompt)
@@ -83,18 +85,18 @@ def __iter__(self):
8385
"input_len": outputs.attention_mask[sample].sum(),
8486
}
8587

86-
def _make_infill_prompt(self, prefix, suffix):
88+
def _make_infill_prompt(self, prefix, suffix, preprefix=""):
8789
"""Make a prompt for infilling.
8890
Currently supported only for official InCoder and SantaCoder implementations.
8991
"""
9092
model_id = self.tokenizer.name_or_path
9193
if model_id in ["facebook/incoder-1B", "facebook/incoder-6B"]:
9294
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
93-
return f"{prefix}<|mask:0|>{suffix}<|mask:0|>"
95+
return f"{preprefix}{prefix}<|mask:0|>{suffix}<|mask:0|>"
9496
elif model_id in ["bigcode/santacoder"]:
95-
return f"<fim-prefix>{prefix}<fim-suffix>{suffix}<fim-middle>"
96-
elif model_id in ["bigcode/large-model"]:
97-
return f"<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>"
97+
return f"<fim-prefix>{preprefix}{prefix}<fim-suffix>{suffix}<fim-middle>"
98+
elif model_id in ["bigcode/starcoder", "bigcode/starcoderbase"]:
99+
return f"<fim_prefix>{preprefix}{prefix}<fim_suffix>{suffix}<fim_middle>"
98100
else:
99101
raise ValueError(f"Infilling not yet supported for: {model_id}")
100102

@@ -160,7 +162,7 @@ def parse_infill(code, tokenizer):
160162
prefix, rest = code.split("<fim-suffix>", 1)
161163
suffix, infill = rest.split("<fim-middle>", 1)
162164
infill = infill.split("<|endoftext|>")[0]
163-
elif model_id in ["bigcode/large-model"]:
165+
elif model_id in ["bigcode/starcoder", "bigcode/starcoderbase"]:
164166
prefix, rest = code.split("<fim_suffix>", 1)
165167
suffix, infill = rest.split("<fim_middle>", 1)
166168
infill = infill.split("<|endoftext|>")[0]
@@ -177,29 +179,26 @@ def parse_infill(code, tokenizer):
177179
code_gens = [[] for _ in range(n_tasks)]
178180
for sample, generated_tokens in gen_token_dict.items():
179181
for s in generated_tokens:
180-
if INFILL_MODE:
181-
gen_code = parse_infill(
182-
tokenizer.decode(
183-
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
184-
),
185-
tokenizer,
186-
)
187-
elif tokenizer.eos_token in task.stop_words:
182+
if INFILL_MODE or tokenizer.eos_token in task.stop_words:
188183
gen_code = tokenizer.decode(
189184
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
190185
)
186+
if INFILL_MODE:
187+
gen_code = parse_infill(gen_code, tokenizer)
191188
else:
192189
gen_code = tokenizer.decode(
193190
s, skip_special_tokens=True, clean_up_tokenization_spaces=True
194191
)
192+
if not INFILL_MODE:
193+
gen_code = gen_code[len(prefix) :]
195194
if postprocess:
196195
code_gens[sample].append(
197-
task.postprocess_generation(gen_code[len(prefix) :], int(sample))
196+
task.postprocess_generation(gen_code, int(sample))
198197
)
199198
else:
200199
warnings.warn(
201200
"model output is not postprocessed, this might lower evaluation scores"
202201
)
203-
code_gens[sample].append(gen_code[len(prefix) :])
202+
code_gens[sample].append(gen_code)
204203

205204
return code_gens

0 commit comments

Comments
 (0)