Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/basic/agents_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"categories": ["hate", "violence", "self-harm"],
},
},
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
],
},
"input": {
Expand Down Expand Up @@ -71,15 +72,23 @@ async def main() -> None:
run_config=RunConfig(tracing_disabled=True),
session=session,
)
agent = result.new_items[0].agent
print(f"Input guardrails: {[x.name for x in agent.input_guardrails]}")
breakpoint()
print(f"Output guardrails: {[x.name for x in agent.output_guardrails]}")
print(f"Assistant: {result.final_output}")
except EOFError:
print("\nExiting.")
break
except InputGuardrailTripwireTriggered:
except InputGuardrailTripwireTriggered as exc:
print("🛑 Input guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue
except OutputGuardrailTripwireTriggered:
except OutputGuardrailTripwireTriggered as exc:
print("🛑 Output guardrail triggered!")
print(exc.guardrail_result.guardrail.name)
print(exc.guardrail_result.output.output_info)
continue


Expand Down
121 changes: 87 additions & 34 deletions src/guardrails/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_input",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand Down Expand Up @@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
media_type="text/plain",
guardrails=[guardrail],
suppress_tripwire=True,
stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}",
stage_name="tool_output",
raise_guardrail_errors=raise_guardrail_errors,
)

Expand All @@ -338,6 +338,53 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
return tool_output_gr


def _extract_text_from_input(input_data: Any) -> str:
"""Extract text from input_data, handling both string and conversation history formats.

The Agents SDK may pass input_data in different formats:
- String: Direct text input
- List of dicts: Conversation history with message objects

Args:
input_data: Input from Agents SDK (string or list of messages)

Returns:
Extracted text string from the latest user message
"""
# If it's already a string, return it
if isinstance(input_data, str):
return input_data

# If it's a list (conversation history), extract the latest user message
if isinstance(input_data, list) and len(input_data) > 0:
# Iterate from the end to find the latest user message
for msg in reversed(input_data):
if isinstance(msg, dict):
role = msg.get("role")
if role == "user":
content = msg.get("content")
# Content can be a string or a list of content parts
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content parts
text_parts = []
for part in content:
if isinstance(part, dict):
# Check for various text field names
text = part.get("text") or part.get("input_text") or part.get("output_text")
if text and isinstance(text, str):
text_parts.append(text)
if text_parts:
return " ".join(text_parts)
# If content is something else, try to stringify it
elif content is not None:
return str(content)

# Fallback: convert to string
return str(input_data)


def _create_agents_guardrails_from_config(
config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False
) -> list[Any]:
Expand All @@ -355,7 +402,7 @@ def _create_agents_guardrails_from_config(
If False (default), treat guardrail errors as safe and continue execution.

Returns:
List of guardrail functions that can be used with Agents SDK
List of guardrail functions (one per individual guardrail) ready for Agents SDK

Raises:
ImportError: If agents package is not available
Expand All @@ -372,17 +419,15 @@ def _create_agents_guardrails_from_config(
# Load and parse the pipeline configuration
pipeline = load_pipeline_bundles(config)

# Instantiate guardrails for requested stages and filter out tool-level guardrails
stage_guardrails = {}
# Collect all individual guardrails from requested stages (filter out tool-level)
all_guardrails = []
for stage_name in stages:
stage = getattr(pipeline, stage_name, None)
if stage:
all_guardrails = instantiate_guardrails(stage, default_spec_registry)
stage_guardrails = instantiate_guardrails(stage, default_spec_registry)
# Filter out tool-level guardrails - they're handled separately
_, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails)
stage_guardrails[stage_name] = agent_level_guardrails
else:
stage_guardrails[stage_name] = []
_, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails)
all_guardrails.extend(agent_level_guardrails)

# Create default context if none provided
if context is None:
Expand All @@ -394,31 +439,30 @@ class DefaultContext:

context = DefaultContext(guardrail_llm=AsyncOpenAI())

def _create_stage_guardrail(stage_name: str):
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
"""Guardrail function for a specific pipeline stage."""
def _create_individual_guardrail(guardrail):
"""Create a function for a single specific guardrail."""
async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
"""Guardrail function for a specific guardrail check."""
try:
# Get guardrails for this stage (already filtered to exclude prompt injection)
guardrails = stage_guardrails.get(stage_name, [])
if not guardrails:
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
# Extract text from input_data (handle both string and conversation history formats)
text_data = _extract_text_from_input(input_data)

# Run the guardrails for this stage
# Run this single guardrail
results = await run_guardrails(
ctx=context,
data=input_data,
data=text_data,
media_type="text/plain",
guardrails=guardrails,
guardrails=[guardrail], # Just this one guardrail
suppress_tripwire=True, # We handle tripwires manually
stage_name=stage_name,
stage_name=guardrail_type, # "input" or "output" - indicates which stage
raise_guardrail_errors=raise_guardrail_errors,
)

# Check if any tripwires were triggered
# Check if tripwire was triggered
for result in results:
if result.tripwire_triggered:
guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown"
return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True)
# Return full metadata in output_info for consistency with tool guardrails
return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True)

return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)

Expand All @@ -428,24 +472,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
raise e
else:
# Current behavior: treat errors as tripwires
return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True)

# Set the function name for debugging
stage_guardrail.__name__ = f"{stage_name}_guardrail"
return stage_guardrail
# Return structured error info for consistency
return GuardrailFunctionOutput(
output_info={
"error": str(e),
"guardrail_name": guardrail.definition.name,
},
tripwire_triggered=True,
)

# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_")

return single_guardrail

guardrail_functions = []

for stage in stages:
stage_guardrail = _create_stage_guardrail(stage)
# Create one function per individual guardrail (Agents SDK runs them concurrently)
for guardrail in all_guardrails:
guardrail_func = _create_individual_guardrail(guardrail)

# Decorate with the appropriate guardrail decorator
if guardrail_type == "input":
stage_guardrail = input_guardrail(stage_guardrail)
guardrail_func = input_guardrail(guardrail_func)
else:
stage_guardrail = output_guardrail(stage_guardrail)
guardrail_func = output_guardrail(guardrail_func)

guardrail_functions.append(stage_guardrail)
guardrail_functions.append(guardrail_func)

return guardrail_functions

Expand Down
Loading