From 5a2c212575182402d6259de9f403f430ebf2d8d1 Mon Sep 17 00:00:00 2001 From: Steven C Date: Fri, 7 Nov 2025 15:15:32 -0500 Subject: [PATCH 1/5] Agent conversation handling --- examples/basic/agents_sdk.py | 13 +++- src/guardrails/agents.py | 121 +++++++++++++++++++++++++---------- 2 files changed, 98 insertions(+), 36 deletions(-) diff --git a/examples/basic/agents_sdk.py b/examples/basic/agents_sdk.py index a446fb5..ac1c59f 100644 --- a/examples/basic/agents_sdk.py +++ b/examples/basic/agents_sdk.py @@ -25,6 +25,7 @@ "categories": ["hate", "violence", "self-harm"], }, }, + {"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}}, ], }, "input": { @@ -71,15 +72,23 @@ async def main() -> None: run_config=RunConfig(tracing_disabled=True), session=session, ) + agent = result.new_items[0].agent + print(f"Input guardrails: {[x.name for x in agent.input_guardrails]}") + breakpoint() + print(f"Output guardrails: {[x.name for x in agent.output_guardrails]}") print(f"Assistant: {result.final_output}") except EOFError: print("\nExiting.") break - except InputGuardrailTripwireTriggered: + except InputGuardrailTripwireTriggered as exc: print("🛑 Input guardrail triggered!") + print(exc.guardrail_result.guardrail.name) + print(exc.guardrail_result.output.output_info) continue - except OutputGuardrailTripwireTriggered: + except OutputGuardrailTripwireTriggered as exc: print("🛑 Output guardrail triggered!") + print(exc.guardrail_result.guardrail.name) + print(exc.guardrail_result.output.output_info) continue diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 4f3202c..c7f8b52 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu media_type="text/plain", guardrails=[guardrail], suppress_tripwire=True, - stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}", + stage_name="tool_input", raise_guardrail_errors=raise_guardrail_errors, ) @@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction media_type="text/plain", guardrails=[guardrail], suppress_tripwire=True, - stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}", + stage_name="tool_output", raise_guardrail_errors=raise_guardrail_errors, ) @@ -338,6 +338,53 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction return tool_output_gr +def _extract_text_from_input(input_data: Any) -> str: + """Extract text from input_data, handling both string and conversation history formats. + + The Agents SDK may pass input_data in different formats: + - String: Direct text input + - List of dicts: Conversation history with message objects + + Args: + input_data: Input from Agents SDK (string or list of messages) + + Returns: + Extracted text string from the latest user message + """ + # If it's already a string, return it + if isinstance(input_data, str): + return input_data + + # If it's a list (conversation history), extract the latest user message + if isinstance(input_data, list) and len(input_data) > 0: + # Iterate from the end to find the latest user message + for msg in reversed(input_data): + if isinstance(msg, dict): + role = msg.get("role") + if role == "user": + content = msg.get("content") + # Content can be a string or a list of content parts + if isinstance(content, str): + return content + elif isinstance(content, list): + # Extract text from content parts + text_parts = [] + for part in content: + if isinstance(part, dict): + # Check for various text field names + text = part.get("text") or part.get("input_text") or part.get("output_text") + if text and isinstance(text, str): + text_parts.append(text) + if text_parts: + return " ".join(text_parts) + # If content is something else, try to stringify it + elif content is not None: + return str(content) + + # Fallback: convert to string + return str(input_data) + + def _create_agents_guardrails_from_config( config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False ) -> list[Any]: @@ -355,7 +402,7 @@ def _create_agents_guardrails_from_config( If False (default), treat guardrail errors as safe and continue execution. Returns: - List of guardrail functions that can be used with Agents SDK + List of guardrail functions (one per individual guardrail) ready for Agents SDK Raises: ImportError: If agents package is not available @@ -372,17 +419,15 @@ def _create_agents_guardrails_from_config( # Load and parse the pipeline configuration pipeline = load_pipeline_bundles(config) - # Instantiate guardrails for requested stages and filter out tool-level guardrails - stage_guardrails = {} + # Collect all individual guardrails from requested stages (filter out tool-level) + all_guardrails = [] for stage_name in stages: stage = getattr(pipeline, stage_name, None) if stage: - all_guardrails = instantiate_guardrails(stage, default_spec_registry) + stage_guardrails = instantiate_guardrails(stage, default_spec_registry) # Filter out tool-level guardrails - they're handled separately - _, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails) - stage_guardrails[stage_name] = agent_level_guardrails - else: - stage_guardrails[stage_name] = [] + _, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails) + all_guardrails.extend(agent_level_guardrails) # Create default context if none provided if context is None: @@ -394,31 +439,30 @@ class DefaultContext: context = DefaultContext(guardrail_llm=AsyncOpenAI()) - def _create_stage_guardrail(stage_name: str): - async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: - """Guardrail function for a specific pipeline stage.""" + def _create_individual_guardrail(guardrail): + """Create a function for a single specific guardrail.""" + async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: + """Guardrail function for a specific guardrail check.""" try: - # Get guardrails for this stage (already filtered to exclude prompt injection) - guardrails = stage_guardrails.get(stage_name, []) - if not guardrails: - return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) + # Extract text from input_data (handle both string and conversation history formats) + text_data = _extract_text_from_input(input_data) - # Run the guardrails for this stage + # Run this single guardrail results = await run_guardrails( ctx=context, - data=input_data, + data=text_data, media_type="text/plain", - guardrails=guardrails, + guardrails=[guardrail], # Just this one guardrail suppress_tripwire=True, # We handle tripwires manually - stage_name=stage_name, + stage_name=guardrail_type, # "input" or "output" - indicates which stage raise_guardrail_errors=raise_guardrail_errors, ) - # Check if any tripwires were triggered + # Check if tripwire was triggered for result in results: if result.tripwire_triggered: - guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown" - return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True) + # Return full metadata in output_info for consistency with tool guardrails + return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True) return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) @@ -428,24 +472,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data raise e else: # Current behavior: treat errors as tripwires - return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True) - - # Set the function name for debugging - stage_guardrail.__name__ = f"{stage_name}_guardrail" - return stage_guardrail + # Return structured error info for consistency + return GuardrailFunctionOutput( + output_info={ + "error": str(e), + "guardrail_name": guardrail.definition.name, + }, + tripwire_triggered=True, + ) + + # Set the function name to the guardrail name (e.g., "Moderation" → "Moderation") + single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_") + + return single_guardrail guardrail_functions = [] - for stage in stages: - stage_guardrail = _create_stage_guardrail(stage) + # Create one function per individual guardrail (Agents SDK runs them concurrently) + for guardrail in all_guardrails: + guardrail_func = _create_individual_guardrail(guardrail) # Decorate with the appropriate guardrail decorator if guardrail_type == "input": - stage_guardrail = input_guardrail(stage_guardrail) + guardrail_func = input_guardrail(guardrail_func) else: - stage_guardrail = output_guardrail(stage_guardrail) + guardrail_func = output_guardrail(guardrail_func) - guardrail_functions.append(stage_guardrail) + guardrail_functions.append(guardrail_func) return guardrail_functions From a26126da8531f49ecea03e060228fa9a29df2aaf Mon Sep 17 00:00:00 2001 From: Steven C Date: Fri, 7 Nov 2025 16:04:16 -0500 Subject: [PATCH 2/5] add unit test --- examples/basic/agents_sdk.py | 4 - src/guardrails/agents.py | 10 +- tests/unit/test_agents.py | 354 ++++++++++++++++++++++++++++++++++- 3 files changed, 357 insertions(+), 11 deletions(-) diff --git a/examples/basic/agents_sdk.py b/examples/basic/agents_sdk.py index ac1c59f..4ade9d1 100644 --- a/examples/basic/agents_sdk.py +++ b/examples/basic/agents_sdk.py @@ -72,10 +72,6 @@ async def main() -> None: run_config=RunConfig(tracing_disabled=True), session=session, ) - agent = result.new_items[0].agent - print(f"Input guardrails: {[x.name for x in agent.input_guardrails]}") - breakpoint() - print(f"Output guardrails: {[x.name for x in agent.output_guardrails]}") print(f"Assistant: {result.final_output}") except EOFError: print("\nExiting.") diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index c7f8b52..f118b68 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -356,7 +356,10 @@ def _extract_text_from_input(input_data: Any) -> str: return input_data # If it's a list (conversation history), extract the latest user message - if isinstance(input_data, list) and len(input_data) > 0: + if isinstance(input_data, list): + if len(input_data) == 0: + return "" # Empty list returns empty string + # Iterate from the end to find the latest user message for msg in reversed(input_data): if isinstance(msg, dict): @@ -381,6 +384,9 @@ def _extract_text_from_input(input_data: Any) -> str: elif content is not None: return str(content) + # No user message found in list + return "" + # Fallback: convert to string return str(input_data) @@ -483,7 +489,7 @@ async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_dat # Set the function name to the guardrail name (e.g., "Moderation" → "Moderation") single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_") - + return single_guardrail guardrail_functions = [] diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index ea96c33..e9a4f52 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import SimpleNamespace -from typing import Any +from typing import Any, TypedDict import pytest @@ -130,6 +130,15 @@ async def run(self, *args: Any, **kwargs: Any) -> Any: import guardrails.runtime as runtime_module # noqa: E402 +# Add mock for TResponseInputItem for testing +class TResponseInputItem(TypedDict): + """Mock type for Agents SDK response input item.""" + + role: str + content: Any + type: str + + def _make_guardrail(name: str) -> Any: class _DummyCtxModel: model_fields: dict[str, Any] = {} @@ -274,7 +283,7 @@ async def test_create_tool_guardrail_rejects_on_tripwire(monkeypatch: pytest.Mon agents._agent_conversation.set(({"role": "user", "content": "Original request"},)) async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - assert kwargs["stage_name"] == "tool_input_test_guardrail" # noqa: S101 + assert kwargs["stage_name"] == "tool_input" # noqa: S101 # Updated: now uses simple stage name history = kwargs["ctx"].get_conversation_history() assert history[-1]["tool_name"] == "weather" # noqa: S101 return [GuardrailResult(tripwire_triggered=True, info=expected_info)] @@ -426,8 +435,10 @@ async def test_create_agents_guardrails_from_config_tripwire(monkeypatch: pytest lambda stage, registry=None: [_make_guardrail("Input Guard")] if stage is pipeline.input else [], ) + expected_info = {"reason": "blocked", "guardrail_name": "Input Guard"} + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - return [GuardrailResult(tripwire_triggered=True, info={"reason": "blocked"})] + return [GuardrailResult(tripwire_triggered=True, info=expected_info)] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) @@ -442,7 +453,9 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hi") assert result.tripwire_triggered is True # noqa: S101 - assert result.output_info == "Guardrail unknown triggered tripwire" # noqa: S101 + # Updated: now returns full metadata dict instead of string + assert result.output_info == expected_info # noqa: S101 + assert result.output_info["reason"] == "blocked" # noqa: S101 @pytest.mark.asyncio @@ -472,7 +485,10 @@ async def failing_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("name", "instr"), "msg") assert result.tripwire_triggered is True # noqa: S101 - assert "Error running input guardrails" in result.output_info # noqa: S101 + # Updated: output_info is now a dict with error information + assert isinstance(result.output_info, dict) # noqa: S101 + assert "error" in result.output_info # noqa: S101 + assert "boom" in result.output_info["error"] # noqa: S101 @pytest.mark.asyncio @@ -729,3 +745,331 @@ def test_guardrail_agent_with_empty_user_guardrails(monkeypatch: pytest.MonkeyPa assert isinstance(agent_instance, agents_module.Agent) # noqa: S101 assert agent_instance.input_guardrails == [] # noqa: S101 assert agent_instance.output_guardrails == [] # noqa: S101 + + +# ============================================================================= +# Tests for _extract_text_from_input (new function for conversation history bug fix) +# ============================================================================= + + +def test_extract_text_from_input_with_plain_string() -> None: + """Plain string input should be returned as-is.""" + result = agents._extract_text_from_input("Hello world") + assert result == "Hello world" # noqa: S101 + + +def test_extract_text_from_input_with_empty_string() -> None: + """Empty string should be returned as-is.""" + result = agents._extract_text_from_input("") + assert result == "" # noqa: S101 + + +def test_extract_text_from_input_with_single_message() -> None: + """Single message in list format should extract text content.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + { + "type": "input_text", + "text": "What is the weather?", + } + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "What is the weather?" # noqa: S101 + + +def test_extract_text_from_input_with_conversation_history() -> None: + """Multi-turn conversation should extract latest user message.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [{"type": "input_text", "text": "Hello"}], + }, + { + "role": "assistant", + "type": "message", + "content": [{"type": "output_text", "text": "Hi there!"}], + }, + { + "role": "user", + "type": "message", + "content": [{"type": "input_text", "text": "How are you?"}], + }, + ] + result = agents._extract_text_from_input(input_data) + assert result == "How are you?" # noqa: S101 + + +def test_extract_text_from_input_with_multiple_content_parts() -> None: + """Message with multiple text parts should be concatenated.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "input_text", "text": "Hello"}, + {"type": "input_text", "text": "world"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "Hello world" # noqa: S101 + + +def test_extract_text_from_input_with_non_text_content() -> None: + """Non-text content parts should be ignored.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "image", "url": "http://example.com/image.jpg"}, + {"type": "input_text", "text": "What is this?"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "What is this?" # noqa: S101 + + +def test_extract_text_from_input_with_string_content() -> None: + """Message with string content (legacy format) should work.""" + input_data = [ + { + "role": "user", + "content": "Simple string content", + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "Simple string content" # noqa: S101 + + +def test_extract_text_from_input_with_empty_list() -> None: + """Empty list should return empty string.""" + result = agents._extract_text_from_input([]) + assert result == "" # noqa: S101 + + +def test_extract_text_from_input_with_no_user_messages() -> None: + """List with only assistant messages should return empty string.""" + input_data = [ + { + "role": "assistant", + "type": "message", + "content": [{"type": "output_text", "text": "Assistant message"}], + } + ] + result = agents._extract_text_from_input(input_data) + assert result == "" # noqa: S101 + + +# ============================================================================= +# Tests for updated agent-level guardrail behavior (stage_name and metadata) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_agent_guardrail_uses_correct_stage_name(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent guardrails should use simple stage names like 'input' or 'output'.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Moderation")] if stage is pipeline.input else [], + ) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hello") + + # Should use simple stage name "input", not guardrail name + assert captured_stage_name == "input" # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrail_returns_full_metadata_on_trigger(monkeypatch: pytest.MonkeyPatch) -> None: + """Triggered agent guardrails should return full info dict in output_info.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Jailbreak")] if stage is pipeline.input else [], + ) + + expected_metadata = { + "guardrail_name": "Jailbreak", + "observation": "Jailbreak attempt detected", + "confidence": 0.95, + "threshold": 0.7, + "flagged": True, + } + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=True, info=expected_metadata)] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="llm"), + raise_guardrail_errors=False, + ) + + result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("a", "b"), "hack the system") + + assert result.tripwire_triggered is True # noqa: S101 + # Should return full metadata dict, not just a string + assert result.output_info == expected_metadata # noqa: S101 + assert result.output_info["confidence"] == 0.95 # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrail_function_has_descriptive_name(monkeypatch: pytest.MonkeyPatch) -> None: + """Agent guardrail functions should be named after their guardrail.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Contains PII")] if stage is pipeline.input else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + # Function name should be the guardrail name (with underscores) + assert guardrails[0].__name__ == "Contains_PII" # noqa: S101 + + +@pytest.mark.asyncio +async def test_agent_guardrails_creates_individual_functions_per_guardrail(monkeypatch: pytest.MonkeyPatch) -> None: + """Should create one agent-level guardrail function per individual guardrail.""" + pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) + monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) + monkeypatch.setattr( + runtime_module, + "instantiate_guardrails", + lambda stage, registry=None: [_make_guardrail("Moderation"), _make_guardrail("Jailbreak")] if stage is pipeline.input else [], + ) + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + guardrails = agents._create_agents_guardrails_from_config( + config={}, + stages=["input"], + guardrail_type="input", + context=None, + raise_guardrail_errors=False, + ) + + # Should have 2 separate guardrail functions + assert len(guardrails) == 2 # noqa: S101 + assert guardrails[0].__name__ == "Moderation" # noqa: S101 + assert guardrails[1].__name__ == "Jailbreak" # noqa: S101 + + +# ============================================================================= +# Tests for updated tool-level guardrail behavior (stage_name) +# ============================================================================= + + +@pytest.mark.asyncio +async def test_tool_guardrail_uses_correct_stage_name_input(monkeypatch: pytest.MonkeyPatch) -> None: + """Tool input guardrails should use 'tool_input' as stage_name.""" + guardrail = _make_guardrail("Prompt Injection Detection") + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hello"},)) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="input", + context=SimpleNamespace(guardrail_llm="client"), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolInputGuardrailData(context=ToolContext(tool_name="weather", tool_arguments={"city": "Paris"})) + await tool_fn(data) + + # Should use "tool_input", not a guardrail-specific name + assert captured_stage_name == "tool_input" # noqa: S101 + + +@pytest.mark.asyncio +async def test_tool_guardrail_uses_correct_stage_name_output(monkeypatch: pytest.MonkeyPatch) -> None: + """Tool output guardrails should use 'tool_output' as stage_name.""" + guardrail = _make_guardrail("Prompt Injection Detection") + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hello"},)) + + captured_stage_name = None + + async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: + nonlocal captured_stage_name + captured_stage_name = kwargs["stage_name"] + return [GuardrailResult(tripwire_triggered=False, info={})] + + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + + tool_fn = agents._create_tool_guardrail( + guardrail=guardrail, + guardrail_type="output", + context=SimpleNamespace(guardrail_llm="client"), + raise_guardrail_errors=False, + block_on_violations=False, + ) + + data = agents_module.ToolOutputGuardrailData( + context=ToolContext(tool_name="math", tool_arguments={"x": 1}), + output="Result: 42", + ) + await tool_fn(data) + + # Should use "tool_output", not a guardrail-specific name + assert captured_stage_name == "tool_output" # noqa: S101 From c47d352dda8a1d2f28a8f45b579faa3d48e29dc1 Mon Sep 17 00:00:00 2001 From: Steven C Date: Fri, 7 Nov 2025 16:28:22 -0500 Subject: [PATCH 3/5] better part checking --- src/guardrails/agents.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index f118b68..c972877 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -357,7 +357,7 @@ def _extract_text_from_input(input_data: Any) -> str: # If it's a list (conversation history), extract the latest user message if isinstance(input_data, list): - if len(input_data) == 0: + if not input_data: return "" # Empty list returns empty string # Iterate from the end to find the latest user message @@ -374,8 +374,12 @@ def _extract_text_from_input(input_data: Any) -> str: text_parts = [] for part in content: if isinstance(part, dict): - # Check for various text field names - text = part.get("text") or part.get("input_text") or part.get("output_text") + # Check for various text field names (avoid falsy empty string issue) + text = None + for field in ['text', 'input_text', 'output_text']: + if field in part: + text = part[field] + break if text and isinstance(text, str): text_parts.append(text) if text_parts: From 4d0fd708e71e6e95cf4b4e72a07c6277c9e85655 Mon Sep 17 00:00:00 2001 From: Steven C Date: Fri, 7 Nov 2025 16:43:59 -0500 Subject: [PATCH 4/5] Preserve input empty strings --- src/guardrails/agents.py | 7 ++++--- tests/unit/test_agents.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index c972877..e4c11b4 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -380,7 +380,8 @@ def _extract_text_from_input(input_data: Any) -> str: if field in part: text = part[field] break - if text and isinstance(text, str): + # Preserve empty strings, only filter None + if text is not None and isinstance(text, str): text_parts.append(text) if text_parts: return " ".join(text_parts) @@ -478,8 +479,8 @@ async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_dat except Exception as e: if raise_guardrail_errors: - # Re-raise the exception to stop execution - raise e + # Re-raise the exception to stop execution (preserve traceback) + raise else: # Current behavior: treat errors as tripwires # Return structured error info for consistency diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index e9a4f52..7f48b35 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -868,6 +868,24 @@ def test_extract_text_from_input_with_no_user_messages() -> None: assert result == "" # noqa: S101 +def test_extract_text_from_input_preserves_empty_strings() -> None: + """Empty strings in content parts should be preserved, not filtered out.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [ + {"type": "input_text", "text": "Hello"}, + {"type": "input_text", "text": ""}, # Empty string should be preserved + {"type": "input_text", "text": "World"}, + ], + } + ] + result = agents._extract_text_from_input(input_data) + # Empty string should be included, resulting in extra space + assert result == "Hello World" # noqa: S101 + + # ============================================================================= # Tests for updated agent-level guardrail behavior (stage_name and metadata) # ============================================================================= From a2af559cae58f43239208ad534ffaebc06b229ea Mon Sep 17 00:00:00 2001 From: Steven C Date: Fri, 7 Nov 2025 16:57:52 -0500 Subject: [PATCH 5/5] Address Copilot comments --- src/guardrails/agents.py | 13 +++++++++++-- tests/unit/test_agents.py | 25 +++++++++++++++---------- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index e4c11b4..31d5722 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -370,6 +370,9 @@ def _extract_text_from_input(input_data: Any) -> str: if isinstance(content, str): return content elif isinstance(content, list): + if not content: + # Empty content list returns empty string (consistent with no text parts found) + return "" # Extract text from content parts text_parts = [] for part in content: @@ -385,6 +388,8 @@ def _extract_text_from_input(input_data: Any) -> str: text_parts.append(text) if text_parts: return " ".join(text_parts) + # No text parts found, return empty string + return "" # If content is something else, try to stringify it elif content is not None: return str(content) @@ -452,8 +457,12 @@ class DefaultContext: def _create_individual_guardrail(guardrail): """Create a function for a single specific guardrail.""" - async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: - """Guardrail function for a specific guardrail check.""" + async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str | list) -> GuardrailFunctionOutput: + """Guardrail function for a specific guardrail check. + + Note: input_data is typed as str in Agents SDK, but can actually be a list + of message objects when conversation history is used. We handle both cases. + """ try: # Extract text from input_data (handle both string and conversation history formats) text_data = _extract_text_from_input(input_data) diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index 7f48b35..80c51a3 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -7,7 +7,7 @@ from collections.abc import Callable from dataclasses import dataclass from types import SimpleNamespace -from typing import Any, TypedDict +from typing import Any import pytest @@ -130,15 +130,6 @@ async def run(self, *args: Any, **kwargs: Any) -> Any: import guardrails.runtime as runtime_module # noqa: E402 -# Add mock for TResponseInputItem for testing -class TResponseInputItem(TypedDict): - """Mock type for Agents SDK response input item.""" - - role: str - content: Any - type: str - - def _make_guardrail(name: str) -> Any: class _DummyCtxModel: model_fields: dict[str, Any] = {} @@ -886,6 +877,20 @@ def test_extract_text_from_input_preserves_empty_strings() -> None: assert result == "Hello World" # noqa: S101 +def test_extract_text_from_input_with_empty_content_list() -> None: + """Empty content list should return empty string, not stringified list.""" + input_data = [ + { + "role": "user", + "type": "message", + "content": [], # Empty content list + } + ] + result = agents._extract_text_from_input(input_data) + # Should return empty string, not "[]" + assert result == "" # noqa: S101 + + # ============================================================================= # Tests for updated agent-level guardrail behavior (stage_name and metadata) # =============================================================================