Skip to content

Commit 9d17482

Browse files
authored
Merge pull request #164 from bigcode-project/max/santacoder-fim
SantaCoder FIM task
2 parents 56ec144 + 8613c5c commit 9d17482

File tree

10 files changed

+178
-71
lines changed

10 files changed

+178
-71
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ Below are the features and tasks of this framework:
3535
- [CoNaLa](https://huggingface.co/datasets/neulab/conala) for **Python** code generation (2-shot setting and evaluation with BLEU score).
3636
- [Concode](https://huggingface.co/datasets/code_x_glue_tc_text_to_code) for **Java** code generation (2-shot setting and evaluation with BLEU score).
3737
- 3 multilingual downstream classification tasks: [Java Complexity prediction](https://huggingface.co/datasets/codeparrot/codecomplex), [Java code equivalence prediction](https://huggingface.co/datasets/code_x_glue_cc_clone_detection_big_clone_bench), [C code defect prediction](https://huggingface.co/datasets/code_x_glue_cc_defect_detection).
38+
- [SantaCoder-FIM](https://huggingface.co/datasets/bigcode/santacoder-fim-task) for evaluating FIM on **Python** code using Exact Match. Further details are described in [SantaCoder](https://arxiv.org/abs/2301.03988). Includes two tasks:
39+
- `StarCoderFIM`: which uses the default FIM tokens `"<fim_prefix>", "<fim_middle>", "<fim_suffix>"`, and
40+
- `SantaCoderFIM`: which uses SantaCoder FIM tokens `"<fim-prefix>", "<fim-middle>", "<fim-suffix>"`
3841

3942
More details about each task can be found in the documentation in [`docs/README.md`](https://github.com/bigcode-project/bigcode-evaluation-harness/blob/main/docs/README.md).
4043
## Setup

bigcode_eval/base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,19 @@ def process_results(self, generations, references):
7777
:return: dict[str: float]
7878
"""
7979
pass
80+
81+
@staticmethod
82+
def _stop_at_stop_token(decoded_string, stop_tokens):
83+
"""
84+
Produces the prefix of decoded_string that ends at the first occurrence of
85+
a stop_token.
86+
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
87+
itself.
88+
"""
89+
min_stop_index = len(decoded_string)
90+
for stop_token in stop_tokens:
91+
stop_index = decoded_string.find(stop_token)
92+
if stop_index != -1 and stop_index < min_stop_index:
93+
min_stop_index = stop_index
94+
return decoded_string[:min_stop_index]
95+

bigcode_eval/tasks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from . import (apps, codexglue_code_to_text, codexglue_text_to_text, conala,
55
concode, ds1000, gsm, humaneval, humanevalpack,
66
instruct_humaneval, instruct_wizard_humaneval, mbpp, multiple,
7-
parity, python_bugs, quixbugs, recode)
7+
parity, python_bugs, quixbugs, recode, santacoder_fim)
88

99
TASK_REGISTRY = {
1010
**apps.create_all_tasks(),
@@ -25,6 +25,7 @@
2525
**gsm.create_all_tasks(),
2626
**instruct_humaneval.create_all_tasks(),
2727
**recode.create_all_tasks(),
28+
**santacoder_fim.create_all_tasks(),
2829
}
2930

3031
ALL_TASKS = sorted(list(TASK_REGISTRY))

bigcode_eval/tasks/humaneval.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,20 +74,6 @@ def get_reference(self, doc):
7474
entry_point = f"check({doc['entry_point']})"
7575
return "\n" + test_func + "\n" + entry_point
7676

77-
@staticmethod
78-
def _stop_at_stop_token(decoded_string, stop_tokens):
79-
"""
80-
Produces the prefix of decoded_string that ends at the first occurrence of
81-
a stop_token.
82-
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
83-
itself.
84-
"""
85-
min_stop_index = len(decoded_string)
86-
for stop_token in stop_tokens:
87-
stop_index = decoded_string.find(stop_token)
88-
if stop_index != -1 and stop_index < min_stop_index:
89-
min_stop_index = stop_index
90-
return decoded_string[:min_stop_index]
9177

9278
def postprocess_generation(self, generation, idx):
9379
"""Defines the postprocessing for a LM generation.

bigcode_eval/tasks/instruct_humaneval.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,6 @@ def get_reference(self, doc):
5555
entry_point = f"check({doc['entry_point']})"
5656
return "\n" + test_func + "\n" + entry_point
5757

58-
@staticmethod
59-
def _stop_at_stop_token(decoded_string, stop_tokens):
60-
"""
61-
Produces the prefix of decoded_string that ends at the first occurrence of
62-
a stop_token.
63-
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
64-
itself.
65-
"""
66-
min_stop_index = len(decoded_string)
67-
for stop_token in stop_tokens:
68-
stop_index = decoded_string.find(stop_token)
69-
if stop_index != -1 and stop_index < min_stop_index:
70-
min_stop_index = stop_index
71-
return decoded_string[:min_stop_index]
7258

7359
def process_results(self, generations, references):
7460
"""Takes the list of LM generations and evaluates them against ground truth references,

bigcode_eval/tasks/mbpp.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -59,20 +59,6 @@ def get_reference(self, doc):
5959
"""Builds the reference solution for the doc (sample from the test dataset)."""
6060
return "\n".join(doc["test_list"])
6161

62-
@staticmethod
63-
def _stop_at_stop_token(decoded_string, stop_tokens):
64-
"""
65-
Produces the prefix of decoded_string that ends at the first occurrence of
66-
a stop_token.
67-
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
68-
itself.
69-
"""
70-
min_stop_index = len(decoded_string)
71-
for stop_token in stop_tokens:
72-
stop_index = decoded_string.find(stop_token)
73-
if stop_index != -1 and stop_index < min_stop_index:
74-
min_stop_index = stop_index
75-
return decoded_string[:min_stop_index]
7662

7763
def postprocess_generation(self, generation, idx):
7864
"""Defines the postprocessing for a LM generation.

bigcode_eval/tasks/multiple.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,20 +115,6 @@ def remove_last_block(string, stop_words):
115115
# last string should be ""
116116
return "".join(string_list[:-2])
117117

118-
@staticmethod
119-
def _stop_at_stop_token(decoded_string, stop_tokens):
120-
"""
121-
Produces the prefix of decoded_string that ends at the first occurrence of
122-
a stop_token.
123-
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
124-
itself.
125-
"""
126-
min_stop_index = len(decoded_string)
127-
for stop_token in stop_tokens:
128-
stop_index = decoded_string.find(stop_token)
129-
if stop_index != -1 and stop_index < min_stop_index:
130-
min_stop_index = stop_index
131-
return decoded_string[:min_stop_index]
132118

133119
def postprocess_generation(self, generation, idx):
134120
"""Defines the postprocessing for a LM generation.

bigcode_eval/tasks/recode.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,6 @@ def get_reference(self, doc):
9696
"test_code": test_code,
9797
}
9898

99-
@staticmethod
100-
def _stop_at_stop_token(decoded_string, stop_tokens):
101-
"""
102-
Produces the prefix of decoded_string that ends at the first occurrence of
103-
a stop_token.
104-
WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
105-
itself.
106-
"""
107-
min_stop_index = len(decoded_string)
108-
for stop_token in stop_tokens:
109-
stop_index = decoded_string.find(stop_token)
110-
if stop_index != -1 and stop_index < min_stop_index:
111-
min_stop_index = stop_index
112-
return decoded_string[:min_stop_index]
11399

114100
def postprocess_generation(self, generation, idx):
115101
"""
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import Dict, List
2+
3+
from tqdm import tqdm
4+
5+
from bigcode_eval.base import Task
6+
7+
_CITATION = """
8+
@article{allal2023santacoder,
9+
title={SantaCoder: don't reach for the stars!},
10+
author={Allal, Loubna Ben and Li, Raymond and Kocetkov, Denis and Mou, Chenghao and Akiki, Christopher and Ferrandis, Carlos Munoz and Muennighoff, Niklas and Mishra, Mayank and Gu, Alex and Dey, Manan and others},
11+
journal={arXiv preprint arXiv:2301.03988},
12+
year={2023}
13+
}
14+
"""
15+
16+
LANGUAGES = [
17+
"py",
18+
"js",
19+
"java",
20+
]
21+
22+
23+
def create_all_tasks():
24+
return {
25+
"santacoder_fim": SantaCoderFIM,
26+
"starcoder_fim": StarCoderFIM,
27+
}
28+
29+
30+
def initialize_empty_metrics(languages: List[str]) -> Dict[str, float]:
31+
metrics = {}
32+
for lang in languages:
33+
metrics[f"n_accurate_{lang}"] = 0.0
34+
metrics[f"n_count_{lang}"] = 0.0
35+
return metrics
36+
37+
38+
def aggregate_per_lang_accuracy(
39+
metrics: Dict[str, float], languages: List[str]
40+
) -> Dict[str, float]:
41+
em_metrics = {}
42+
for lang in languages:
43+
# avoid div by 0
44+
acc = (
45+
metrics[f"n_accurate_{lang}"] / metrics[f"n_count_{lang}"]
46+
if metrics[f"n_count_{lang}"]
47+
else 0
48+
)
49+
em_metrics[f"{lang} Exact Match"] = acc
50+
51+
return em_metrics
52+
53+
54+
class SantaCoderFIM(Task):
55+
DATASET_PATH = "bigcode/santacoder-fim-task"
56+
57+
def __init__(
58+
self,
59+
fim_prefix: str = "<fim-prefix>",
60+
fim_middle: str = "<fim-middle>",
61+
fim_suffix: str = "<fim-suffix>",
62+
):
63+
stop_words = ["<|endoftext|>", "<|filename|>"]
64+
super().__init__(
65+
stop_words=stop_words,
66+
requires_execution=False,
67+
)
68+
self.fim_prefix = fim_prefix
69+
self.fim_middle = fim_middle
70+
self.fim_suffix = fim_suffix
71+
72+
def get_dataset(self):
73+
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
74+
dataset = self.dataset["train"]
75+
return dataset
76+
77+
def get_prompt(self, doc):
78+
"""Builds the prompt for the LM to generate from."""
79+
return f"""{self.fim_prefix}{doc["prompt"]}{self.fim_suffix}{doc["suffix"]}{self.fim_middle}"""
80+
81+
def get_reference(self, doc):
82+
"""Builds the reference solution for the doc (sample from the test dataset)."""
83+
return doc["canonical_solution"]
84+
85+
def postprocess_generation(self, generation, idx):
86+
"""Defines the postprocessing for a LM generation.
87+
:param generation: str
88+
code generation from LM
89+
:param idx: int
90+
index of doc in the dataset to which the generation belongs
91+
"""
92+
doc = self.get_dataset()[idx]
93+
prompt = self.get_prompt(doc)
94+
output = generation[len(prompt) :]
95+
return self._stop_at_stop_token(output, self.stop_words)
96+
# return generation
97+
98+
def process_results(self, generations, references):
99+
"""Takes the list of LM generations and evaluates them against ground truth references,
100+
returning the metric for the generations as in {"metric_name": result}.
101+
:param generations: list(list(str))
102+
list of lists containing generations
103+
:param references: list(str)
104+
list of str containing refrences
105+
:return: dict[str: float]
106+
"""
107+
metrics = initialize_empty_metrics(LANGUAGES)
108+
for idx, (gen, reference) in tqdm(enumerate(zip(generations, references))):
109+
language = self.get_dataset()[idx]["language"]
110+
for g in gen:
111+
metrics[f"n_accurate_{language}"] += int(g.strip() == reference.strip())
112+
113+
metrics[f"n_count_{language}"] += len(gen)
114+
115+
em_metrics = aggregate_per_lang_accuracy(metrics, LANGUAGES)
116+
117+
return em_metrics
118+
119+
120+
class StarCoderFIM(SantaCoderFIM):
121+
DATASET_PATH = "bigcode/santacoder-fim-task"
122+
123+
def __init__(self):
124+
fim_prefix = "<fim_prefix>"
125+
fim_middle = "<fim_middle>"
126+
fim_suffix = "<fim_suffix>"
127+
stop_words = ["<|endoftext|>", "<|filename|>"]
128+
super().__init__(
129+
stop_words=stop_words,
130+
requires_execution=False,
131+
fim_prefix=fim_prefix,
132+
fim_middle=fim_middle,
133+
fim_suffix=fim_suffix,
134+
)

docs/README.md

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,29 @@ accelerate launch main.py \
357357
```
358358
If you ever get index out-of-range errors try using a number of problems `limit` that is proportional to the number of devices you are using.
359359

360+
### SantaCoder-FIM
361+
[SantaCoder-FIM](https://huggingface.co/datasets/bigcode/santacoder-fim-task): 4,792 tasks for FIM insertion described in [SantaCoder: don't reach for the stars!](https://arxiv.org/abs/2301.03988). The tasks are similar to other tasks without unit tests, with two key differences:
362+
1. Instead of BLEU Score, Exact Match is used to score the generations.
363+
2. Use zero-shot setting instead of 2-shot
364+
365+
SantaCoder-FIM includes 2 tasks:
366+
- `StarCoderFIM`: which uses the default FIM tokens `"<fim_prefix>", "<fim_middle>", "<fim_suffix>"`, and
367+
- `SantaCoderFIM`: which uses SantaCoder FIM tokens `"<fim-prefix>", "<fim-middle>", "<fim-suffix>"`
368+
So depending on the FIM tokens used to train the model, you will need to select the appropriate task for evaluation.
369+
370+
We only do single generation `n_samples=1`, and use the same generation settings as before.
371+
Below are the commands to run the evaluation:
372+
```python
373+
accelerate launch main.py \
374+
--model <MODEL_NAME> \
375+
--max_length_generation <MAX_LENGTH> \
376+
--tasks <TASK> \
377+
--n_samples 1 \
378+
--temperature 0.2 \
379+
--batch_size 1
380+
```
381+
If you ever get index out-of-range errors try using a number of problems `limit` that is proportional to the number of devices you are using.
382+
360383
## Documentation generation task
361384
Code to text task from [CodeXGLUE](https://huggingface.co/datasets/code_x_glue_ct_code_to_text): is a benchmark for English documentation generation from for 6 programming languages: Python, Go, Ruby, Java, JavaScript and PHP.
362385

0 commit comments

Comments
 (0)