Skip to content

Commit 5bbeb1f

Browse files
[MGSM] Enable auto_id, add language tag, consistent shuffling (take 2) (rllm-org#136)
* [MGSM] Enable auto_id, add language tag, consistent shuffling Fixes rllm-org#92 * unique sample_id
1 parent 3e73bf9 commit 5bbeb1f

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

src/inspect_evals/mgsm/mgsm.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from time import time
12
from typing import List
23

34
from inspect_ai import Task, task
@@ -68,6 +69,7 @@ def load_mgsm_dataset(
6869
languages = ALL_LANGUAGES
6970

7071
samples = []
72+
seed = int(time()) # use same shuffling seed for all languages
7173

7274
for lang in languages:
7375
csv_filename = LANG_TO_FPATH[lang]
@@ -77,14 +79,18 @@ def load_mgsm_dataset(
7779
dialect="excel-tab",
7880
limit=limit_samples_per_lang,
7981
shuffle=shuffle,
82+
seed=seed,
83+
auto_id=True,
8084
delimiter="\t",
8185
)
8286

8387
lang_samples = lang_dataset.samples # type: ignore
88+
cot_template = LANG_TO_INSTRUCTIONS[lang]
8489

85-
if use_cot:
86-
cot_template = LANG_TO_INSTRUCTIONS[lang]
87-
for sample in lang_samples:
90+
for sample in lang_samples:
91+
sample.metadata = {"language": lang, "question_id": sample.id}
92+
sample.id = f"{lang}_{sample.id}"
93+
if use_cot:
8894
cot_prompt = cot_template.format(prompt=sample.input)
8995
sample.input = cot_prompt
9096

@@ -105,6 +111,7 @@ def mgsm(
105111
languages=languages,
106112
limit_samples_per_lang=limit_samples_per_lang,
107113
use_cot=use_cot,
114+
shuffle=shuffle,
108115
)
109116

110117
task = Task(

0 commit comments

Comments
 (0)