diff --git a/examples/basic/agents_sdk.py b/examples/basic/agents_sdk.py index a446fb5..4ade9d1 100644 --- a/examples/basic/agents_sdk.py +++ b/examples/basic/agents_sdk.py @@ -25,6 +25,7 @@ "categories": ["hate", "violence", "self-harm"], }, }, + {"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}}, ], }, "input": { @@ -75,11 +76,15 @@ async def main() -> None: except EOFError: print("\nExiting.") break - except InputGuardrailTripwireTriggered: + except InputGuardrailTripwireTriggered as exc: print("🛑 Input guardrail triggered!") + print(exc.guardrail_result.guardrail.name) + print(exc.guardrail_result.output.output_info) continue - except OutputGuardrailTripwireTriggered: + except OutputGuardrailTripwireTriggered as exc: print("🛑 Output guardrail triggered!") + print(exc.guardrail_result.guardrail.name) + print(exc.guardrail_result.output.output_info) continue diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 4f3202c..31d5722 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu media_type="text/plain", guardrails=[guardrail], suppress_tripwire=True, - stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}", + stage_name="tool_input", raise_guardrail_errors=raise_guardrail_errors, ) @@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction media_type="text/plain", guardrails=[guardrail], suppress_tripwire=True, - stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}", + stage_name="tool_output", raise_guardrail_errors=raise_guardrail_errors, ) @@ -338,6 +338,69 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction return tool_output_gr +def _extract_text_from_input(input_data: Any) -> str: + """Extract text from input_data, handling both string and conversation history formats. + + The Agents SDK may pass input_data in different formats: + - String: Direct text input + - List of dicts: Conversation history with message objects + + Args: + input_data: Input from Agents SDK (string or list of messages) + + Returns: + Extracted text string from the latest user message + """ + # If it's already a string, return it + if isinstance(input_data, str): + return input_data + + # If it's a list (conversation history), extract the latest user message + if isinstance(input_data, list): + if not input_data: + return "" # Empty list returns empty string + + # Iterate from the end to find the latest user message + for msg in reversed(input_data): + if isinstance(msg, dict): + role = msg.get("role") + if role == "user": + content = msg.get("content") + # Content can be a string or a list of content parts + if isinstance(content, str): + return content + elif isinstance(content, list): + if not content: + # Empty content list returns empty string (consistent with no text parts found) + return "" + # Extract text from content parts + text_parts = [] + for part in content: + if isinstance(part, dict): + # Check for various text field names (avoid falsy empty string issue) + text = None + for field in ['text', 'input_text', 'output_text']: + if field in part: + text = part[field] + break + # Preserve empty strings, only filter None + if text is not None and isinstance(text, str): + text_parts.append(text) + if text_parts: + return " ".join(text_parts) + # No text parts found, return empty string + return "" + # If content is something else, try to stringify it + elif content is not None: + return str(content) + + # No user message found in list + return "" + + # Fallback: convert to string + return str(input_data) + + def _create_agents_guardrails_from_config( config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False ) -> list[Any]: @@ -355,7 +418,7 @@ def _create_agents_guardrails_from_config( If False (default), treat guardrail errors as safe and continue execution. Returns: - List of guardrail functions that can be used with Agents SDK + List of guardrail functions (one per individual guardrail) ready for Agents SDK Raises: ImportError: If agents package is not available @@ -372,17 +435,15 @@ def _create_agents_guardrails_from_config( # Load and parse the pipeline configuration pipeline = load_pipeline_bundles(config) - # Instantiate guardrails for requested stages and filter out tool-level guardrails - stage_guardrails = {} + # Collect all individual guardrails from requested stages (filter out tool-level) + all_guardrails = [] for stage_name in stages: stage = getattr(pipeline, stage_name, None) if stage: - all_guardrails = instantiate_guardrails(stage, default_spec_registry) + stage_guardrails = instantiate_guardrails(stage, default_spec_registry) # Filter out tool-level guardrails - they're handled separately - _, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails) - stage_guardrails[stage_name] = agent_level_guardrails - else: - stage_guardrails[stage_name] = [] + _, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails) + all_guardrails.extend(agent_level_guardrails) # Create default context if none provided if context is None: @@ -394,58 +455,70 @@ class DefaultContext: context = DefaultContext(guardrail_llm=AsyncOpenAI()) - def _create_stage_guardrail(stage_name: str): - async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: - """Guardrail function for a specific pipeline stage.""" + def _create_individual_guardrail(guardrail): + """Create a function for a single specific guardrail.""" + async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str | list) -> GuardrailFunctionOutput: + """Guardrail function for a specific guardrail check. + + Note: input_data is typed as str in Agents SDK, but can actually be a list + of message objects when conversation history is used. We handle both cases. + """ try: - # Get guardrails for this stage (already filtered to exclude prompt injection) - guardrails = stage_guardrails.get(stage_name, []) - if not guardrails: - return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + # Extract text from input_data (handle both string and conversation history formats) + text_data = _extract_text_from_input(input_data) - # Run the guardrails for this stage + # Run this single guardrail results = await run_guardrails( ctx=context, - data=input_data, + data=text_data, media_type="text/plain", - guardrails=guardrails, + guardrails=[guardrail], # Just this one guardrail suppress_tripwire=True, # We handle tripwires manually - stage_name=stage_name, + stage_name=guardrail_type, # "input" or "output" - indicates which stage raise_guardrail_errors=raise_guardrail_errors, ) - # Check if any tripwires were triggered + # Check if tripwire was triggered for result in results: if result.tripwire_triggered: - guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown" - return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True) + # Return full metadata in output_info for consistency with tool guardrails + return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True) return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) except Exception as e: if raise_guardrail_errors: - # Re-raise the exception to stop execution - raise e + # Re-raise the exception to stop execution (preserve traceback) + raise else: # Current behavior: treat errors as tripwires - return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True) + # Return structured error info for consistency + return GuardrailFunctionOutput( + output_info={ + "error": str(e), + "guardrail_name": guardrail.definition.name, + }, + tripwire_triggered=True, + ) + + # Set the function name to the guardrail name (e.g., "Moderation" → "Moderation") + single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_") - # Set the function name for debugging - stage_guardrail.__name__ = f"{stage_name}_guardrail" - return stage_guardrail + return single_guardrail guardrail_functions = [] - for stage in stages: - stage_guardrail = _create_stage_guardrail(stage) + # Create one function per individual guardrail (Agents SDK runs them concurrently) + for guardrail in all_guardrails: + guardrail_func = _create_individual_guardrail(guardrail) # Decorate with the appropriate guardrail decorator if guardrail_type == "input": - stage_guardrail = input_guardrail(stage_guardrail) + guardrail_func = input_guardrail(guardrail_func) else: - stage_guardrail = output_guardrail(stage_guardrail) + guardrail_func = output_guardrail(guardrail_func) - guardrail_functions.append(stage_guardrail) + guardrail_functions.append(guardrail_func) return guardrail_functions diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index ea96c33..80c51a3 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -274,7 +274,7 @@ async def test_create_tool_guardrail_rejects_on_tripwire(monkeypatch: pytest.Mon agents._agent_conversation.set(({"role": "user", "content": "Original request"},)) async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - assert kwargs["stage_name"] == "tool_input_test_guardrail" # noqa: S101 + assert kwargs["stage_name"] == "tool_input" # noqa: S101 # Updated: now uses simple stage name history = kwargs["ctx"].get_conversation_history() assert history[-1]["tool_name"] == "weather" # noqa: S101 return [GuardrailResult(tripwire_triggered=True, info=expected_info)] @@ -426,8 +426,10 @@ async def test_create_agents_guardrails_from_config_tripwire(monkeypatch: pytest lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], ) + expected_info = {"reason": "blocked", "guardrail_name": "Input Guard"} + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - return [GuardrailResult(tripwire_triggered=True, info={"reason": "blocked"})] + return [GuardrailResult(tripwire_triggered=True, info=expected_info)] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) @@ -442,7 +444,9 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hi") assert result.tripwire_triggered is True # noqa: S101 - assert result.output_info == "Guardrail unknown triggered tripwire" # noqa: S101 + # Updated: now returns full metadata dict instead of string + assert result.output_info == expected_info # noqa: S101 + assert result.output_info["reason"] == "blocked" # noqa: S101 @pytest.mark.asyncio @@ -472,7 +476,10 @@ async def failing_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("name", "instr"), "msg") assert result.tripwire_triggered is True # noqa: S101 - assert "Error running input guardrails" in result.output_info # noqa: S101 + # Updated: output_info is now a dict with error information + assert isinstance(result.output_info, dict) # noqa: S101 + assert "error" in result.output_info # noqa: S101 + assert "boom" in result.output_info["error"] # noqa: S101 @pytest.mark.asyncio @@ -729,3 +736,363 @@ def test_guardrail_agent_with_empty_user_guardrails(monkeypatch: pytest.MonkeyPa assert isinstance(agent_instance, agents_module.Agent) # noqa: S101 assert agent_instance.input_guardrails == [] # noqa: S101 assert agent_instance.output_guardrails == [] # noqa: S101 + + +# ============================================================================= +# Tests for _extract_text_from_input (new function for conversation history bug fix) +# ============================================================================= + + +def test_extract_text_from_input_with_plain_string() -> None: + """Plain string input should be returned as-is.""" + result = agents._extract_text_from_input("Hello world") + assert result == "Hello world" # noqa: S101 + + +def test_extract_text_from_input_with_empty_string() -> None: + """Empty string should be returned as-is.""" + result = agents._extract_text_from_input("") + assert result == "" # noqa: S101 + + +def test_extract_text_from_input_with_single_message() -> None: + """Single message in list format should extract text content.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + { + "type": "input_text", + "text": "What is the weather?", + } + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "What is the weather?" # noqa: S101 + + +def test_extract_text_from_input_with_conversation_history() -> None: + """Multi-turn conversation should extract latest user message.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [{"type": "input_text", "text": "Hello"}], + }, + { + "role": "assistant", + "type": "message", + "content": [{"type": "output_text", "text": "Hi there!"}], + }, + { + "role": "user", + "type": "message", + "content": [{"type": "input_text", "text": "How are you?"}], + }, + ] + result = agents._extract_text_from_input(input_data) + assert result == "How are you?" # noqa: S101 + + +def test_extract_text_from_input_with_multiple_content_parts() -> None: + """Message with multiple text parts should be concatenated.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "input_text", "text": "Hello"}, + {"type": "input_text", "text": "world"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "Hello world" # noqa: S101 + + +def test_extract_text_from_input_with_non_text_content() -> None: + """Non-text content parts should be ignored.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "image", "url": "http://example.com/image.jpg"}, + {"type": "input_text", "text": "What is this?"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "What is this?" # noqa: S101 + + +def test_extract_text_from_input_with_string_content() -> None: + """Message with string content (legacy format) should work.""" + input_data = [ + { + "role": "user", + "content": "Simple string content", + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "Simple string content" # noqa: S101 + + +def test_extract_text_from_input_with_empty_list() -> None: + """Empty list should return empty string.""" + result = agents._extract_text_from_input([]) + assert result == "" # noqa: S101 + + +def test_extract_text_from_input_with_no_user_messages() -> None: + """List with only assistant messages should return empty string.""" + input_data = [ + { + "role": "assistant", + "type": "message", + "content": [{"type": "output_text", "text": "Assistant message"}], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "" # noqa: S101 + + +def test_extract_text_from_input_preserves_empty_strings() -> None: + """Empty strings in content parts should be preserved, not filtered out.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "input_text", "text": "Hello"}, + {"type": "input_text", "text": ""}, # Empty string should be preserved + {"type": "input_text", "text": "World"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + # Empty string should be included, resulting in extra space + assert result == "Hello World" # noqa: S101 + + +def test_extract_text_from_input_with_empty_content_list() -> None: + """Empty content list should return empty string, not stringified list.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [], # Empty content list + } + ] + result = agents._extract_text_from_input(input_data) + # Should return empty string, not "[]" + assert result == "" # noqa: S101 + + +# ============================================================================= +# Tests for updated agent-level guardrail behavior (stage_name and metadata) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_agent_guardrail_uses_correct_stage_name(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent guardrails should use simple stage names like 'input' or 'output'.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Moderation")] if stage is pipeline.input else [], + ) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hello") + + # Should use simple stage name "input", not guardrail name + assert captured_stage_name == "input" # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrail_returns_full_metadata_on_trigger(monkeypatch: pytest.MonkeyPatch) -> None: + """Triggered agent guardrails should return full info dict in output_info.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Jailbreak")] if stage is pipeline.input else [], + ) + + expected_metadata = { + "guardrail_name": "Jailbreak", + "observation": "Jailbreak attempt detected", + "confidence": 0.95, + "threshold": 0.7, + "flagged": True, + } + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True, info=expected_metadata)] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=False, + ) + + result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hack the system") + + assert result.tripwire_triggered is True # noqa: S101 + # Should return full metadata dict, not just a string + assert result.output_info == expected_metadata # noqa: S101 + assert result.output_info["confidence"] == 0.95 # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrail_function_has_descriptive_name(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent guardrail functions should be named after their guardrail.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Contains PII")] if stage is pipeline.input else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + # Function name should be the guardrail name (with underscores) + assert guardrails[0].__name__ == "Contains_PII" # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrails_creates_individual_functions_per_guardrail(monkeypatch: pytest.MonkeyPatch) -> None: + """Should create one agent-level guardrail function per individual guardrail.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Moderation"), _make_guardrail("Jailbreak")] if stage is pipeline.input else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + # Should have 2 separate guardrail functions + assert len(guardrails) == 2 # noqa: S101 + assert guardrails[0].__name__ == "Moderation" # noqa: S101 + assert guardrails[1].__name__ == "Jailbreak" # noqa: S101 + + +# ============================================================================= +# Tests for updated tool-level guardrail behavior (stage_name) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_tool_guardrail_uses_correct_stage_name_input(monkeypatch: pytest.MonkeyPatch) -> None: + """Tool input guardrails should use 'tool_input' as stage_name.""" + guardrail = _make_guardrail("Prompt Injection Detection") + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hello"},)) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="client"), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolInputGuardrailData(context=ToolContext(tool_name="weather", tool_arguments={"city": "Paris"})) + await tool_fn(data) + + # Should use "tool_input", not a guardrail-specific name + assert captured_stage_name == "tool_input" # noqa: S101 + + +@pytest.mark.asyncio +async def test_tool_guardrail_uses_correct_stage_name_output(monkeypatch: pytest.MonkeyPatch) -> None: + """Tool output guardrails should use 'tool_output' as stage_name.""" + guardrail = _make_guardrail("Prompt Injection Detection") + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hello"},)) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="output", + context=SimpleNamespace(guardrail_llm="client"), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolOutputGuardrailData( + context=ToolContext(tool_name="math", tool_arguments={"x": 1}), + output="Result: 42", + ) + await tool_fn(data) + + # Should use "tool_output", not a guardrail-specific name + assert captured_stage_name == "tool_output" # noqa: S101