Skip to content

Commit 2d5991a

Browse files
committed
Update
1 parent ba34e71 commit 2d5991a

File tree

4 files changed

+5
-489
lines changed

4 files changed

+5
-489
lines changed

bigcode_eval/generation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __call__(self, input_ids, scores, **kwargs):
3838
"""Returns true if generated sequence is too long."""
3939
return input_ids.shape[1] > int(self.input_length * self.multiplier)
4040

41+
4142
def parallel_generations(
4243
task,
4344
dataset,

bigcode_eval/tasks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from . import (apps, codexglue_code_to_text, codexglue_text_to_text, conala,
55
concode, ds1000, gsm, humaneval, humanevalplus, humanevalpack,
66
instruct_humaneval, instruct_wizard_humaneval, mbpp, mbppplus,
7-
multiple, parity, python_bugs, quixbugs, recode, santacoder_fim, mercury)
7+
multiple, parity, python_bugs, quixbugs, recode, santacoder_fim,
8+
mercury)
89

910
TASK_REGISTRY = {
1011
**apps.create_all_tasks(),

bigcode_eval/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def __iter__(self):
120120
"n_copies (n_samples/batch_size) was changed from 1 to 2 because n_tasks isn't proportional to num devices"
121121
)
122122

123-
for sample in tqdm(range(self.n_tasks), desc="Task Encoding"):
123+
for sample in range(self.n_tasks):
124124
for _ in range(self.n_copies):
125125
if self.has_encoder:
126126
yield {
@@ -220,6 +220,7 @@ def _parse_instruction(code, instruction_tokens):
220220
shift = len("```python")
221221
return code[idx + shift :]
222222

223+
223224
def complete_code(
224225
task,
225226
accelerator,
@@ -248,13 +249,11 @@ def complete_code(
248249
code_gens: List[List[Optional[str]]] = [[] for _ in range(n_tasks)]
249250
generations = [] if not intermediate_generations else intermediate_generations
250251
gen_token_dict = defaultdict(list) # dict of list of generated tokens
251-
252252
for step, batch in tqdm(
253253
enumerate(dataloader),
254254
total=math.ceil(
255255
n_tasks * dataloader.dataset.n_copies / accelerator.num_processes
256256
),
257-
desc="batch generation",
258257
):
259258
with torch.no_grad():
260259
if task.stop_words:

0 commit comments

Comments
 (0)