Skip to content

Commit a656ca7

Browse files
malhotra5openhands-agentxingyaoww
authored
Refactor: Always include risk fields (#1052)
Co-authored-by: openhands <openhands@all-hands.dev> Co-authored-by: Xingyao Wang <xingyao@all-hands.dev>
1 parent a482ab1 commit a656ca7

18 files changed

+714
-114
lines changed

openhands-agent-server/openhands/agent_server/conversation_router.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
GenerateTitleResponse,
1717
SendMessageRequest,
1818
SetConfirmationPolicyRequest,
19+
SetSecurityAnalyzerRequest,
1920
StartConversationRequest,
2021
Success,
2122
UpdateConversationRequest,
@@ -237,6 +238,23 @@ async def set_conversation_confirmation_policy(
237238
return Success()
238239

239240

241+
@conversation_router.post(
242+
"/{conversation_id}/security_analyzer",
243+
responses={404: {"description": "Item not found"}},
244+
)
245+
async def set_conversation_security_analyzer(
246+
conversation_id: UUID,
247+
request: SetSecurityAnalyzerRequest,
248+
conversation_service: ConversationService = Depends(get_conversation_service),
249+
) -> Success:
250+
"""Set the security analyzer for a conversation."""
251+
event_service = await conversation_service.get_event_service(conversation_id)
252+
if event_service is None:
253+
raise HTTPException(status.HTTP_404_NOT_FOUND)
254+
await event_service.set_security_analyzer(request.security_analyzer)
255+
return Success()
256+
257+
240258
@conversation_router.patch(
241259
"/{conversation_id}", responses={404: {"description": "Item not found"}}
242260
)

openhands-agent-server/openhands/agent_server/event_service.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
ConversationState,
2121
)
2222
from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent
23+
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
2324
from openhands.sdk.security.confirmation_policy import ConfirmationPolicyBase
2425
from openhands.sdk.utils.async_utils import AsyncCallbackWrapper
2526
from openhands.sdk.utils.cipher import Cipher
@@ -303,6 +304,17 @@ async def set_confirmation_policy(self, policy: ConfirmationPolicyBase):
303304
None, self._conversation.set_confirmation_policy, policy
304305
)
305306

307+
async def set_security_analyzer(
308+
self, security_analyzer: SecurityAnalyzerBase | None
309+
):
310+
"""Set the security analyzer for the conversation."""
311+
if not self._conversation:
312+
raise ValueError("inactive_service")
313+
loop = asyncio.get_running_loop()
314+
await loop.run_in_executor(
315+
None, self._conversation.set_security_analyzer, security_analyzer
316+
)
317+
306318
async def close(self):
307319
await self._pub_sub.close()
308320
if self._conversation:

