|
29 | 29 | TEAM_ALLOCATION_HINT, |
30 | 30 | ) |
31 | 31 |
|
| 32 | +DomainType = Literal["murder_mysteries", "object_placements", "team_allocation"] |
| 33 | +PromptTechniqueType = Literal["regular", "cot", "cot+"] |
| 34 | + |
| 35 | +DEFAULT_DOMAIN: DomainType = "murder_mysteries" |
| 36 | +DEFAULT_PROMPT_TECHNIQUE: PromptTechniqueType = "regular" |
| 37 | +DEFAULT_EXAMPLE_COUNT = 0 |
| 38 | + |
32 | 39 |
|
33 | 40 | @task |
34 | 41 | def musr( |
35 | | - domain: Literal[ |
36 | | - "murder_mysteries", "object_placements", "team_allocation" |
37 | | - ] = "murder_mysteries", |
38 | | - prompt_technique: Literal["regular", "cot", "cot+"] = "regular", |
39 | | - example_count: int = 0, |
| 42 | + domain: DomainType = DEFAULT_DOMAIN, |
| 43 | + prompt_technique: PromptTechniqueType = DEFAULT_PROMPT_TECHNIQUE, |
| 44 | + example_count: int = DEFAULT_EXAMPLE_COUNT, |
40 | 45 | ) -> Task: |
41 | 46 | """Inspect task implementing the MuSR benchmark. |
42 | 47 |
|
@@ -68,9 +73,9 @@ def musr( |
68 | 73 |
|
69 | 74 |
|
70 | 75 | def get_domain_prompt( |
71 | | - domain: Literal["murder_mysteries", "object_placements", "team_allocation"], |
72 | | - prompt_technique: Literal["regular", "cot", "cot+"], |
73 | | - example_count: int, |
| 76 | + domain: DomainType = DEFAULT_DOMAIN, |
| 77 | + prompt_technique: PromptTechniqueType = DEFAULT_PROMPT_TECHNIQUE, |
| 78 | + example_count: int = DEFAULT_EXAMPLE_COUNT, |
74 | 79 | ) -> str: |
75 | 80 | domain_info = { |
76 | 81 | "murder_mysteries": { |
|
0 commit comments