Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/basic/agents_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"categories": ["hate", "violence", "self-harm"],
},
},
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
],
},
"input": {
Expand Down Expand Up @@ -75,11 +76,15 @@ async def main() -> None:
except EOFError:
print("\nExiting.")
break
except InputGuardrailTripwireTriggered:
except InputGuardrailTripwireTriggered as exc:
print("🛑 Input guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue
except OutputGuardrailTripwireTriggered:
except OutputGuardrailTripwireTriggered as exc:
print("🛑 Output guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue


Expand Down
143 changes: 108 additions & 35 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_input",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand Down Expand Up @@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_output",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand All @@ -338,6 +338,69 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
return tool_output_gr


def _extract_text_from_input(input_data: Any) -> str:
"""Extract text from input_data, handling both string and conversation history formats.

The Agents SDK may pass input_data in different formats:
- String: Direct text input
- List of dicts: Conversation history with message objects

Args:
input_data: Input from Agents SDK (string or list of messages)

Returns:
Extracted text string from the latest user message
"""
# If it's already a string, return it
if isinstance(input_data, str):
return input_data

# If it's a list (conversation history), extract the latest user message
if isinstance(input_data, list):
if not input_data:
return "" # Empty list returns empty string

# Iterate from the end to find the latest user message
for msg in reversed(input_data):
if isinstance(msg, dict):
role = msg.get("role")
if role == "user":
content = msg.get("content")
# Content can be a string or a list of content parts
if isinstance(content, str):
return content
elif isinstance(content, list):
if not content:
# Empty content list returns empty string (consistent with no text parts found)
return ""
# Extract text from content parts
text_parts = []
for part in content:
if isinstance(part, dict):
# Check for various text field names (avoid falsy empty string issue)
text = None
for field in ['text', 'input_text', 'output_text']:
if field in part:
text = part[field]
break
# Preserve empty strings, only filter None
if text is not None and isinstance(text, str):
text_parts.append(text)
if text_parts:
return " ".join(text_parts)
# No text parts found, return empty string
return ""
# If content is something else, try to stringify it
elif content is not None:
return str(content)

# No user message found in list
return ""

# Fallback: convert to string
return str(input_data)


def _create_agents_guardrails_from_config(
config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False
) -> list[Any]:
Expand All @@ -355,7 +418,7 @@ def _create_agents_guardrails_from_config(
If False (default), treat guardrail errors as safe and continue execution.

Returns:
List of guardrail functions that can be used with Agents SDK
List of guardrail functions (one per individual guardrail) ready for Agents SDK

Raises:
ImportError: If agents package is not available
Expand All @@ -372,17 +435,15 @@ def _create_agents_guardrails_from_config(
# Load and parse the pipeline configuration
pipeline = load_pipeline_bundles(config)

# Instantiate guardrails for requested stages and filter out tool-level guardrails
stage_guardrails = {}
# Collect all individual guardrails from requested stages (filter out tool-level)
all_guardrails = []
for stage_name in stages:
stage = getattr(pipeline, stage_name, None)
if stage:
all_guardrails = instantiate_guardrails(stage, default_spec_registry)
stage_guardrails = instantiate_guardrails(stage, default_spec_registry)
# Filter out tool-level guardrails - they're handled separately
_, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails)
stage_guardrails[stage_name] = agent_level_guardrails
else:
stage_guardrails[stage_name] = []
_, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails)
all_guardrails.extend(agent_level_guardrails)

# Create default context if none provided
if context is None:
Expand All @@ -394,58 +455,70 @@ class DefaultContext:

context = DefaultContext(guardrail_llm=AsyncOpenAI())

def _create_stage_guardrail(stage_name: str):
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
"""Guardrail function for a specific pipeline stage."""
def _create_individual_guardrail(guardrail):
"""Create a function for a single specific guardrail."""
async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str | list) -> GuardrailFunctionOutput:
"""Guardrail function for a specific guardrail check.

Note: input_data is typed as str in Agents SDK, but can actually be a list
of message objects when conversation history is used. We handle both cases.
"""
try:
# Get guardrails for this stage (already filtered to exclude prompt injection)
guardrails = stage_guardrails.get(stage_name, [])
if not guardrails:
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
# Extract text from input_data (handle both string and conversation history formats)
text_data = _extract_text_from_input(input_data)

# Run the guardrails for this stage
# Run this single guardrail
results = await run_guardrails(
ctx=context,
data=input_data,
data=text_data,
media_type="text/plain",
guardrails=guardrails,
guardrails=[guardrail], # Just this one guardrail
suppress_tripwire=True, # We handle tripwires manually
stage_name=stage_name,
stage_name=guardrail_type, # "input" or "output" - indicates which stage
raise_guardrail_errors=raise_guardrail_errors,
)

# Check if any tripwires were triggered
# Check if tripwire was triggered
for result in results:
if result.tripwire_triggered:
guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown"
return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True)
# Return full metadata in output_info for consistency with tool guardrails
return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True)

return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)

except Exception as e:
if raise_guardrail_errors:
# Re-raise the exception to stop execution
raise e
# Re-raise the exception to stop execution (preserve traceback)
raise
else:
# Current behavior: treat errors as tripwires
return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True)
# Return structured error info for consistency
return GuardrailFunctionOutput(
output_info={
"error": str(e),
"guardrail_name": guardrail.definition.name,
},
tripwire_triggered=True,
)

# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_")

# Set the function name for debugging
stage_guardrail.__name__ = f"{stage_name}_guardrail"
return stage_guardrail
return single_guardrail

guardrail_functions = []

for stage in stages:
stage_guardrail = _create_stage_guardrail(stage)
# Create one function per individual guardrail (Agents SDK runs them concurrently)
for guardrail in all_guardrails:
guardrail_func = _create_individual_guardrail(guardrail)

# Decorate with the appropriate guardrail decorator
if guardrail_type == "input":
stage_guardrail = input_guardrail(stage_guardrail)
guardrail_func = input_guardrail(guardrail_func)
else:
stage_guardrail = output_guardrail(stage_guardrail)
guardrail_func = output_guardrail(guardrail_func)

guardrail_functions.append(stage_guardrail)
guardrail_functions.append(guardrail_func)

return guardrail_functions

Expand Down
Loading
Loading