Skip to content

Commit cabaa1f

Browse files
committed
support left padding and prefix post-processing for models like chatglm
1 parent bb1ed78 commit cabaa1f

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

bigcode_eval/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def complete_code(
268268
batch["input_len"].max().item()
269269
)
270270

271-
inputs = batch["ids"][:, : batch["input_len"]]
271+
inputs = batch["ids"][:, : batch["input_len"]] if tokenizer.padding_side == "right" else batch["ids"]
272272
if "ids_encoder" in batch:
273273
if is_wrapped:
274274
generated_tokens = accelerator.unwrap_model(model).generate(
@@ -365,7 +365,7 @@ def update_code_gens(
365365
postprocess,
366366
code_gens,
367367
gen_token_dict,
368-
):
368+
):
369369
for sample, generated_tokens in gen_token_dict.items():
370370
for s in generated_tokens:
371371
if INFILL_MODE or tokenizer.eos_token in task.stop_words:
@@ -378,6 +378,13 @@ def update_code_gens(
378378
gen_code = tokenizer.decode(
379379
s, skip_special_tokens=False, clean_up_tokenization_spaces=False
380380
)
381+
try:
382+
# some tokenizers add a multi-token prefix to the generation (e.g ChatGLM)
383+
tokenizer_prefix = tokenizer.decode(tokenizer.get_prefix_tokens())
384+
if gen_code.startswith(f"{tokenizer_prefix}"):
385+
gen_code = gen_code[len(tokenizer_prefix):].lstrip()
386+
except:
387+
pass
381388
if INFILL_MODE:
382389
gen_code = _parse_infill(gen_code, tokenizer)
383390
if INSTRUCTION_MODE:

main.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,11 @@ def parse_args():
109109
action="store_true",
110110
help="Load model in 4bit",
111111
)
112+
parser.add_argument(
113+
"--left_padding",
114+
action="store_true",
115+
help="Force left padding, needed for models like chatglm3-6b",
116+
)
112117
parser.add_argument(
113118
"--limit",
114119
type=int,
@@ -311,14 +316,25 @@ def main():
311316
model.merge_and_unload()
312317
print("Merge complete.")
313318

314-
tokenizer = AutoTokenizer.from_pretrained(
315-
args.model,
316-
revision=args.revision,
317-
trust_remote_code=args.trust_remote_code,
318-
use_auth_token=args.use_auth_token,
319-
truncation_side="left",
320-
padding_side="right", # padding on the right is needed to cut off padding in `complete_code`
321-
)
319+
if args.left_padding:
320+
# left padding is required for some models like chatglm3-6b
321+
tokenizer = AutoTokenizer.from_pretrained(
322+
args.model,
323+
revision=args.revision,
324+
trust_remote_code=args.trust_remote_code,
325+
use_auth_token=args.use_auth_token,
326+
padding_side="left",
327+
)
328+
else:
329+
# used by default for most models
330+
tokenizer = AutoTokenizer.from_pretrained(
331+
args.model,
332+
revision=args.revision,
333+
trust_remote_code=args.trust_remote_code,
334+
use_auth_token=args.use_auth_token,
335+
truncation_side="left",
336+
padding_side="right",
337+
)
322338
if not tokenizer.eos_token:
323339
if tokenizer.bos_token:
324340
tokenizer.eos_token = tokenizer.bos_token

0 commit comments

Comments
 (0)