|
30 | 30 | ) |
31 | 31 |
|
32 | 32 |
|
| 33 | +@task |
| 34 | +def musr( |
| 35 | + domain: Literal[ |
| 36 | + "murder_mysteries", "object_placements", "team_allocation" |
| 37 | + ] = "murder_mysteries", |
| 38 | + prompt_technique: Literal["regular", "cot", "cot+"] = "regular", |
| 39 | + example_count: int = 0, |
| 40 | +) -> Task: |
| 41 | + """Inspect task implementing the MuSR benchmark. |
| 42 | +
|
| 43 | + Args: |
| 44 | + domain (Literal["murder_mysteries", "object_placements", "team_allocation"]): Which domain in the dataset to evaluate. |
| 45 | + Defaults to "murder_mysteries". |
| 46 | + prompt_technique (Literal["regular", "cot", "cot+"]): The prompt technique to use. "regular" includes only the narrative |
| 47 | + and the question. "cot" uses chain-of-thought prompting. "cot+" includes a hint. Defaults to "regular". |
| 48 | + example_count (int): Number of solved examples to include at the beginning of each prompt. Defaults to 0. Currently only supports 1 example. |
| 49 | + """ |
| 50 | + prompt = get_domain_prompt(domain, prompt_technique, example_count) |
| 51 | + |
| 52 | + dataset = hf_dataset( |
| 53 | + path="TAUR-Lab/MuSR", |
| 54 | + split=domain, |
| 55 | + sample_fields=record_to_sample, |
| 56 | + shuffle=True, |
| 57 | + auto_id=True, |
| 58 | + ) |
| 59 | + |
| 60 | + return Task( |
| 61 | + dataset=dataset, |
| 62 | + solver=[ |
| 63 | + system_message(SYSTEM_PROMPT), |
| 64 | + multiple_choice(template=prompt), |
| 65 | + ], |
| 66 | + scorer=choice(), |
| 67 | + ) |
| 68 | + |
| 69 | + |
33 | 70 | def get_domain_prompt( |
34 | 71 | domain: Literal["murder_mysteries", "object_placements", "team_allocation"], |
35 | 72 | prompt_technique: Literal["regular", "cot", "cot+"], |
@@ -80,40 +117,3 @@ def record_to_sample(record: Dict[str, Any]) -> Sample: |
80 | 117 | choices=ast.literal_eval(record["choices"]), |
81 | 118 | target=chr(ord("A") + int(record["answer_index"])), |
82 | 119 | ) |
83 | | - |
84 | | - |
85 | | -@task |
86 | | -def musr( |
87 | | - domain: Literal[ |
88 | | - "murder_mysteries", "object_placements", "team_allocation" |
89 | | - ] = "murder_mysteries", |
90 | | - prompt_technique: Literal["regular", "cot", "cot+"] = "regular", |
91 | | - example_count: int = 0, |
92 | | -) -> Task: |
93 | | - """Inspect task implementing the MuSR benchmark. |
94 | | -
|
95 | | - Args: |
96 | | - domain (Literal["murder_mysteries", "object_placements", "team_allocation"]): Which domain in the dataset to evaluate. |
97 | | - Defaults to "murder_mysteries". |
98 | | - prompt_technique (Literal["regular", "cot", "cot+"]): The prompt technique to use. "regular" includes only the narrative |
99 | | - and the question. "cot" uses chain-of-thought prompting. "cot+" includes a hint. Defaults to "regular". |
100 | | - example_count (int): Number of solved examples to include at the beginning of each prompt. Defaults to 0. Currently only supports 1 example. |
101 | | - """ |
102 | | - prompt = get_domain_prompt(domain, prompt_technique, example_count) |
103 | | - |
104 | | - dataset = hf_dataset( |
105 | | - path="TAUR-Lab/MuSR", |
106 | | - split=domain, |
107 | | - sample_fields=record_to_sample, |
108 | | - shuffle=True, |
109 | | - auto_id=True, |
110 | | - ) |
111 | | - |
112 | | - return Task( |
113 | | - dataset=dataset, |
114 | | - solver=[ |
115 | | - system_message(SYSTEM_PROMPT), |
116 | | - multiple_choice(template=prompt), |
117 | | - ], |
118 | | - scorer=choice(), |
119 | | - ) |
0 commit comments