Skip to content

Commit 13cc10f

Browse files
authored
Chat tasks in Agentharm (rllm-org#140)
* Adding chat task support for AgentHarm * Adding doc string arg for chat bool
1 parent 6094b3d commit 13cc10f

File tree

3 files changed

+24
-9
lines changed

3 files changed

+24
-9
lines changed

src/inspect_evals/agentharm/agentharm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def agentharm(
7373
refusal_judge: str = "openai/gpt-4o-2024-08-06",
7474
semantic_judge: str = "openai/gpt-4o-2024-08-06",
7575
agent_kwargs: Optional[Dict[str, Any]] = None,
76+
chat_dataset: Optional[bool] = False,
7677
) -> Task:
7778
"""
7879
AgentHarm benchmark task.
@@ -91,11 +92,12 @@ def agentharm(
9192
- user_prompt_template (str): Name of user prompt in `prompts.py`. Default is 'empty'.
9293
- system_prompt_template (str): Name of system prompt in `prompts.py`. Default is None.
9394
- prompt_technique (Literal['standard', 'cot', 'react', 'refusal']): Prompting technique to apply to the system prompt. Default is 'standard'.
95+
chat_dataset (bool): If true, evaluate on the chat dataset.
9496
"""
9597
return run_agent_benchmark(
9698
agent=agent,
9799
behavior_ids=behavior_ids,
98-
task_name="harmful",
100+
task_name="harmful" if not chat_dataset else "chat",
99101
split=split,
100102
detailed_behaviors=detailed_behaviors,
101103
hint_included=hint_included,
@@ -109,7 +111,7 @@ def agentharm(
109111
def run_agent_benchmark(
110112
agent: str,
111113
behavior_ids: list[str] = [],
112-
task_name: Literal["harmful", "benign"] = "harmful",
114+
task_name: Literal["harmful", "benign", "chat"] = "harmful",
113115
split: Literal["val", "test_private", "test_public"] = "test_public",
114116
detailed_behaviors: Optional[bool] = None,
115117
hint_included: Optional[bool] = None,
@@ -140,7 +142,7 @@ def run_agent_benchmark(
140142
"""
141143
dataset = load_dataset(task_name, split)
142144
dataset = filter_dataset(dataset, behavior_ids, detailed_behaviors, hint_included)
143-
all_tool_names = get_benchmark_tools(task_name)
145+
all_tool_names = get_benchmark_tools(task_name) if task_name != "chat" else []
144146

145147
agent_solver = get_agent(agent=agent, **agent_kwargs if agent_kwargs else {})
146148

@@ -159,6 +161,8 @@ def run_agent_benchmark(
159161
harmful_grading_functions
160162
if task_name == "harmful"
161163
else benign_grading_functions
164+
if task_name == "benign"
165+
else None
162166
)
163167

164168
return Task(

src/inspect_evals/agentharm/scorer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,14 @@ async def compute_score(state: TaskState, target: Target) -> Score:
8282
grading_function_name = state.metadata["grading_function"]
8383
state.metadata["semantic_judge"] = semantic_judge
8484
# this is where scoring happens
85-
score_dict = await getattr(grading_module, grading_function_name)(
86-
state.metadata, state.messages
87-
)
88-
score = np.mean(list(score_dict.values()))
85+
if grading_module is None:
86+
score = 0.0
87+
score_dict = {}
88+
else:
89+
score_dict = await getattr(grading_module, grading_function_name)(
90+
state.metadata, state.messages
91+
)
92+
score = np.mean(list(score_dict.values()))
8993

9094
refusal, refusal_explanation = False, "No refusal messages were detected."
9195
state.metadata["refusal_explanations"] = []

src/inspect_evals/agentharm/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121

2222
def load_dataset(
23-
name: Literal["harmful", "benign"],
23+
name: Literal["harmful", "benign", "chat"],
2424
split: Literal["val", "test_public", "test_private"],
2525
) -> Dataset:
2626
"""
@@ -61,6 +61,11 @@ def load_dataset(
6161
"test_public": "benchmark/benign_behaviors_test_public.json",
6262
"test_private": "benchmark/benign_behaviors_test_private.json",
6363
},
64+
"chat": {
65+
"val": "benchmark/chat_validation.json",
66+
"test_public": "benchmark/chat_public_test.json",
67+
"test_private": "benchmark/chat_private_test.json",
68+
},
6469
}
6570

6671
if name not in split_dict:
@@ -183,7 +188,7 @@ def get_benchmark_tools(task_name: Literal["harmful", "benign"]) -> List[str]:
183188

184189
@solver
185190
def setup_tools_from_metadata(
186-
task_name: Literal["harmful", "benign"],
191+
task_name: Literal["harmful", "benign", "chat"],
187192
all_tool_names: list[str],
188193
n_irrelevant_tools: int = 0,
189194
) -> Solver:
@@ -202,6 +207,8 @@ def setup_tools_from_metadata(
202207

203208
async def solve(state: TaskState, generate: Generate) -> TaskState:
204209
tools_list = []
210+
if task_name == "chat":
211+
return state
205212

206213
for name in state.metadata["target_functions"]:
207214
try:

0 commit comments

Comments
 (0)