Skip to content

Commit 5a2c212

Browse files
committed
Agent conversation handling
1 parent bf0cb52 commit 5a2c212

File tree

2 files changed

+98
-36
lines changed

2 files changed

+98
-36
lines changed

examples/basic/agents_sdk.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"categories": ["hate", "violence", "self-harm"],
2626
},
2727
},
28+
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
2829
],
2930
},
3031
"input": {
@@ -71,15 +72,23 @@ async def main() -> None:
7172
run_config=RunConfig(tracing_disabled=True),
7273
session=session,
7374
)
75+
agent = result.new_items[0].agent
76+
print(f"Input guardrails: {[x.name for x in agent.input_guardrails]}")
77+
breakpoint()
78+
print(f"Output guardrails: {[x.name for x in agent.output_guardrails]}")
7479
print(f"Assistant: {result.final_output}")
7580
except EOFError:
7681
print("\nExiting.")
7782
break
78-
except InputGuardrailTripwireTriggered:
83+
except InputGuardrailTripwireTriggered as exc:
7984
print("🛑 Input guardrail triggered!")
85+
print(exc.guardrail_result.guardrail.name)
86+
print(exc.guardrail_result.output.output_info)
8087
continue
81-
except OutputGuardrailTripwireTriggered:
88+
except OutputGuardrailTripwireTriggered as exc:
8289
print("🛑 Output guardrail triggered!")
90+
print(exc.guardrail_result.guardrail.name)
91+
print(exc.guardrail_result.output.output_info)
8392
continue
8493

8594

src/guardrails/agents.py

Lines changed: 87 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
257257
media_type="text/plain",
258258
guardrails=[guardrail],
259259
suppress_tripwire=True,
260-
stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}",
260+
stage_name="tool_input",
261261
raise_guardrail_errors=raise_guardrail_errors,
262262
)
263263

@@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
312312
media_type="text/plain",
313313
guardrails=[guardrail],
314314
suppress_tripwire=True,
315-
stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}",
315+
stage_name="tool_output",
316316
raise_guardrail_errors=raise_guardrail_errors,
317317
)
318318

@@ -338,6 +338,53 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
338338
return tool_output_gr
339339

340340

341+
def _extract_text_from_input(input_data: Any) -> str:
342+
"""Extract text from input_data, handling both string and conversation history formats.
343+
344+
The Agents SDK may pass input_data in different formats:
345+
- String: Direct text input
346+
- List of dicts: Conversation history with message objects
347+
348+
Args:
349+
input_data: Input from Agents SDK (string or list of messages)
350+
351+
Returns:
352+
Extracted text string from the latest user message
353+
"""
354+
# If it's already a string, return it
355+
if isinstance(input_data, str):
356+
return input_data
357+
358+
# If it's a list (conversation history), extract the latest user message
359+
if isinstance(input_data, list) and len(input_data) > 0:
360+
# Iterate from the end to find the latest user message
361+
for msg in reversed(input_data):
362+
if isinstance(msg, dict):
363+
role = msg.get("role")
364+
if role == "user":
365+
content = msg.get("content")
366+
# Content can be a string or a list of content parts
367+
if isinstance(content, str):
368+
return content
369+
elif isinstance(content, list):
370+
# Extract text from content parts
371+
text_parts = []
372+
for part in content:
373+
if isinstance(part, dict):
374+
# Check for various text field names
375+
text = part.get("text") or part.get("input_text") or part.get("output_text")
376+
if text and isinstance(text, str):
377+
text_parts.append(text)
378+
if text_parts:
379+
return " ".join(text_parts)
380+
# If content is something else, try to stringify it
381+
elif content is not None:
382+
return str(content)
383+
384+
# Fallback: convert to string
385+
return str(input_data)
386+
387+
341388
def _create_agents_guardrails_from_config(
342389
config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False
343390
) -> list[Any]:
@@ -355,7 +402,7 @@ def _create_agents_guardrails_from_config(
355402
If False (default), treat guardrail errors as safe and continue execution.
356403
357404
Returns:
358-
List of guardrail functions that can be used with Agents SDK
405+
List of guardrail functions (one per individual guardrail) ready for Agents SDK
359406
360407
Raises:
361408
ImportError: If agents package is not available
@@ -372,17 +419,15 @@ def _create_agents_guardrails_from_config(
372419
# Load and parse the pipeline configuration
373420
pipeline = load_pipeline_bundles(config)
374421

375-
# Instantiate guardrails for requested stages and filter out tool-level guardrails
376-
stage_guardrails = {}
422+
# Collect all individual guardrails from requested stages (filter out tool-level)
423+
all_guardrails = []
377424
for stage_name in stages:
378425
stage = getattr(pipeline, stage_name, None)
379426
if stage:
380-
all_guardrails = instantiate_guardrails(stage, default_spec_registry)
427+
stage_guardrails = instantiate_guardrails(stage, default_spec_registry)
381428
# Filter out tool-level guardrails - they're handled separately
382-
_, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails)
383-
stage_guardrails[stage_name] = agent_level_guardrails
384-
else:
385-
stage_guardrails[stage_name] = []
429+
_, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails)
430+
all_guardrails.extend(agent_level_guardrails)
386431