openhands-agent-server/openhands/agent_server/models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
ConversationState,
1515
)
1616
from openhands.sdk.llm.utils.metrics import MetricsSnapshot
17+
from openhands.sdk.security.analyzer import SecurityAnalyzerBase
1718
from openhands.sdk.security.confirmation_policy import (
1819
ConfirmationPolicyBase,
1920
NeverConfirm,
@@ -165,6 +166,14 @@ class SetConfirmationPolicyRequest(BaseModel):
165166
policy: ConfirmationPolicyBase = Field(description="The confirmation policy to set")
166167

167168

169+
class SetSecurityAnalyzerRequest(BaseModel):
170+
"Payload to set security analyzer for a conversation"
171+
172+
security_analyzer: SecurityAnalyzerBase | None = Field(
173+
description="The security analyzer to set"
174+
)
175+
176+
168177
class UpdateConversationRequest(BaseModel):
169178
"""Payload to update conversation metadata."""
170179

openhands-sdk/openhands/sdk/agent/agent.py

Lines changed: 69 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import json
22

3-
from pydantic import ValidationError
3+
from pydantic import ValidationError, model_validator
44

5+
import openhands.sdk.security.analyzer as analyzer
56
import openhands.sdk.security.risk as risk
67
from openhands.sdk.agent.base import AgentBase
78
from openhands.sdk.agent.utils import fix_malformed_tool_arguments
@@ -41,7 +42,6 @@
4142
should_enable_observability,
4243
)
4344
from openhands.sdk.observability.utils import extract_action_name
44-
from openhands.sdk.security.confirmation_policy import NeverConfirm
4545
from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer
4646
from openhands.sdk.tool import (
4747
Action,
@@ -72,9 +72,20 @@ class Agent(AgentBase):
7272
>>> agent = Agent(llm=llm, tools=tools)
7373
"""
7474

75-
@property
76-
def _add_security_risk_prediction(self) -> bool:
77-
return isinstance(self.security_analyzer, LLMSecurityAnalyzer)
75+
@model_validator(mode="before")
76+
@classmethod
77+
def _add_security_prompt_as_default(cls, data):
78+
"""Ensure llm_security_analyzer=True is always set before initialization."""
79+
if not isinstance(data, dict):
80+
return data
81+
82+
kwargs = data.get("system_prompt_kwargs") or {}
83+
if not isinstance(kwargs, dict):
84+
kwargs = {}
85+
86+
kwargs.setdefault("llm_security_analyzer", True)
87+
data["system_prompt_kwargs"] = kwargs
88+
return data
7889

7990
def init_state(
8091
self,
@@ -85,18 +96,6 @@ def init_state(
8596
# TODO(openhands): we should add test to test this init_state will actually
8697
# modify state in-place
8798

88-
# Validate security analyzer configuration once during initialization
89-
if self._add_security_risk_prediction and isinstance(
90-
state.confirmation_policy, NeverConfirm
91-
):
92-
# If security analyzer is enabled, we always need a policy that is not
93-
# NeverConfirm, otherwise we are just predicting risks without using them,
94-
# and waste tokens!
95-
logger.warning(
96-
"LLM security analyzer is enabled but confirmation "
97-
"policy is set to NeverConfirm"
98-
)
99-
10099
llm_convertible_messages = [
101100
event for event in state.events if isinstance(event, LLMConvertibleEvent)
102101
]
@@ -105,10 +104,15 @@ def init_state(
105104
event = SystemPromptEvent(
106105
source="agent",
107106
system_prompt=TextContent(text=self.system_message),
107+
# Always expose a 'security_risk' parameter in tool schemas.
108+
# This ensures the schema remains consistent, even if the
109+
# security analyzer is disabled. Validation of this field
110+
# happens dynamically at runtime depending on the analyzer
111+
# configured. This allows weaker models to omit risk field
112+
# and bypass validation requirements when analyzer is disabled.
113+
# For detailed logic, see `_extract_security_risk` method.
108114
tools=[
109-
t.to_openai_tool(
110-
add_security_risk_prediction=self._add_security_risk_prediction
111-
)
115+
t.to_openai_tool(add_security_risk_prediction=True)
112116
for t in self.tools_map.values()
113117
],
114118
)
@@ -176,15 +180,15 @@ def step(
176180
tools=list(self.tools_map.values()),
177181
include=None,
178182
store=False,
179-
add_security_risk_prediction=self._add_security_risk_prediction,
183+
add_security_risk_prediction=True,
180184
extra_body=self.llm.litellm_extra_body,
181185
)
182186
else:
183187
llm_response = self.llm.completion(
184188
messages=_messages,
185189
tools=list(self.tools_map.values()),
186190
extra_body=self.llm.litellm_extra_body,
187-
add_security_risk_prediction=self._add_security_risk_prediction,
191+
add_security_risk_prediction=True,
188192
)
189193
except FunctionCallValidationError as e:
190194
logger.warning(f"LLM generated malformed function call: {e}")
@@ -230,6 +234,7 @@ def step(
230234
tool_call,
231235
llm_response_id=llm_response.id,
232236
on_event=on_event,
237+
security_analyzer=state.security_analyzer,
233238
thought=thought_content
234239
if i == 0
235240
else [], # Only first gets thought
@@ -300,10 +305,10 @@ def _requires_user_confirmation(
300305

301306
# If a security analyzer is registered, use it to grab the risks of the actions
302307
# involved. If not, we'll set the risks to UNKNOWN.
303-
if self.security_analyzer is not None:
308+
if state.security_analyzer is not None:
304309
risks = [
305310
risk
306-
for _, risk in self.security_analyzer.analyze_pending_actions(
311+
for _, risk in state.security_analyzer.analyze_pending_actions(
307312
action_events
308313
)
309314
]
@@ -319,11 +324,44 @@ def _requires_user_confirmation(
319324

320325
return False
321326

327+
def _extract_security_risk(
328+
self,
329+
arguments: dict,
330+
tool_name: str,
331+
read_only_tool: bool,
332+
security_analyzer: analyzer.SecurityAnalyzerBase | None = None,
333+
) -> risk.SecurityRisk:
334+
requires_sr = isinstance(security_analyzer, LLMSecurityAnalyzer)
335+
raw = arguments.pop("security_risk", None)
336+
337+
# Default risk value for action event
338+
# Tool is marked as read-only so security risk can be ignored
339+
if read_only_tool:
340+
return risk.SecurityRisk.UNKNOWN
341+
342+
# Raises exception if failed to pass risk field when expected
343+
# Exception will be sent back to agent as error event
344+
# Strong models like GPT-5 can correct itself by retrying
345+
if requires_sr and raw is None:
346+
raise ValueError(
347+
f"Failed to provide security_risk field in tool '{tool_name}'"
348+
)
349+
350+
# When using weaker models without security analyzer
351+
# safely ignore missing security risk fields
352+
if not requires_sr and raw is None:
353+
return risk.SecurityRisk.UNKNOWN
354+
355+
# Raises exception if invalid risk enum passed by LLM
356+
security_risk = risk.SecurityRisk(raw)
357+
return security_risk
358+
322359
def _get_action_event(
323360
self,
324361
tool_call: MessageToolCall,
325362
llm_response_id: str,
326363
on_event: ConversationCallbackType,
364+
security_analyzer: analyzer.SecurityAnalyzerBase | None = None,
327365
thought: list[TextContent] | None = None,
328366
reasoning_content: str | None = None,
329367
thinking_blocks: list[ThinkingBlock | RedactedThinkingBlock] | None = None,
@@ -369,25 +407,18 @@ def _get_action_event(
369407

370408
# Fix malformed arguments (e.g., JSON strings for list/dict fields)
371409
arguments = fix_malformed_tool_arguments(arguments, tool.action_type)
372-
373-
# if the tool has a security_risk field (when security analyzer is set),
374-
# pop it out as it's not part of the tool's action schema
375-
if (
376-
_predicted_risk := arguments.pop("security_risk", None)
377-
) is not None and self.security_analyzer is not None:
378-
try:
379-
security_risk = risk.SecurityRisk(_predicted_risk)
380-
except ValueError:
381-
logger.warning(
382-
f"Invalid security_risk value from LLM: {_predicted_risk}"
383-
)
384-
410+
security_risk = self._extract_security_risk(
411+
arguments,
412+
tool.name,
413+
tool.annotations.readOnlyHint if tool.annotations else False,
414+
security_analyzer,
415+
)
385416
assert "security_risk" not in arguments, (
386417
"Unexpected 'security_risk' key found in tool arguments"
387418
)
388419

389420
action: Action = tool.action_from_arguments(arguments)
390-
except (json.JSONDecodeError, ValidationError) as e:
421+
except (json.JSONDecodeError, ValidationError, ValueError) as e:
391422
err = (
392423
f"Error validating args {tool_call.arguments} for tool "
393424
f"'{tool.name}': {e}"

openhands-sdk/openhands/sdk/agent/base.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,20 @@
11
import os
22
import re
33
import sys
4+
import warnings
45
from abc import ABC, abstractmethod
56
from collections.abc import Generator, Iterable
67
from typing import TYPE_CHECKING, Any
78

8-
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
9+
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
910

10-
import openhands.sdk.security.analyzer as analyzer
1111
from openhands.sdk.context.agent_context import AgentContext
1212
from openhands.sdk.context.condenser import CondenserBase, LLMSummarizingCondenser
1313
from openhands.sdk.context.prompts.prompt import render_template
1414
from openhands.sdk.llm import LLM
1515
from openhands.sdk.logger import get_logger
1616
from openhands.sdk.mcp import create_mcp_tools
17-
from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer
17+
from openhands.sdk.security import analyzer
1818
from openhands.sdk.tool import BUILT_IN_TOOLS, Tool, ToolDefinition, resolve_tool
1919
from openhands.sdk.utils.models import DiscriminatedUnionMixin
2020
from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff
@@ -27,6 +27,13 @@
2727
logger = get_logger(__name__)
2828

2929

30+
AGENT_SECURITY_ANALYZER_DEPRECATION_WARNING = (
31+
"Agent.security_analyzer is deprecated and will be removed "
32+
"in a future release.\n\n use `conversation = Conversation();"
33+
"conversation.set_security_analyzer(...)` instead."
34+
)
35+
36+
3037
class AgentBase(DiscriminatedUnionMixin, ABC):
3138
"""Abstract base class for OpenHands agents.
3239
@@ -122,11 +129,13 @@ class AgentBase(DiscriminatedUnionMixin, ABC):
122129
description="Optional kwargs to pass to the system prompt Jinja2 template.",
123130
examples=[{"cli_mode": True}],
124131
)
132+
125133
security_analyzer: analyzer.SecurityAnalyzerBase | None = Field(
126134
default=None,
127135
description="Optional security analyzer to evaluate action risks.",
128136
examples=[{"kind": "LLMSecurityAnalyzer"}],
129137
)
138+
130139
condenser: CondenserBase | None = Field(
131140
default=None,
132141
description="Optional condenser to use for condensing conversation history.",
@@ -147,6 +156,22 @@ class AgentBase(DiscriminatedUnionMixin, ABC):
147156
# Runtime materialized tools; private and non-serializable
148157
_tools: dict[str, ToolDefinition] = PrivateAttr(default_factory=dict)
149158

159+
@model_validator(mode="before")
160+
@classmethod
161+
def _coerce_inputs(cls, data):
162+
if not isinstance(data, dict):
163+
return data
164+
d = dict(data)
165+
166+
if "security_analyzer" in d and d["security_analyzer"]:
167+
warnings.warn(
168+
AGENT_SECURITY_ANALYZER_DEPRECATION_WARNING,
169+
DeprecationWarning,
170+
stacklevel=3,
171+
)
172+
173+
return d
174+
150175
@property
151176
def prompt_dir(self) -> str:
152177
"""Returns the directory where this class's module file is located."""
@@ -164,13 +189,7 @@ def name(self) -> str:
164189
@property
165190
def system_message(self) -> str:
166191
"""Compute system message on-demand to maintain statelessness."""
167-
# Prepare template kwargs, including cli_mode if available
168192
template_kwargs = dict(self.system_prompt_kwargs)
169-
if self.security_analyzer:
170-
template_kwargs["llm_security_analyzer"] = bool(
171-
isinstance(self.security_analyzer, LLMSecurityAnalyzer)
172-
)
173-
174193
system_message = render_template(
175194
prompt_dir=self.prompt_dir,
176195
template_name=self.system_prompt_filename,
@@ -198,6 +217,16 @@ def init_state(
198217

199218
def _initialize(self, state: "ConversationState"):
200219
"""Create an AgentBase instance from an AgentSpec."""
220+
221+
# 1) Migrate deprecated analyzer → state (if present)
222+
if self.security_analyzer and not state.security_analyzer:
223+
state.security_analyzer = self.security_analyzer
224+
# 2) Clear on the immutable model (allowed via object.__setattr__)
225+
try:
226+
object.__setattr__(self, "security_analyzer", None)
227+
except Exception:
228+
logger.warning("Could not clear deprecated Agent.security_analyzer")
229+
201230
if self._tools:
202231
logger.warning("Agent already initialized; skipping re-initialization.")
203232
return
@@ -297,8 +326,6 @@ def resolve_diff_from_deserialized(self, persisted: "AgentBase") -> "AgentBase":
297326
updates["condenser"] = new_condenser
298327

299328
# Allow security_analyzer to differ - use the runtime (self) version
300-
# This allows users to add/remove security analyzers mid-conversation
301-
# (e.g., when switching to weaker LLMs that can't handle security_risk field)
302329
updates["security_analyzer"] = self.security_analyzer
303330

304331
# Create maps by tool name for easy lookup

0 commit comments

Comments
 (0)