Skip to content

Commit f95ae68

Browse files
authored
Fix issues with Agent conversation handling (#45)
1 parent bf0cb52 commit f95ae68

File tree

3 files changed

+486
-41
lines changed

3 files changed

+486
-41
lines changed

examples/basic/agents_sdk.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"categories": ["hate", "violence", "self-harm"],
2626
},
2727
},
28+
{"name": "Contains PII", "config": {"entities": ["US_SSN", "PHONE_NUMBER", "EMAIL_ADDRESS"]}},
2829
],
2930
},
3031
"input": {
@@ -75,11 +76,15 @@ async def main() -> None:
7576
except EOFError:
7677
print("\nExiting.")
7778
break
78-
except InputGuardrailTripwireTriggered:
79+
except InputGuardrailTripwireTriggered as exc:
7980
print("🛑 Input guardrail triggered!")
81+
print(exc.guardrail_result.guardrail.name)
82+
print(exc.guardrail_result.output.output_info)
8083
continue
81-
except OutputGuardrailTripwireTriggered:
84+
except OutputGuardrailTripwireTriggered as exc:
8285
print("🛑 Output guardrail triggered!")
86+
print(exc.guardrail_result.guardrail.name)
87+
print(exc.guardrail_result.output.output_info)
8388
continue
8489

8590

src/guardrails/agents.py

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
257257
media_type="text/plain",
258258
guardrails=[guardrail],
259259
suppress_tripwire=True,
260-
stage_name=f"tool_input_{guardrail_name.lower().replace(' ', '_')}",
260+
stage_name="tool_input",
261261
raise_guardrail_errors=raise_guardrail_errors,
262262
)
263263

@@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
312312
media_type="text/plain",
313313
guardrails=[guardrail],
314314
suppress_tripwire=True,
315-
stage_name=f"tool_output_{guardrail_name.lower().replace(' ', '_')}",
315+
stage_name="tool_output",
316316
raise_guardrail_errors=raise_guardrail_errors,
317317
)
318318

@@ -338,6 +338,69 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
338338
return tool_output_gr
339339

340340