387432
# Create default context if none provided
388433
if context is None:
@@ -394,31 +439,30 @@ class DefaultContext:
394439

395440
context = DefaultContext(guardrail_llm=AsyncOpenAI())
396441

397-
def _create_stage_guardrail(stage_name: str):
398-
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
399-
"""Guardrail function for a specific pipeline stage."""
442+
def _create_individual_guardrail(guardrail):
443+
"""Create a function for a single specific guardrail."""
444+
async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
445+
"""Guardrail function for a specific guardrail check."""
400446
try:
401-
# Get guardrails for this stage (already filtered to exclude prompt injection)
402-
guardrails = stage_guardrails.get(stage_name, [])
403-
if not guardrails:
404-
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
447+
# Extract text from input_data (handle both string and conversation history formats)
448+
text_data = _extract_text_from_input(input_data)
405449

406-
# Run the guardrails for this stage
450+
# Run this single guardrail
407451
results = await run_guardrails(
408452
ctx=context,
409-
data=input_data,
453+
data=text_data,
410454
media_type="text/plain",
411-
guardrails=guardrails,
455+
guardrails=[guardrail], # Just this one guardrail
412456
suppress_tripwire=True, # We handle tripwires manually
413-
stage_name=stage_name,
457+
stage_name=guardrail_type, # "input" or "output" - indicates which stage
414458
raise_guardrail_errors=raise_guardrail_errors,
415459
)
416460

417-
# Check if any tripwires were triggered
461+
# Check if tripwire was triggered
418462
for result in results:
419463
if result.tripwire_triggered:
420-
guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown"
421-
return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True)
464+
# Return full metadata in output_info for consistency with tool guardrails
465+
return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True)
422466

423467
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
424468

@@ -428,24 +472,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
428472
raise e
429473
else:
430474
# Current behavior: treat errors as tripwires
431-
return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True)
432-
433-
# Set the function name for debugging
434-
stage_guardrail.__name__ = f"{stage_name}_guardrail"
435-
return stage_guardrail
475+
# Return structured error info for consistency
476+
return GuardrailFunctionOutput(
477+
output_info={
478+
"error": str(e),
479+
"guardrail_name": guardrail.definition.name,
480+
},
481+
tripwire_triggered=True,
482+
)
483+
484+
# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
485+
single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_")
486+
487+
return single_guardrail
436488

437489
guardrail_functions = []
438490

439-
for stage in stages:
440-
stage_guardrail = _create_stage_guardrail(stage)
491+
# Create one function per individual guardrail (Agents SDK runs them concurrently)
492+
for guardrail in all_guardrails:
493+
guardrail_func = _create_individual_guardrail(guardrail)
441494

442495
# Decorate with the appropriate guardrail decorator
443496
if guardrail_type == "input":
444-
stage_guardrail = input_guardrail(stage_guardrail)
497+
guardrail_func = input_guardrail(guardrail_func)
445498
else:
446-
stage_guardrail = output_guardrail(stage_guardrail)
499+
guardrail_func = output_guardrail(guardrail_func)
447500

448-
guardrail_functions.append(stage_guardrail)
501+
guardrail_functions.append(guardrail_func)
449502

450503
return guardrail_functions
451504

0 commit comments

Comments
 (0)