|
| 1 | +from typing import Dict, List |
| 2 | + |
| 3 | +from tqdm import tqdm |
| 4 | + |
| 5 | +from bigcode_eval.base import Task |
| 6 | + |
| 7 | +_CITATION = """ |
| 8 | +@article{allal2023santacoder, |
| 9 | + title={SantaCoder: don't reach for the stars!}, |
| 10 | + author={Allal, Loubna Ben and Li, Raymond and Kocetkov, Denis and Mou, Chenghao and Akiki, Christopher and Ferrandis, Carlos Munoz and Muennighoff, Niklas and Mishra, Mayank and Gu, Alex and Dey, Manan and others}, |
| 11 | + journal={arXiv preprint arXiv:2301.03988}, |
| 12 | + year={2023} |
| 13 | +} |
| 14 | +""" |
| 15 | + |
| 16 | +LANGUAGES = [ |
| 17 | + "py", |
| 18 | + "js", |
| 19 | + "java", |
| 20 | +] |
| 21 | + |
| 22 | + |
| 23 | +def create_all_tasks(): |
| 24 | + return { |
| 25 | + "santacoder_fim": SantaCoderFIM, |
| 26 | + "starcoder_fim": StarCoderFIM, |
| 27 | + } |
| 28 | + |
| 29 | + |
| 30 | +def initialize_empty_metrics(languages: List[str]) -> Dict[str, float]: |
| 31 | + metrics = {} |
| 32 | + for lang in languages: |
| 33 | + metrics[f"n_accurate_{lang}"] = 0.0 |
| 34 | + metrics[f"n_count_{lang}"] = 0.0 |
| 35 | + return metrics |
| 36 | + |
| 37 | + |
| 38 | +def aggregate_per_lang_accuracy( |
| 39 | + metrics: Dict[str, float], languages: List[str] |
| 40 | +) -> Dict[str, float]: |
| 41 | + em_metrics = {} |
| 42 | + for lang in languages: |
| 43 | + # avoid div by 0 |
| 44 | + acc = ( |
| 45 | + metrics[f"n_accurate_{lang}"] / metrics[f"n_count_{lang}"] |
| 46 | + if metrics[f"n_count_{lang}"] |
| 47 | + else 0 |
| 48 | + ) |
| 49 | + em_metrics[f"{lang} Exact Match"] = acc |
| 50 | + |
| 51 | + return em_metrics |
| 52 | + |
| 53 | + |
| 54 | +class SantaCoderFIM(Task): |
| 55 | + DATASET_PATH = "bigcode/santacoder-fim-task" |
| 56 | + |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + fim_prefix: str = "<fim-prefix>", |
| 60 | + fim_middle: str = "<fim-middle>", |
| 61 | + fim_suffix: str = "<fim-suffix>", |
| 62 | + ): |
| 63 | + stop_words = ["<|endoftext|>", "<|filename|>"] |
| 64 | + super().__init__( |
| 65 | + stop_words=stop_words, |
| 66 | + requires_execution=False, |
| 67 | + ) |
| 68 | + self.fim_prefix = fim_prefix |
| 69 | + self.fim_middle = fim_middle |
| 70 | + self.fim_suffix = fim_suffix |
| 71 | + |
| 72 | + def get_dataset(self): |
| 73 | + """Returns dataset for the task or an iterable of any object, that get_prompt can handle""" |
| 74 | + dataset = self.dataset["train"] |
| 75 | + return dataset |
| 76 | + |
| 77 | + def get_prompt(self, doc): |
| 78 | + """Builds the prompt for the LM to generate from.""" |
| 79 | + return f"""{self.fim_prefix}{doc["prompt"]}{self.fim_suffix}{doc["suffix"]}{self.fim_middle}""" |
| 80 | + |
| 81 | + def get_reference(self, doc): |
| 82 | + """Builds the reference solution for the doc (sample from the test dataset).""" |
| 83 | + return doc["canonical_solution"] |
| 84 | + |
| 85 | + def postprocess_generation(self, generation, idx): |
| 86 | + """Defines the postprocessing for a LM generation. |
| 87 | + :param generation: str |
| 88 | + code generation from LM |
| 89 | + :param idx: int |
| 90 | + index of doc in the dataset to which the generation belongs |
| 91 | + """ |
| 92 | + doc = self.get_dataset()[idx] |
| 93 | + prompt = self.get_prompt(doc) |
| 94 | + output = generation[len(prompt) :] |
| 95 | + return self._stop_at_stop_token(output, self.stop_words) |
| 96 | + # return generation |
| 97 | + |
| 98 | + def process_results(self, generations, references): |
| 99 | + """Takes the list of LM generations and evaluates them against ground truth references, |
| 100 | + returning the metric for the generations as in {"metric_name": result}. |
| 101 | + :param generations: list(list(str)) |
| 102 | + list of lists containing generations |
| 103 | + :param references: list(str) |
| 104 | + list of str containing refrences |
| 105 | + :return: dict[str: float] |
| 106 | + """ |
| 107 | + metrics = initialize_empty_metrics(LANGUAGES) |
| 108 | + for idx, (gen, reference) in tqdm(enumerate(zip(generations, references))): |
| 109 | + language = self.get_dataset()[idx]["language"] |
| 110 | + for g in gen: |
| 111 | + metrics[f"n_accurate_{language}"] += int(g.strip() == reference.strip()) |
| 112 | + |
| 113 | + metrics[f"n_count_{language}"] += len(gen) |
| 114 | + |
| 115 | + em_metrics = aggregate_per_lang_accuracy(metrics, LANGUAGES) |
| 116 | + |
| 117 | + return em_metrics |
| 118 | + |
| 119 | + |
| 120 | +class StarCoderFIM(SantaCoderFIM): |
| 121 | + DATASET_PATH = "bigcode/santacoder-fim-task" |
| 122 | + |
| 123 | + def __init__(self): |
| 124 | + fim_prefix = "<fim_prefix>" |
| 125 | + fim_middle = "<fim_middle>" |
| 126 | + fim_suffix = "<fim_suffix>" |
| 127 | + stop_words = ["<|endoftext|>", "<|filename|>"] |
| 128 | + super().__init__( |
| 129 | + stop_words=stop_words, |
| 130 | + requires_execution=False, |
| 131 | + fim_prefix=fim_prefix, |
| 132 | + fim_middle=fim_middle, |
| 133 | + fim_suffix=fim_suffix, |
| 134 | + ) |
0 commit comments