Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/basic/agents_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"categories": ["hate", "violence", "self-harm"],
},
},
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
],
},
"input": {
Expand Down Expand Up @@ -75,11 +76,15 @@ async def main() -> None:
except EOFError:
print("\nExiting.")
break
except InputGuardrailTripwireTriggered:
except InputGuardrailTripwireTriggered as exc:
print("🛑 Input guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue
except OutputGuardrailTripwireTriggered:
except OutputGuardrailTripwireTriggered as exc:
print("🛑 Output guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue


Expand Down
125 changes: 92 additions & 33 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_input",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand Down Expand Up @@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_output",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand All @@ -338,6 +338,59 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
return tool_output_gr


def _extract_text_from_input(input_data: Any) -> str:
"""Extract text from input_data, handling both string and conversation history formats.

The Agents SDK may pass input_data in different formats:
- String: Direct text input
- List of dicts: Conversation history with message objects

Args:
input_data: Input from Agents SDK (string or list of messages)

Returns:
Extracted text string from the latest user message
"""
# If it's already a string, return it
if isinstance(input_data, str):
return input_data

# If it's a list (conversation history), extract the latest user message
if isinstance(input_data, list):
if len(input_data) == 0:
return "" # Empty list returns empty string

# Iterate from the end to find the latest user message
for msg in reversed(input_data):
if isinstance(msg, dict):
role = msg.get("role")
if role == "user":
content = msg.get("content")
# Content can be a string or a list of content parts
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content parts
text_parts = []
for part in content:
if isinstance(part, dict):
# Check for various text field names
text = part.get("text") or part.get("input_text") or part.get("output_text")
if text and isinstance(text, str):
text_parts.append(text)
if text_parts:
return " ".join(text_parts)
# If content is something else, try to stringify it
elif content is not None:
return str(content)

# No user message found in list
return ""

# Fallback: convert to string
return str(input_data)


def _create_agents_guardrails_from_config(
config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False
) -> list[Any]:
Expand All @@ -355,7 +408,7 @@ def _create_agents_guardrails_from_config(
If False (default), treat guardrail errors as safe and continue execution.

Returns:
List of guardrail functions that can be used with Agents SDK
List of guardrail functions (one per individual guardrail) ready for Agents SDK

Raises:
ImportError: If agents package is not available
Expand All @@ -372,17 +425,15 @@ def _create_agents_guardrails_from_config(
# Load and parse the pipeline configuration
pipeline = load_pipeline_bundles(config)

# Instantiate guardrails for requested stages and filter out tool-level guardrails
stage_guardrails = {}
# Collect all individual guardrails from requested stages (filter out tool-level)
all_guardrails = []
for stage_name in stages:
stage = getattr(pipeline, stage_name, None)
if stage:
all_guardrails = instantiate_guardrails(stage, default_spec_registry)
stage_guardrails = instantiate_guardrails(stage, default_spec_registry)
# Filter out tool-level guardrails - they're handled separately
_, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails)
stage_guardrails[stage_name] = agent_level_guardrails
else:
stage_guardrails[stage_name] = []
_, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails)
all_guardrails.extend(agent_level_guardrails)

# Create default context if none provided
if context is None:
Expand All @@ -394,31 +445,30 @@ class DefaultContext:

context = DefaultContext(guardrail_llm=AsyncOpenAI())

def _create_stage_guardrail(stage_name: str):
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
"""Guardrail function for a specific pipeline stage."""
def _create_individual_guardrail(guardrail):
"""Create a function for a single specific guardrail."""
async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
"""Guardrail function for a specific guardrail check."""
try:
# Get guardrails for this stage (already filtered to exclude prompt injection)
guardrails = stage_guardrails.get(stage_name, [])
if not guardrails:
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
# Extract text from input_data (handle both string and conversation history formats)
text_data = _extract_text_from_input(input_data)

# Run the guardrails for this stage
# Run this single guardrail
results = await run_guardrails(
ctx=context,
data=input_data,
data=text_data,
media_type="text/plain",
guardrails=guardrails,
guardrails=[guardrail], # Just this one guardrail
suppress_tripwire=True, # We handle tripwires manually
stage_name=stage_name,
stage_name=guardrail_type, # "input" or "output" - indicates which stage
raise_guardrail_errors=raise_guardrail_errors,
)

# Check if any tripwires were triggered
# Check if tripwire was triggered
for result in results:
if result.tripwire_triggered:
guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown"
return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True)
# Return full metadata in output_info for consistency with tool guardrails
return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True)

return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)

Expand All @@ -428,24 +478,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
raise e
else:
# Current behavior: treat errors as tripwires
return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True)
# Return structured error info for consistency
return GuardrailFunctionOutput(
output_info={
"error": str(e),
"guardrail_name": guardrail.definition.name,
},
tripwire_triggered=True,
)

# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_")

# Set the function name for debugging
stage_guardrail.__name__ = f"{stage_name}_guardrail"
return stage_guardrail
return single_guardrail

guardrail_functions = []

for stage in stages:
stage_guardrail = _create_stage_guardrail(stage)
# Create one function per individual guardrail (Agents SDK runs them concurrently)
for guardrail in all_guardrails:
guardrail_func = _create_individual_guardrail(guardrail)

# Decorate with the appropriate guardrail decorator
if guardrail_type == "input":
stage_guardrail = input_guardrail(stage_guardrail)
guardrail_func = input_guardrail(guardrail_func)
else:
stage_guardrail = output_guardrail(stage_guardrail)
guardrail_func = output_guardrail(guardrail_func)

guardrail_functions.append(stage_guardrail)
guardrail_functions.append(guardrail_func)

return guardrail_functions

Expand Down
Loading
Loading