diff --git a/openhands-agent-server/openhands/agent_server/conversation_router.py b/openhands-agent-server/openhands/agent_server/conversation_router.py index c179b55cf..b70fba3a3 100644 --- a/openhands-agent-server/openhands/agent_server/conversation_router.py +++ b/openhands-agent-server/openhands/agent_server/conversation_router.py @@ -16,6 +16,7 @@ GenerateTitleResponse, SendMessageRequest, SetConfirmationPolicyRequest, + SetSecurityAnalyzerRequest, StartConversationRequest, Success, UpdateConversationRequest, @@ -237,6 +238,23 @@ async def set_conversation_confirmation_policy( return Success() +@conversation_router.post( + "/{conversation_id}/security_analyzer", + responses={404: {"description": "Item not found"}}, +) +async def set_conversation_security_analyzer( + conversation_id: UUID, + request: SetSecurityAnalyzerRequest, + conversation_service: ConversationService = Depends(get_conversation_service), +) -> Success: + """Set the security analyzer for a conversation.""" + event_service = await conversation_service.get_event_service(conversation_id) + if event_service is None: + raise HTTPException(status.HTTP_404_NOT_FOUND) + await event_service.set_security_analyzer(request.security_analyzer) + return Success() + + @conversation_router.patch( "/{conversation_id}", responses={404: {"description": "Item not found"}} ) diff --git a/openhands-agent-server/openhands/agent_server/event_service.py b/openhands-agent-server/openhands/agent_server/event_service.py index 620eeb382..3de57d721 100644 --- a/openhands-agent-server/openhands/agent_server/event_service.py +++ b/openhands-agent-server/openhands/agent_server/event_service.py @@ -20,6 +20,7 @@ ConversationState, ) from openhands.sdk.event.conversation_state import ConversationStateUpdateEvent +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ConfirmationPolicyBase from openhands.sdk.utils.async_utils import AsyncCallbackWrapper from openhands.sdk.utils.cipher import Cipher @@ -303,6 +304,17 @@ async def set_confirmation_policy(self, policy: ConfirmationPolicyBase): None, self._conversation.set_confirmation_policy, policy ) + async def set_security_analyzer( + self, security_analyzer: SecurityAnalyzerBase | None + ): + """Set the security analyzer for the conversation.""" + if not self._conversation: + raise ValueError("inactive_service") + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, self._conversation.set_security_analyzer, security_analyzer + ) + async def close(self): await self._pub_sub.close() if self._conversation: diff --git a/openhands-agent-server/openhands/agent_server/models.py b/openhands-agent-server/openhands/agent_server/models.py index a19080f5e..7c85e13a7 100644 --- a/openhands-agent-server/openhands/agent_server/models.py +++ b/openhands-agent-server/openhands/agent_server/models.py @@ -14,6 +14,7 @@ ConversationState, ) from openhands.sdk.llm.utils.metrics import MetricsSnapshot +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, NeverConfirm, @@ -165,6 +166,14 @@ class SetConfirmationPolicyRequest(BaseModel): policy: ConfirmationPolicyBase = Field(description="The confirmation policy to set") +class SetSecurityAnalyzerRequest(BaseModel): + "Payload to set security analyzer for a conversation" + + security_analyzer: SecurityAnalyzerBase | None = Field( + description="The security analyzer to set" + ) + + class UpdateConversationRequest(BaseModel): """Payload to update conversation metadata.""" diff --git a/openhands-sdk/openhands/sdk/agent/agent.py b/openhands-sdk/openhands/sdk/agent/agent.py index 7586471e5..e93bcdd54 100644 --- a/openhands-sdk/openhands/sdk/agent/agent.py +++ b/openhands-sdk/openhands/sdk/agent/agent.py @@ -1,7 +1,8 @@ import json -from pydantic import ValidationError +from pydantic import ValidationError, model_validator +import openhands.sdk.security.analyzer as analyzer import openhands.sdk.security.risk as risk from openhands.sdk.agent.base import AgentBase from openhands.sdk.agent.utils import fix_malformed_tool_arguments @@ -41,7 +42,6 @@ should_enable_observability, ) from openhands.sdk.observability.utils import extract_action_name -from openhands.sdk.security.confirmation_policy import NeverConfirm from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer from openhands.sdk.tool import ( Action, @@ -72,9 +72,20 @@ class Agent(AgentBase): >>> agent = Agent(llm=llm, tools=tools) """ - @property - def _add_security_risk_prediction(self) -> bool: - return isinstance(self.security_analyzer, LLMSecurityAnalyzer) + @model_validator(mode="before") + @classmethod + def _add_security_prompt_as_default(cls, data): + """Ensure llm_security_analyzer=True is always set before initialization.""" + if not isinstance(data, dict): + return data + + kwargs = data.get("system_prompt_kwargs") or {} + if not isinstance(kwargs, dict): + kwargs = {} + + kwargs.setdefault("llm_security_analyzer", True) + data["system_prompt_kwargs"] = kwargs + return data def init_state( self, @@ -85,18 +96,6 @@ def init_state( # TODO(openhands): we should add test to test this init_state will actually # modify state in-place - # Validate security analyzer configuration once during initialization - if self._add_security_risk_prediction and isinstance( - state.confirmation_policy, NeverConfirm - ): - # If security analyzer is enabled, we always need a policy that is not - # NeverConfirm, otherwise we are just predicting risks without using them, - # and waste tokens! - logger.warning( - "LLM security analyzer is enabled but confirmation " - "policy is set to NeverConfirm" - ) - llm_convertible_messages = [ event for event in state.events if isinstance(event, LLMConvertibleEvent) ] @@ -105,10 +104,15 @@ def init_state( event = SystemPromptEvent( source="agent", system_prompt=TextContent(text=self.system_message), + # Always expose a 'security_risk' parameter in tool schemas. + # This ensures the schema remains consistent, even if the + # security analyzer is disabled. Validation of this field + # happens dynamically at runtime depending on the analyzer + # configured. This allows weaker models to omit risk field + # and bypass validation requirements when analyzer is disabled. + # For detailed logic, see `_extract_security_risk` method. tools=[ - t.to_openai_tool( - add_security_risk_prediction=self._add_security_risk_prediction - ) + t.to_openai_tool(add_security_risk_prediction=True) for t in self.tools_map.values() ], ) @@ -176,7 +180,7 @@ def step( tools=list(self.tools_map.values()), include=None, store=False, - add_security_risk_prediction=self._add_security_risk_prediction, + add_security_risk_prediction=True, extra_body=self.llm.litellm_extra_body, ) else: @@ -184,7 +188,7 @@ def step( messages=_messages, tools=list(self.tools_map.values()), extra_body=self.llm.litellm_extra_body, - add_security_risk_prediction=self._add_security_risk_prediction, + add_security_risk_prediction=True, ) except FunctionCallValidationError as e: logger.warning(f"LLM generated malformed function call: {e}") @@ -230,6 +234,7 @@ def step( tool_call, llm_response_id=llm_response.id, on_event=on_event, + security_analyzer=state.security_analyzer, thought=thought_content if i == 0 else [], # Only first gets thought @@ -300,10 +305,10 @@ def _requires_user_confirmation( # If a security analyzer is registered, use it to grab the risks of the actions # involved. If not, we'll set the risks to UNKNOWN. - if self.security_analyzer is not None: + if state.security_analyzer is not None: risks = [ risk - for _, risk in self.security_analyzer.analyze_pending_actions( + for _, risk in state.security_analyzer.analyze_pending_actions( action_events ) ] @@ -319,11 +324,44 @@ def _requires_user_confirmation( return False + def _extract_security_risk( + self, + arguments: dict, + tool_name: str, + read_only_tool: bool, + security_analyzer: analyzer.SecurityAnalyzerBase | None = None, + ) -> risk.SecurityRisk: + requires_sr = isinstance(security_analyzer, LLMSecurityAnalyzer) + raw = arguments.pop("security_risk", None) + + # Default risk value for action event + # Tool is marked as read-only so security risk can be ignored + if read_only_tool: + return risk.SecurityRisk.UNKNOWN + + # Raises exception if failed to pass risk field when expected + # Exception will be sent back to agent as error event + # Strong models like GPT-5 can correct itself by retrying + if requires_sr and raw is None: + raise ValueError( + f"Failed to provide security_risk field in tool '{tool_name}'" + ) + + # When using weaker models without security analyzer + # safely ignore missing security risk fields + if not requires_sr and raw is None: + return risk.SecurityRisk.UNKNOWN + + # Raises exception if invalid risk enum passed by LLM + security_risk = risk.SecurityRisk(raw) + return security_risk + def _get_action_event( self, tool_call: MessageToolCall, llm_response_id: str, on_event: ConversationCallbackType, + security_analyzer: analyzer.SecurityAnalyzerBase | None = None, thought: list[TextContent] | None = None, reasoning_content: str | None = None, thinking_blocks: list[ThinkingBlock | RedactedThinkingBlock] | None = None, @@ -369,25 +407,18 @@ def _get_action_event( # Fix malformed arguments (e.g., JSON strings for list/dict fields) arguments = fix_malformed_tool_arguments(arguments, tool.action_type) - - # if the tool has a security_risk field (when security analyzer is set), - # pop it out as it's not part of the tool's action schema - if ( - _predicted_risk := arguments.pop("security_risk", None) - ) is not None and self.security_analyzer is not None: - try: - security_risk = risk.SecurityRisk(_predicted_risk) - except ValueError: - logger.warning( - f"Invalid security_risk value from LLM: {_predicted_risk}" - ) - + security_risk = self._extract_security_risk( + arguments, + tool.name, + tool.annotations.readOnlyHint if tool.annotations else False, + security_analyzer, + ) assert "security_risk" not in arguments, ( "Unexpected 'security_risk' key found in tool arguments" ) action: Action = tool.action_from_arguments(arguments) - except (json.JSONDecodeError, ValidationError) as e: + except (json.JSONDecodeError, ValidationError, ValueError) as e: err = ( f"Error validating args {tool_call.arguments} for tool " f"'{tool.name}': {e}" diff --git a/openhands-sdk/openhands/sdk/agent/base.py b/openhands-sdk/openhands/sdk/agent/base.py index e0b2e2236..380ee4d62 100644 --- a/openhands-sdk/openhands/sdk/agent/base.py +++ b/openhands-sdk/openhands/sdk/agent/base.py @@ -1,20 +1,20 @@ import os import re import sys +import warnings from abc import ABC, abstractmethod from collections.abc import Generator, Iterable from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator -import openhands.sdk.security.analyzer as analyzer from openhands.sdk.context.agent_context import AgentContext from openhands.sdk.context.condenser import CondenserBase, LLMSummarizingCondenser from openhands.sdk.context.prompts.prompt import render_template from openhands.sdk.llm import LLM from openhands.sdk.logger import get_logger from openhands.sdk.mcp import create_mcp_tools -from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer +from openhands.sdk.security import analyzer from openhands.sdk.tool import BUILT_IN_TOOLS, Tool, ToolDefinition, resolve_tool from openhands.sdk.utils.models import DiscriminatedUnionMixin from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff @@ -27,6 +27,13 @@ logger = get_logger(__name__) +AGENT_SECURITY_ANALYZER_DEPRECATION_WARNING = ( + "Agent.security_analyzer is deprecated and will be removed " + "in a future release.\n\n use `conversation = Conversation();" + "conversation.set_security_analyzer(...)` instead." +) + + class AgentBase(DiscriminatedUnionMixin, ABC): """Abstract base class for OpenHands agents. @@ -122,11 +129,13 @@ class AgentBase(DiscriminatedUnionMixin, ABC): description="Optional kwargs to pass to the system prompt Jinja2 template.", examples=[{"cli_mode": True}], ) + security_analyzer: analyzer.SecurityAnalyzerBase | None = Field( default=None, description="Optional security analyzer to evaluate action risks.", examples=[{"kind": "LLMSecurityAnalyzer"}], ) + condenser: CondenserBase | None = Field( default=None, description="Optional condenser to use for condensing conversation history.", @@ -147,6 +156,22 @@ class AgentBase(DiscriminatedUnionMixin, ABC): # Runtime materialized tools; private and non-serializable _tools: dict[str, ToolDefinition] = PrivateAttr(default_factory=dict) + @model_validator(mode="before") + @classmethod + def _coerce_inputs(cls, data): + if not isinstance(data, dict): + return data + d = dict(data) + + if "security_analyzer" in d and d["security_analyzer"]: + warnings.warn( + AGENT_SECURITY_ANALYZER_DEPRECATION_WARNING, + DeprecationWarning, + stacklevel=3, + ) + + return d + @property def prompt_dir(self) -> str: """Returns the directory where this class's module file is located.""" @@ -164,13 +189,7 @@ def name(self) -> str: @property def system_message(self) -> str: """Compute system message on-demand to maintain statelessness.""" - # Prepare template kwargs, including cli_mode if available template_kwargs = dict(self.system_prompt_kwargs) - if self.security_analyzer: - template_kwargs["llm_security_analyzer"] = bool( - isinstance(self.security_analyzer, LLMSecurityAnalyzer) - ) - system_message = render_template( prompt_dir=self.prompt_dir, template_name=self.system_prompt_filename, @@ -198,6 +217,16 @@ def init_state( def _initialize(self, state: "ConversationState"): """Create an AgentBase instance from an AgentSpec.""" + + # 1) Migrate deprecated analyzer → state (if present) + if self.security_analyzer and not state.security_analyzer: + state.security_analyzer = self.security_analyzer + # 2) Clear on the immutable model (allowed via object.__setattr__) + try: + object.__setattr__(self, "security_analyzer", None) + except Exception: + logger.warning("Could not clear deprecated Agent.security_analyzer") + if self._tools: logger.warning("Agent already initialized; skipping re-initialization.") return @@ -297,8 +326,6 @@ def resolve_diff_from_deserialized(self, persisted: "AgentBase") -> "AgentBase": updates["condenser"] = new_condenser # Allow security_analyzer to differ - use the runtime (self) version - # This allows users to add/remove security analyzers mid-conversation - # (e.g., when switching to weaker LLMs that can't handle security_risk field) updates["security_analyzer"] = self.security_analyzer # Create maps by tool name for easy lookup diff --git a/openhands-sdk/openhands/sdk/conversation/base.py b/openhands-sdk/openhands/sdk/conversation/base.py index 6213061b5..d423bc91e 100644 --- a/openhands-sdk/openhands/sdk/conversation/base.py +++ b/openhands-sdk/openhands/sdk/conversation/base.py @@ -14,6 +14,7 @@ should_enable_observability, start_active_span, ) +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, NeverConfirm, @@ -49,6 +50,11 @@ def confirmation_policy(self) -> ConfirmationPolicyBase: """The confirmation policy.""" ... + @property + def security_analyzer(self) -> SecurityAnalyzerBase | None: + """The security analyzer.""" + ... + @property def activated_knowledge_skills(self) -> list[str]: """List of activated knowledge skills.""" @@ -145,13 +151,12 @@ def is_confirmation_mode_active(self) -> bool: """Check if confirmation mode is active. Returns True if BOTH conditions are met: - 1. The agent has a security analyzer set (not None) + 1. The conversation state has a security analyzer set (not None) 2. The confirmation policy is active """ return ( - self.state.agent.security_analyzer is not None - and self.confirmation_policy_active + self.state.security_analyzer is not None and self.confirmation_policy_active ) @abstractmethod diff --git a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py index 25bf893f9..690c0b6bd 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py @@ -28,6 +28,7 @@ from openhands.sdk.llm.llm_registry import LLMRegistry from openhands.sdk.logger import get_logger from openhands.sdk.observability.laminar import observe +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, ) @@ -403,6 +404,11 @@ def update_secrets(self, secrets: Mapping[str, SecretValue]) -> None: secret_registry.update_secrets(secrets) logger.info(f"Added {len(secrets)} secrets to conversation") + def set_security_analyzer(self, analyzer: SecurityAnalyzerBase | None) -> None: + """Set the security analyzer for the conversation.""" + with self._state: + self._state.security_analyzer = analyzer + def close(self) -> None: """Close the conversation and clean up all tool executors.""" if self._cleanup_initiated: diff --git a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py index e35e69220..6f360dd23 100644 --- a/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py +++ b/openhands-sdk/openhands/sdk/conversation/impl/remote_conversation.py @@ -29,6 +29,7 @@ from openhands.sdk.llm import LLM, Message, TextContent from openhands.sdk.logger import get_logger from openhands.sdk.observability.laminar import observe +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, ) @@ -343,6 +344,16 @@ def confirmation_policy(self) -> ConfirmationPolicyBase: ) return ConfirmationPolicyBase.model_validate(policy_data) + @property + def security_analyzer(self) -> SecurityAnalyzerBase | None: + """The security analyzer.""" + info = self._get_conversation_info() + analyzer_data = info.get("security_analyzer") + if analyzer_data: + return SecurityAnalyzerBase.model_validate(analyzer_data) + + return None + @property def activated_knowledge_skills(self) -> list[str]: """List of activated knowledge skills.""" @@ -597,6 +608,16 @@ def set_confirmation_policy(self, policy: ConfirmationPolicyBase) -> None: json=payload, ) + def set_security_analyzer(self, analyzer: SecurityAnalyzerBase | None) -> None: + """Set the security analyzer for the remote conversation.""" + payload = {"security_analyzer": analyzer.model_dump() if analyzer else analyzer} + _send_request( + self._client, + "POST", + f"/api/conversations/{self._id}/security_analyzer", + json=payload, + ) + def reject_pending_actions(self, reason: str = "User rejected the action") -> None: # Equivalent to rejecting confirmation: pause _send_request( diff --git a/openhands-sdk/openhands/sdk/conversation/state.py b/openhands-sdk/openhands/sdk/conversation/state.py index ddfdd0a79..437a420c0 100644 --- a/openhands-sdk/openhands/sdk/conversation/state.py +++ b/openhands-sdk/openhands/sdk/conversation/state.py @@ -17,6 +17,7 @@ from openhands.sdk.event.base import Event from openhands.sdk.io import FileStore, InMemoryFileStore, LocalFileStore from openhands.sdk.logger import get_logger +from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.confirmation_policy import ( ConfirmationPolicyBase, NeverConfirm, @@ -81,6 +82,10 @@ class ConversationState(OpenHandsModel): default=ConversationExecutionStatus.IDLE ) confirmation_policy: ConfirmationPolicyBase = NeverConfirm() + security_analyzer: SecurityAnalyzerBase | None = Field( + default=None, + description="Optional security analyzer to evaluate action risks.", + ) activated_knowledge_skills: list[str] = Field( default_factory=list, @@ -204,6 +209,8 @@ def create( max_iterations=max_iterations, stuck_detection=stuck_detection, ) + # Record existing analyzer configuration in state + state.security_analyzer = state.security_analyzer state._fs = file_store state._events = EventLog(file_store, dir_path=EVENTS_DIR) state.stats = ConversationStats() diff --git a/tests/agent_server/test_conversation_router.py b/tests/agent_server/test_conversation_router.py index dc3106083..55308dbb0 100644 --- a/tests/agent_server/test_conversation_router.py +++ b/tests/agent_server/test_conversation_router.py @@ -22,6 +22,7 @@ from openhands.agent_server.utils import utc_now from openhands.sdk import LLM, Agent, TextContent, Tool from openhands.sdk.conversation.state import ConversationExecutionStatus +from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer from openhands.sdk.workspace import LocalWorkspace @@ -76,6 +77,12 @@ def mock_event_service(): return service +@pytest.fixture +def llm_security_analyzer(): + """Create an LLMSecurityAnalyzer for testing.""" + return LLMSecurityAnalyzer() + + @pytest.fixture def sample_start_conversation_request(): """Create a sample StartConversationRequest for testing.""" @@ -1169,3 +1176,92 @@ def test_generate_conversation_title_invalid_params( assert response.status_code == 422 # Validation error finally: client.app.dependency_overrides.clear() + + +def test_set_conversation_security_analyzer_success( + client, + sample_conversation_id, + mock_conversation_service, + mock_event_service, + llm_security_analyzer, +): + """Test successful setting of security analyzer via API endpoint.""" + # Setup mocks + mock_conversation_service.get_event_service.return_value = mock_event_service + mock_event_service.set_security_analyzer.return_value = None + + # Override dependency + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + # Make request + response = client.post( + f"/api/conversations/{sample_conversation_id}/security_analyzer", + json={"security_analyzer": llm_security_analyzer.model_dump()}, + ) + + # Verify response + assert response.status_code == 200 + assert response.json() == {"success": True} + + # Verify service calls + mock_conversation_service.get_event_service.assert_called_once_with( + sample_conversation_id + ) + mock_event_service.set_security_analyzer.assert_called_once() + + +def test_set_conversation_security_analyzer_with_none( + client, sample_conversation_id, mock_conversation_service, mock_event_service +): + """Test setting security analyzer to None via API endpoint.""" + # Setup mocks + mock_conversation_service.get_event_service.return_value = mock_event_service + mock_event_service.set_security_analyzer.return_value = None + + # Override dependency + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + # Make request with None analyzer + response = client.post( + f"/api/conversations/{sample_conversation_id}/security_analyzer", + json={"security_analyzer": None}, + ) + + # Verify response + assert response.status_code == 200 + assert response.json() == {"success": True} + + # Verify service calls + mock_conversation_service.get_event_service.assert_called_once_with( + sample_conversation_id + ) + mock_event_service.set_security_analyzer.assert_called_once_with(None) + + +def test_security_analyzer_endpoint_with_malformed_analyzer_data( + client, sample_conversation_id, mock_conversation_service, mock_event_service +): + """Test endpoint behavior with malformed security analyzer data.""" + # Setup mocks + mock_conversation_service.get_event_service.return_value = mock_event_service + mock_event_service.set_security_analyzer.return_value = None + + # Override dependency + client.app.dependency_overrides[get_conversation_service] = ( + lambda: mock_conversation_service + ) + + # Test with invalid analyzer type (should be rejected) + response = client.post( + f"/api/conversations/{sample_conversation_id}/security_analyzer", + json={"security_analyzer": {"kind": "InvalidAnalyzerType"}}, + ) + + # Should return validation error for unknown analyzer type + assert response.status_code == 422 + response_data = response.json() + assert "detail" in response_data diff --git a/tests/cross/test_agent_reconciliation.py b/tests/cross/test_agent_reconciliation.py index d1387e28b..d2f170362 100644 --- a/tests/cross/test_agent_reconciliation.py +++ b/tests/cross/test_agent_reconciliation.py @@ -517,6 +517,6 @@ def test_conversation_restart_adding_security_analyzer(): # Verify conversation loaded successfully assert new_conversation.id == conversation_id - assert new_conversation.agent.security_analyzer is not None - assert isinstance(new_conversation.agent.security_analyzer, LLMSecurityAnalyzer) + assert new_conversation.state.security_analyzer is not None + assert isinstance(new_conversation.state.security_analyzer, LLMSecurityAnalyzer) assert len(new_conversation.state.events) > 0 diff --git a/tests/cross/test_remote_conversation_live_server.py b/tests/cross/test_remote_conversation_live_server.py index 5b3a4f706..57fbf488f 100644 --- a/tests/cross/test_remote_conversation_live_server.py +++ b/tests/cross/test_remote_conversation_live_server.py @@ -569,3 +569,151 @@ def fake_completion_with_cost( assert stats_from_field, "Expected non-empty stats in the 'stats' field after run()" conv.close() + + +def test_security_risk_field_with_live_server( + server_env, monkeypatch: pytest.MonkeyPatch +): + """Integration test validating security_risk field functionality. + + This test validates the fix for issue #819 where security_risk field handling + was inconsistent. It tests that: + 1. Actions execute successfully with security_risk provided + 2. Actions execute successfully without security_risk (defaults to UNKNOWN) + + This is a regression test spawning a real agent server to ensure end-to-end + functionality of security_risk field handling. + """ + from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer + + # Track which completion call we're on to control behavior + call_count = {"count": 0} + + def fake_completion_with_tool_calls( + self, + messages, + tools, + return_metrics=False, + add_security_risk_prediction=False, + **kwargs, + ): # type: ignore[no-untyped-def] + from openhands.sdk.llm.llm_response import LLMResponse + from openhands.sdk.llm.message import Message + from openhands.sdk.llm.utils.metrics import MetricsSnapshot + + call_count["count"] += 1 + + # First call: return tool call WITHOUT security_risk + # (to test error event when analyzer is configured) + if call_count["count"] == 1: + litellm_msg = LiteLLMMessage.model_validate( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "finish", + "arguments": '{"message": "Task complete"}', + }, + } + ], + } + ) + # Second call: return tool call WITH security_risk + # (to test successful execution after error) + elif call_count["count"] == 2: + litellm_msg = LiteLLMMessage.model_validate( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": { + "name": "finish", + "arguments": ( + '{"message": "Task complete", ' + '"security_risk": "LOW"}' + ), + }, + } + ], + } + ) + # Third call: simple message to finish + else: + litellm_msg = LiteLLMMessage.model_validate( + {"role": "assistant", "content": "Done"} + ) + + raw_response = ModelResponse( + id=f"test-resp-{call_count['count']}", + created=int(time.time()), + model="test-model", + choices=[Choices(index=0, finish_reason="stop", message=litellm_msg)], + ) + + message = Message.from_llm_chat_message(litellm_msg) + metrics_snapshot = MetricsSnapshot( + model_name="test-model", + accumulated_cost=0.0, + max_budget_per_task=None, + accumulated_token_usage=None, + ) + + return LLMResponse( + message=message, metrics=metrics_snapshot, raw_response=raw_response + ) + + monkeypatch.setattr( + LLM, "completion", fake_completion_with_tool_calls, raising=True + ) + + # Create an Agent with LLMSecurityAnalyzer + # Using empty tools list since tools need to be registered in the server + llm = LLM(model="gpt-4", api_key=SecretStr("test")) + agent = Agent( + llm=llm, + tools=[], + security_analyzer=LLMSecurityAnalyzer(), + ) + + workspace = RemoteWorkspace( + host=server_env["host"], working_dir="/tmp/workspace/project" + ) + conv: RemoteConversation = Conversation(agent=agent, workspace=workspace) + + # Step 1: Send message WITHOUT security_risk - should still execute (defaults to + # UNKNOWN) + conv.send_message("Complete the task") + conv.run() + + # Wait for action event - should succeed even without security_risk + found_action_without_risk = False + for attempt in range(50): # up to ~5s + events = conv.state.events + for e in events: + if isinstance(e, ActionEvent) and e.tool_name == "finish": + # Verify it has a security risk attribute + assert hasattr(e, "security_risk"), ( + "Expected ActionEvent to have security_risk attribute" + ) + found_action_without_risk = True + break + if found_action_without_risk: + break + time.sleep(0.1) + + assert found_action_without_risk, ( + "Expected to find ActionEvent with finish tool even without security_risk" + ) + + conv.close() + + # The test validates that: + # 1. Actions can be executed without security_risk (defaults to UNKNOWN) + # 2. ActionEvent always has a security_risk attribute diff --git a/tests/sdk/agent/test_agent_immutability.py b/tests/sdk/agent/test_agent_immutability.py index 158a363e5..961f4077a 100644 --- a/tests/sdk/agent/test_agent_immutability.py +++ b/tests/sdk/agent/test_agent_immutability.py @@ -5,7 +5,6 @@ from openhands.sdk.agent.agent import Agent from openhands.sdk.llm import LLM -from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer class TestAgentImmutability: @@ -55,31 +54,6 @@ def test_system_message_is_computed_property(self): keyword in msg1.lower() for keyword in ["assistant", "help", "task", "user"] ) - def test_agent_with_different_configs_are_different(self): - """Test that agents with different configs produce different system messages.""" - # Use LLMSecurityAnalyzer so that the security risk assessment section is - # included and cli_mode differences will be visible in the system message - security_analyzer = LLMSecurityAnalyzer() - agent1 = Agent( - llm=self.llm, - tools=[], - security_analyzer=security_analyzer, - system_prompt_kwargs={"cli_mode": True}, - ) - agent2 = Agent( - llm=self.llm, - tools=[], - security_analyzer=security_analyzer, - system_prompt_kwargs={"cli_mode": False}, - ) - - # System messages should be different due to cli_mode - msg1 = agent1.system_message - msg2 = agent2.system_message - - # They should be different (cli_mode affects the template) - assert msg1 != msg2 - def test_condenser_property_access(self): """Test that condenser property works correctly.""" # Test with None condenser @@ -151,13 +125,9 @@ def test_multiple_agents_are_independent(self): def test_agent_model_copy_creates_new_instance(self): """Test that model_copy creates a new Agent instance with modified fields.""" - # Use LLMSecurityAnalyzer so that the security risk assessment section is - # included and cli_mode differences will be visible in the system message - security_analyzer = LLMSecurityAnalyzer() original_agent = Agent( llm=self.llm, tools=[], - security_analyzer=security_analyzer, system_prompt_kwargs={"cli_mode": True}, ) diff --git a/tests/sdk/agent/test_extract_security_risk.py b/tests/sdk/agent/test_extract_security_risk.py new file mode 100644 index 000000000..3d6e293e7 --- /dev/null +++ b/tests/sdk/agent/test_extract_security_risk.py @@ -0,0 +1,183 @@ +"""Tests for Agent._extract_security_risk method. + +This module tests the _extract_security_risk method which handles extraction +and validation of security risk parameters from tool arguments. +""" + +import pytest +from pydantic import SecretStr + +from openhands.sdk.agent import Agent +from openhands.sdk.event import ActionEvent +from openhands.sdk.llm import LLM +from openhands.sdk.security.analyzer import SecurityAnalyzerBase +from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer +from openhands.sdk.security.risk import SecurityRisk + + +class MockNonLLMAnalyzer(SecurityAnalyzerBase): + """Mock security analyzer that is not an LLMSecurityAnalyzer.""" + + def security_risk(self, action: ActionEvent) -> SecurityRisk: + return SecurityRisk.LOW + + +@pytest.fixture +def mock_llm(): + """Create a mock LLM for testing.""" + return LLM( + usage_id="test-llm", + model="test-model", + api_key=SecretStr("test-key"), + base_url="http://test", + ) + + +@pytest.fixture +def agent_with_llm_analyzer(mock_llm): + """Create an agent with LLMSecurityAnalyzer.""" + agent = Agent(llm=mock_llm) + return agent, LLMSecurityAnalyzer() + + +@pytest.fixture +def agent_with_non_llm_analyzer(mock_llm): + """Create an agent with non-LLM security analyzer.""" + agent = Agent(llm=mock_llm) + return agent, MockNonLLMAnalyzer() + + +@pytest.fixture +def agent_without_analyzer(mock_llm): + """Create an agent without security analyzer.""" + agent = Agent(llm=mock_llm) + return agent, None + + +@pytest.mark.parametrize( + "agent_fixture,security_risk_value,expected_result,should_raise", + [ + # Case 1: LLM analyzer set, security risk passed, extracted properly + ("agent_with_llm_analyzer", "LOW", SecurityRisk.LOW, False), + ("agent_with_llm_analyzer", "MEDIUM", SecurityRisk.MEDIUM, False), + ("agent_with_llm_analyzer", "HIGH", SecurityRisk.HIGH, False), + ("agent_with_llm_analyzer", "UNKNOWN", SecurityRisk.UNKNOWN, False), + # Case 2: analyzer is not set, security risk is passed, extracted properly + ("agent_with_non_llm_analyzer", "LOW", SecurityRisk.LOW, False), + ("agent_with_non_llm_analyzer", "MEDIUM", SecurityRisk.MEDIUM, False), + ("agent_with_non_llm_analyzer", "HIGH", SecurityRisk.HIGH, False), + ("agent_with_non_llm_analyzer", "UNKNOWN", SecurityRisk.UNKNOWN, False), + ("agent_without_analyzer", "LOW", SecurityRisk.LOW, False), + ("agent_without_analyzer", "MEDIUM", SecurityRisk.MEDIUM, False), + ("agent_without_analyzer", "HIGH", SecurityRisk.HIGH, False), + ("agent_without_analyzer", "UNKNOWN", SecurityRisk.UNKNOWN, False), + # Case 3: LLM analyzer set, security risk not passed, ValueError raised + ("agent_with_llm_analyzer", None, None, True), + # Case 4: analyzer is not set, security risk is not passed, UNKNOWN returned + ("agent_with_non_llm_analyzer", None, SecurityRisk.UNKNOWN, False), + ("agent_without_analyzer", None, SecurityRisk.UNKNOWN, False), + # Case 5: invalid security risk value passed, ValueError raised + ("agent_with_llm_analyzer", "INVALID", None, True), + ("agent_with_non_llm_analyzer", "INVALID", None, True), + ("agent_without_analyzer", "INVALID", None, True), + ], +) +def test_extract_security_risk( + request, agent_fixture, security_risk_value, expected_result, should_raise +): + """Test _extract_security_risk method with various scenarios.""" + # Get the agent fixture + agent, security_analyzer = request.getfixturevalue(agent_fixture) + + # Prepare arguments + arguments = {"some_param": "value"} + if security_risk_value is not None: + arguments["security_risk"] = security_risk_value + + tool_name = "test_tool" + + if should_raise: + with pytest.raises(ValueError): + agent._extract_security_risk(arguments, tool_name, False, security_analyzer) + else: + result = agent._extract_security_risk( + arguments, tool_name, False, security_analyzer + ) + assert result == expected_result + + # Verify that security_risk was popped from arguments + assert "security_risk" not in arguments + # Verify other arguments remain + assert arguments["some_param"] == "value" + + +def test_extract_security_risk_arguments_mutation(): + """Test that arguments dict is properly mutated (security_risk is popped).""" + agent = Agent( + llm=LLM( + usage_id="test-llm", + model="test-model", + api_key=SecretStr("test-key"), + base_url="http://test", + ) + ) + + # Test with security_risk present + arguments = {"param1": "value1", "security_risk": "LOW", "param2": "value2"} + original_args = arguments.copy() + + result = agent._extract_security_risk(arguments, "test_tool", False, None) + + # Verify result + assert result == SecurityRisk.LOW + + # Verify security_risk was popped + assert "security_risk" not in arguments + + # Verify other parameters remain + assert arguments["param1"] == original_args["param1"] + assert arguments["param2"] == original_args["param2"] + assert len(arguments) == 2 # Only 2 params should remain + + +def test_extract_security_risk_with_empty_arguments(): + """Test _extract_security_risk with empty arguments dict.""" + agent = Agent( + llm=LLM( + usage_id="test-llm", + model="test-model", + api_key=SecretStr("test-key"), + base_url="http://test", + ) + ) + + arguments = {} + result = agent._extract_security_risk(arguments, "test_tool", False, None) + + # Should return UNKNOWN when no analyzer and no security_risk + assert result == SecurityRisk.UNKNOWN + assert arguments == {} # Should remain empty + + +def test_extract_security_risk_with_read_only_tool(): + """Test _extract_security_risk with read only tool.""" + agent = Agent( + llm=LLM( + usage_id="test-llm", + model="test-model", + api_key=SecretStr("test-key"), + base_url="http://test", + ) + ) + + # Test with readOnlyHint=True - should return UNKNOWN regardless of security_risk + arguments = {"param1": "value1", "security_risk": "HIGH"} + result = agent._extract_security_risk( + arguments, "test_tool", True, LLMSecurityAnalyzer() + ) + + # Should return UNKNOWN when read_only_tool is True + assert result == SecurityRisk.UNKNOWN + # security_risk should still be popped from arguments + assert "security_risk" not in arguments + assert arguments["param1"] == "value1" diff --git a/tests/sdk/agent/test_security_analyzer_backwards_compatibility.py b/tests/sdk/agent/test_security_analyzer_backwards_compatibility.py new file mode 100644 index 000000000..107469c8c --- /dev/null +++ b/tests/sdk/agent/test_security_analyzer_backwards_compatibility.py @@ -0,0 +1,73 @@ +"""Test backwards compatibility for security_analyzer field migration from Agent to ConversationState.""" # noqa: E501 + +import uuid + +from openhands.sdk.agent import Agent +from openhands.sdk.conversation.impl.local_conversation import LocalConversation +from openhands.sdk.io.local import LocalFileStore +from openhands.sdk.llm.llm import LLM +from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer +from openhands.sdk.workspace.local import LocalWorkspace +from openhands.sdk.workspace.workspace import Workspace + + +def test_security_analyzer_migrates_and_is_cleared(): + llm = LLM(model="test-model", api_key=None) + agent = Agent(llm=llm, security_analyzer=LLMSecurityAnalyzer()) + + assert agent.security_analyzer is not None + + conversation = LocalConversation( + agent=agent, workspace=LocalWorkspace(working_dir="/tmp") + ) + + assert agent.security_analyzer is None + assert conversation.state.security_analyzer is not None + + +def test_security_analyzer_reconciliation_and_migration(tmp_path): + # Create conversation state that + # has agent with no security analyzer + DUMMY_BASE_STATE = """{"id": "2d73fc17-6d31-4a5c-ba0d-19c80888bdf3", "agent": {"kind": "Agent", "llm": {"model": "litellm_proxy/claude-sonnet-4-20250514", "api_key": "**********", "base_url": "https://llm-proxy.app.all-hands.dev/", "openrouter_site_url": "https://docs.all-hands.dev/", "openrouter_app_name": "OpenHands", "num_retries": 5, "retry_multiplier": 8.0, "retry_min_wait": 8, "retry_max_wait": 64, "max_message_chars": 30000, "temperature": 0.0, "top_p": 1.0, "max_input_tokens": 1000000, "max_output_tokens": 64000, "drop_params": true, "modify_params": true, "disable_stop_word": false, "caching_prompt": true, "log_completions": false, "log_completions_folder": "logs/completions", "reasoning_effort": "high", "extended_thinking_budget": 200000, "service_id": "agent", "metadata": {"trace_version": "1.0.0", "tags": ["app:openhands", "model:litellm_proxy/claude-sonnet-4-20250514", "type:agent", "web_host:unspecified", "openhands_sdk_version:1.0.0", "openhands_tools_version:1.0.0"], "session_id": "2d73fc17-6d31-4a5c-ba0d-19c80888bdf3"}, "OVERRIDE_ON_SERIALIZE": ["api_key", "aws_access_key_id", "aws_secret_access_key"]}, "tools": [{"name": "BashTool", "params": {}}, {"name": "FileEditorTool", "params": {}}, {"name": "TaskTrackerTool", "params": {}}], "mcp_config": {"mcpServers": {"fetch": {"command": "uvx", "args": ["mcp-server-fetch"]}, "repomix": {"command": "npx", "args": ["-y", "repomix@1.4.2", "--mcp"]}, "new_fetch": {"command": "npm", "args": ["mcp-server-fetch"], "env": {}, "transport": "stdio"}}}, "filter_tools_regex": "^(?!repomix)(.*)|^repomix.*pack_codebase.*$", "agent_context": {"microagents": [], "system_message_suffix": "You current working directory is: /Users/rohitmalhotra/Documents/Openhands/Openhands/openhands-cli"}, "system_prompt_filename": "system_prompt.j2", "system_prompt_kwargs": {"cli_mode": true}, "security_analyzer": null, "condenser": {"kind": "LLMSummarizingCondenser", "llm": {"model": "litellm_proxy/claude-sonnet-4-20250514", "api_key": "**********", "base_url": "https://llm-proxy.app.all-hands.dev/", "openrouter_site_url": "https://docs.all-hands.dev/", "openrouter_app_name": "OpenHands", "num_retries": 5, "retry_multiplier": 8.0, "retry_min_wait": 8, "retry_max_wait": 64, "max_message_chars": 30000, "temperature": 0.0, "top_p": 1.0, "max_input_tokens": 1000000, "max_output_tokens": 64000, "drop_params": true, "modify_params": true, "disable_stop_word": false, "caching_prompt": true, "log_completions": false, "log_completions_folder": "logs/completions", "reasoning_effort": "high", "extended_thinking_budget": 200000, "service_id": "condenser", "metadata": {"trace_version": "1.0.0", "tags": ["app:openhands", "model:litellm_proxy/claude-sonnet-4-20250514", "type:condenser", "web_host:unspecified", "openhands_sdk_version:1.0.0", "openhands_tools_version:1.0.0"], "session_id": "2d73fc17-6d31-4a5c-ba0d-19c80888bdf3"}, "OVERRIDE_ON_SERIALIZE": ["api_key", "aws_access_key_id", "aws_secret_access_key"]}, "max_size": 80, "keep_first": 4}}, "workspace": {"kind": "LocalWorkspace", "working_dir": "/Users/rohitmalhotra/Documents/Openhands/Openhands/openhands-cli"}, "persistence_dir": "/Users/rohitmalhotra/.openhands/conversations/2d73fc17-6d31-4a5c-ba0d-19c80888bdf3", "max_iterations": 500, "stuck_detection": true, "agent_status": "idle", "confirmation_policy": {"kind": "AlwaysConfirm"}, "activated_knowledge_microagents": [], "stats": {"service_to_metrics": {"agent": {"model_name": "litellm_proxy/claude-sonnet-4-20250514", "accumulated_cost": 0.0, "accumulated_token_usage": {"model": "litellm_proxy/claude-sonnet-4-20250514", "prompt_tokens": 0, "completion_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0, "reasoning_tokens": 0, "context_window": 0, "per_turn_token": 0, "response_id": ""}, "costs": [], "response_latencies": [], "token_usages": []}, "condenser": {"model_name": "litellm_proxy/claude-sonnet-4-20250514", "accumulated_cost": 0.0, "accumulated_token_usage": {"model": "litellm_proxy/claude-sonnet-4-20250514", "prompt_tokens": 0, "completion_tokens": 0, "cache_read_tokens": 0, "cache_write_tokens": 0, "reasoning_tokens": 0, "context_window": 0, "per_turn_token": 0, "response_id": ""}, "costs": [], "response_latencies": [], "token_usages": []}}}""" # noqa: E501 + + llm = LLM(model="test-model", api_key=None) + file_store = LocalFileStore(root=str(tmp_path)) + file_store.write( + "conversations/2d73fc17-6d31-4a5c-ba0d-19c80888bdf3/base_state.json", + DUMMY_BASE_STATE, + ) + + # Update agent security analyzer to test reconciliation + agent = Agent(llm=llm, security_analyzer=LLMSecurityAnalyzer()) + + # Creating conversation should migrate security analyzer + conversation = LocalConversation( + agent=agent, + workspace=Workspace(working_dir="/tmp"), + persistence_dir=str(tmp_path), + conversation_id=uuid.UUID("2d73fc17-6d31-4a5c-ba0d-19c80888bdf3"), + ) + + assert isinstance(conversation.state.security_analyzer, LLMSecurityAnalyzer) + assert agent.security_analyzer is None + + +def test_agent_serialize_deserialize_does_not_change_analyzer(tmp_path): + """ + Just serializing and deserializing should not wipe + security analyzer information. Only when a conversation is + created should the security analyzer information be transferred. + """ + + llm = LLM(model="test-model", api_key=None) + agent = Agent(llm=llm, security_analyzer=LLMSecurityAnalyzer()) + + agent = Agent.model_validate_json(agent.model_dump_json()) + assert isinstance(agent.security_analyzer, LLMSecurityAnalyzer) + + conversation = LocalConversation( + agent=agent, workspace=Workspace(working_dir="/tmp") + ) + + assert isinstance(conversation.state.security_analyzer, LLMSecurityAnalyzer) + assert agent.security_analyzer is None diff --git a/tests/sdk/agent/test_security_policy_integration.py b/tests/sdk/agent/test_security_policy_integration.py index 585ac2b94..96c406567 100644 --- a/tests/sdk/agent/test_security_policy_integration.py +++ b/tests/sdk/agent/test_security_policy_integration.py @@ -15,7 +15,6 @@ from openhands.sdk.conversation import Conversation from openhands.sdk.event import ActionEvent, AgentErrorEvent from openhands.sdk.llm import LLM, Message, TextContent -from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer def test_security_policy_in_system_message(): @@ -91,7 +90,6 @@ def test_security_policy_template_rendering(): def test_llm_security_analyzer_template_kwargs(): """Test that agent sets template_kwargs appropriately when security analyzer is LLMSecurityAnalyzer.""" # noqa: E501 - # Create agent with LLMSecurityAnalyzer agent = Agent( llm=LLM( usage_id="test-llm", @@ -99,10 +97,9 @@ def test_llm_security_analyzer_template_kwargs(): api_key=SecretStr("test-key"), base_url="http://test", ), - security_analyzer=LLMSecurityAnalyzer(), ) - # Access the system_message property to trigger template_kwargs computation + # Get system message (security analyzer context is automatically included) system_message = agent.system_message # Verify that the security risk assessment section is included in the system prompt @@ -118,7 +115,7 @@ def test_llm_security_analyzer_template_kwargs(): def test_llm_security_analyzer_sandbox_mode(): """Test that agent includes sandbox mode security risk assessment when cli_mode=False.""" # noqa: E501 - # Create agent with LLMSecurityAnalyzer and cli_mode=False + # Create agent with cli_mode=False agent = Agent( llm=LLM( usage_id="test-llm", @@ -126,13 +123,14 @@ def test_llm_security_analyzer_sandbox_mode(): api_key=SecretStr("test-key"), base_url="http://test", ), - security_analyzer=LLMSecurityAnalyzer(), system_prompt_kwargs={"cli_mode": False}, ) - # Access the system_message property to trigger template_kwargs computation + # Get system message (security analyzer context is automatically included) system_message = agent.system_message + print(agent.system_prompt_kwargs) + # Verify that the security risk assessment section is included with sandbox mode content # noqa: E501 assert "" in system_message assert "# Security Risk Policy" in system_message @@ -144,7 +142,7 @@ def test_llm_security_analyzer_sandbox_mode(): assert "**Global Rules**" in system_message -def test_no_security_analyzer_excludes_risk_assessment(): +def test_no_security_analyzer_still_includes_risk_assessment(): """Test that security risk assessment section is excluded when no security analyzer is set.""" # noqa: E501 # Create agent without security analyzer agent = Agent( @@ -156,19 +154,16 @@ def test_no_security_analyzer_excludes_risk_assessment(): ) ) - # Get the system message + # Get the system message with no security analyzer system_message = agent.system_message # Verify that the security risk assessment section is NOT included - assert "" not in system_message - assert "# Security Risk Policy" not in system_message - assert ( - "When using tools that support the security_risk parameter" - not in system_message - ) + assert "" in system_message + assert "# Security Risk Policy" in system_message + assert "When using tools that support the security_risk parameter" in system_message -def test_non_llm_security_analyzer_excludes_risk_assessment(): +def test_non_llm_security_analyzer_still_includes_risk_assessment(): """Test that security risk assessment section is excluded when security analyzer is not LLMSecurityAnalyzer.""" # noqa: E501 from openhands.sdk.security.analyzer import SecurityAnalyzerBase from openhands.sdk.security.risk import SecurityRisk @@ -192,12 +187,9 @@ def security_risk(self, action: ActionEvent) -> SecurityRisk: system_message = agent.system_message # Verify that the security risk assessment section is NOT included - assert "" not in system_message - assert "# Security Risk Policy" not in system_message - assert ( - "When using tools that support the security_risk parameter" - not in system_message - ) + assert "" in system_message + assert "# Security Risk Policy" in system_message + assert "When using tools that support the security_risk parameter" in system_message def _tool_response(name: str, args_json: str) -> ModelResponse: diff --git a/tests/sdk/conversation/local/test_confirmation_mode.py b/tests/sdk/conversation/local/test_confirmation_mode.py index 0176aee01..4344cd77a 100644 --- a/tests/sdk/conversation/local/test_confirmation_mode.py +++ b/tests/sdk/conversation/local/test_confirmation_mode.py @@ -658,31 +658,31 @@ def test_pause_during_confirmation_preserves_waiting_status(self): def test_is_confirmation_mode_active_property(self): """Test the is_confirmation_mode_active property behavior.""" # Initially, no security analyzer and NeverConfirm policy - assert self.conversation.state.agent.security_analyzer is None + assert self.conversation.state.security_analyzer is None assert self.conversation.state.confirmation_policy == NeverConfirm() assert not self.conversation.confirmation_policy_active assert not self.conversation.is_confirmation_mode_active # Set confirmation policy to AlwaysConfirm, but still no security analyzer self.conversation.set_confirmation_policy(AlwaysConfirm()) - assert self.conversation.state.agent.security_analyzer is None + assert self.conversation.state.security_analyzer is None assert self.conversation.state.confirmation_policy == AlwaysConfirm() assert self.conversation.confirmation_policy_active # Still False because no security analyzer assert not self.conversation.is_confirmation_mode_active - # Create agent with security analyzer + # Create agent and set security analyzer on conversation state from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer - agent_with_analyzer = Agent( + agent = Agent( llm=self.llm, tools=[Tool(name="test_tool")], - security_analyzer=LLMSecurityAnalyzer(), ) - conversation_with_analyzer = Conversation(agent=agent_with_analyzer) + conversation_with_analyzer = Conversation(agent=agent) + conversation_with_analyzer.set_security_analyzer(LLMSecurityAnalyzer()) # Initially with security analyzer but NeverConfirm policy - assert conversation_with_analyzer.state.agent.security_analyzer is not None + assert conversation_with_analyzer.state.security_analyzer is not None assert conversation_with_analyzer.state.confirmation_policy == NeverConfirm() assert not conversation_with_analyzer.confirmation_policy_active # False because policy is NeverConfirm @@ -690,7 +690,7 @@ def test_is_confirmation_mode_active_property(self): # Set confirmation policy to AlwaysConfirm with security analyzer conversation_with_analyzer.set_confirmation_policy(AlwaysConfirm()) - assert conversation_with_analyzer.state.agent.security_analyzer is not None + assert conversation_with_analyzer.state.security_analyzer is not None assert conversation_with_analyzer.state.confirmation_policy == AlwaysConfirm() assert conversation_with_analyzer.confirmation_policy_active # True because both conditions are met diff --git a/tests/sdk/conversation/local/test_state_serialization.py b/tests/sdk/conversation/local/test_state_serialization.py index ccf5ab3d1..0c068a391 100644 --- a/tests/sdk/conversation/local/test_state_serialization.py +++ b/tests/sdk/conversation/local/test_state_serialization.py @@ -130,6 +130,7 @@ def test_conversation_state_persistence_save_load(): assert loaded_state.agent.__class__ == agent.__class__ # Test model_dump equality assert loaded_state.model_dump(mode="json") == state.model_dump(mode="json") + # Also verify key fields are preserved assert loaded_state.id == state.id assert len(loaded_state.events) == len(state.events) @@ -544,4 +545,5 @@ def test_conversation_with_agent_different_llm_config(): assert new_conversation._state.agent.llm.api_key.get_secret_value() == "new-key" # Test that the core state structure is preserved (excluding agent differences) new_dump = new_conversation._state.model_dump(mode="json", exclude={"agent"}) + assert new_dump == original_state_dump