Skip to content

Commit 4e13fdb

Browse files
Dedupe types, share defaults
1 parent 1c99383 commit 4e13fdb

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/inspect_evals/musr/musr.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,19 @@
2929
TEAM_ALLOCATION_HINT,
3030
)
3131

32+
DomainType = Literal["murder_mysteries", "object_placements", "team_allocation"]
33+
PromptTechniqueType = Literal["regular", "cot", "cot+"]
34+
35+
DEFAULT_DOMAIN: DomainType = "murder_mysteries"
36+
DEFAULT_PROMPT_TECHNIQUE: PromptTechniqueType = "regular"
37+
DEFAULT_EXAMPLE_COUNT = 0
38+
3239

3340
@task
3441
def musr(
35-
domain: Literal[
36-
"murder_mysteries", "object_placements", "team_allocation"
37-
] = "murder_mysteries",
38-
prompt_technique: Literal["regular", "cot", "cot+"] = "regular",
39-
example_count: int = 0,
42+
domain: DomainType = DEFAULT_DOMAIN,
43+
prompt_technique: PromptTechniqueType = DEFAULT_PROMPT_TECHNIQUE,
44+
example_count: int = DEFAULT_EXAMPLE_COUNT,
4045
) -> Task:
4146
"""Inspect task implementing the MuSR benchmark.
4247
@@ -68,9 +73,9 @@ def musr(
6873

6974

7075
def get_domain_prompt(
71-
domain: Literal["murder_mysteries", "object_placements", "team_allocation"],
72-
prompt_technique: Literal["regular", "cot", "cot+"],
73-
example_count: int,
76+
domain: DomainType = DEFAULT_DOMAIN,
77+
prompt_technique: PromptTechniqueType = DEFAULT_PROMPT_TECHNIQUE,
78+
example_count: int = DEFAULT_EXAMPLE_COUNT,
7479
) -> str:
7580
domain_info = {
7681
"murder_mysteries": {

0 commit comments

Comments
 (0)