341+
def _extract_text_from_input(input_data: Any) -> str:
342+
"""Extract text from input_data, handling both string and conversation history formats.
343+
344+
The Agents SDK may pass input_data in different formats:
345+
- String: Direct text input
346+
- List of dicts: Conversation history with message objects
347+
348+
Args:
349+
input_data: Input from Agents SDK (string or list of messages)
350+
351+
Returns:
352+
Extracted text string from the latest user message
353+
"""
354+
# If it's already a string, return it
355+
if isinstance(input_data, str):
356+
return input_data
357+
358+
# If it's a list (conversation history), extract the latest user message
359+
if isinstance(input_data, list):
360+
if not input_data:
361+
return "" # Empty list returns empty string
362+
363+
# Iterate from the end to find the latest user message
364+
for msg in reversed(input_data):
365+
if isinstance(msg, dict):
366+
role = msg.get("role")
367+
if role == "user":
368+
content = msg.get("content")
369+
# Content can be a string or a list of content parts
370+
if isinstance(content, str):
371+
return content
372+
elif isinstance(content, list):
373+
if not content:
374+
# Empty content list returns empty string (consistent with no text parts found)
375+
return ""
376+
# Extract text from content parts
377+
text_parts = []
378+
for part in content:
379+
if isinstance(part, dict):
380+
# Check for various text field names (avoid falsy empty string issue)
381+
text = None
382+
for field in ['text', 'input_text', 'output_text']:
383+
if field in part:
384+
text = part[field]
385+
break
386+
# Preserve empty strings, only filter None
387+
if text is not None and isinstance(text, str):
388+
text_parts.append(text)
389+
if text_parts:
390+
return " ".join(text_parts)
391+
# No text parts found, return empty string
392+
return ""
393+
# If content is something else, try to stringify it
394+
elif content is not None:
395+
return str(content)
396+
397+
# No user message found in list
398+
return ""
399+
400+
# Fallback: convert to string
401+
return str(input_data)
402+
403+
341404
def _create_agents_guardrails_from_config(
342405
config: str | Path | dict[str, Any], stages: list[str], guardrail_type: str = "input", context: Any = None, raise_guardrail_errors: bool = False
343406
) -> list[Any]:
@@ -355,7 +418,7 @@ def _create_agents_guardrails_from_config(
355418
If False (default), treat guardrail errors as safe and continue execution.
356419
357420
Returns:
358-
List of guardrail functions that can be used with Agents SDK
421+
List of guardrail functions (one per individual guardrail) ready for Agents SDK
359422
360423
Raises:
361424
ImportError: If agents package is not available
@@ -372,17 +435,15 @@ def _create_agents_guardrails_from_config(
372435
# Load and parse the pipeline configuration
373436
pipeline = load_pipeline_bundles(config)
374437

375-
# Instantiate guardrails for requested stages and filter out tool-level guardrails
376-
stage_guardrails = {}
438+
# Collect all individual guardrails from requested stages (filter out tool-level)
439+
all_guardrails = []
377440
for stage_name in stages:
378441
stage = getattr(pipeline, stage_name, None)
379442
if stage:
380-
all_guardrails = instantiate_guardrails(stage, default_spec_registry)
443+
stage_guardrails = instantiate_guardrails(stage, default_spec_registry)
381444
# Filter out tool-level guardrails - they're handled separately
382-
_, agent_level_guardrails = _separate_tool_level_from_agent_level(all_guardrails)
383-
stage_guardrails[stage_name] = agent_level_guardrails
384-
else:
385-
stage_guardrails[stage_name] = []
445+
_, agent_level_guardrails = _separate_tool_level_from_agent_level(stage_guardrails)
446+
all_guardrails.extend(agent_level_guardrails)
386447

387448
# Create default context if none provided
388449
if context is None:
@@ -394,58 +455,70 @@ class DefaultContext:
394455

395456
context = DefaultContext(guardrail_llm=AsyncOpenAI())
396457

397-
def _create_stage_guardrail(stage_name: str):
398-
async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput:
399-
"""Guardrail function for a specific pipeline stage."""
458+
def _create_individual_guardrail(guardrail):
459+
"""Create a function for a single specific guardrail."""
460+
async def single_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str | list) -> GuardrailFunctionOutput:
461+
"""Guardrail function for a specific guardrail check.
462+
463+
Note: input_data is typed as str in Agents SDK, but can actually be a list
464+
of message objects when conversation history is used. We handle both cases.
465+
"""
400466
try:
401-
# Get guardrails for this stage (already filtered to exclude prompt injection)
402-
guardrails = stage_guardrails.get(stage_name, [])
403-
if not guardrails:
404-
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
467+
# Extract text from input_data (handle both string and conversation history formats)
468+
text_data = _extract_text_from_input(input_data)
405469

406-
# Run the guardrails for this stage
470+
# Run this single guardrail
407471
results = await run_guardrails(
408472
ctx=context,
409-
data=input_data,
473+
data=text_data,
410474
media_type="text/plain",
411-
guardrails=guardrails,
475+
guardrails=[guardrail], # Just this one guardrail
412476
suppress_tripwire=True, # We handle tripwires manually
413-
stage_name=stage_name,
477+
stage_name=guardrail_type, # "input" or "output" - indicates which stage
414478
raise_guardrail_errors=raise_guardrail_errors,
415479
)
416480

417-
# Check if any tripwires were triggered
481+
# Check if tripwire was triggered
418482
for result in results:
419483
if result.tripwire_triggered:
420-
guardrail_name = result.info.get("guardrail_name", "unknown") if isinstance(result.info, dict) else "unknown"
421-
return GuardrailFunctionOutput(output_info=f"Guardrail {guardrail_name} triggered tripwire", tripwire_triggered=True)
484+
# Return full metadata in output_info for consistency with tool guardrails
485+
return GuardrailFunctionOutput(output_info=result.info, tripwire_triggered=True)
422486

423487
return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False)
424488

425489
except Exception as e:
426490
if raise_guardrail_errors:
427-
# Re-raise the exception to stop execution
428-
raise e
491+
# Re-raise the exception to stop execution (preserve traceback)
492+
raise
429493
else:
430494
# Current behavior: treat errors as tripwires
431-
return GuardrailFunctionOutput(output_info=f"Error running {stage_name} guardrails: {str(e)}", tripwire_triggered=True)
495+
# Return structured error info for consistency
496+
return GuardrailFunctionOutput(
497+
output_info={
498+
"error": str(e),
499+
"guardrail_name": guardrail.definition.name,
500+
},
501+
tripwire_triggered=True,
502+
)
503+
504+
# Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
505+
single_guardrail.__name__ = guardrail.definition.name.replace(" ", "_")
432506

433-
# Set the function name for debugging
434-
stage_guardrail.__name__ = f"{stage_name}_guardrail"
435-
return stage_guardrail
507+
return single_guardrail
436508

437509
guardrail_functions = []
438510

439-
for stage in stages:
440-
stage_guardrail = _create_stage_guardrail(stage)
511+
# Create one function per individual guardrail (Agents SDK runs them concurrently)
512+
for guardrail in all_guardrails:
513+
guardrail_func = _create_individual_guardrail(guardrail)
441514

442515
# Decorate with the appropriate guardrail decorator
443516
if guardrail_type == "input":
444-
stage_guardrail = input_guardrail(stage_guardrail)
517+
guardrail_func = input_guardrail(guardrail_func)
445518
else:
446-
stage_guardrail = output_guardrail(stage_guardrail)
519+
guardrail_func = output_guardrail(guardrail_func)
447520

448-
guardrail_functions.append(stage_guardrail)
521+
guardrail_functions.append(guardrail_func)
449522

450523
return guardrail_functions
451524

0 commit comments

Comments
 (0)