@@ -257,7 +257,7 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu
257257 media_type = "text/plain" ,
258258 guardrails = [guardrail ],
259259 suppress_tripwire = True ,
260- stage_name = f"tool_input_ { guardrail_name . lower (). replace ( ' ' , '_' ) } " ,
260+ stage_name = "tool_input " ,
261261 raise_guardrail_errors = raise_guardrail_errors ,
262262 )
263263
@@ -312,7 +312,7 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
312312 media_type = "text/plain" ,
313313 guardrails = [guardrail ],
314314 suppress_tripwire = True ,
315- stage_name = f"tool_output_ { guardrail_name . lower (). replace ( ' ' , '_' ) } " ,
315+ stage_name = "tool_output " ,
316316 raise_guardrail_errors = raise_guardrail_errors ,
317317 )
318318
@@ -338,6 +338,53 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction
338338 return tool_output_gr
339339
340340
341+ def _extract_text_from_input (input_data : Any ) -> str :
342+ """Extract text from input_data, handling both string and conversation history formats.
343+
344+ The Agents SDK may pass input_data in different formats:
345+ - String: Direct text input
346+ - List of dicts: Conversation history with message objects
347+
348+ Args:
349+ input_data: Input from Agents SDK (string or list of messages)
350+
351+ Returns:
352+ Extracted text string from the latest user message
353+ """
354+ # If it's already a string, return it
355+ if isinstance (input_data , str ):
356+ return input_data
357+
358+ # If it's a list (conversation history), extract the latest user message
359+ if isinstance (input_data , list ) and len (input_data ) > 0 :
360+ # Iterate from the end to find the latest user message
361+ for msg in reversed (input_data ):
362+ if isinstance (msg , dict ):
363+ role = msg .get ("role" )
364+ if role == "user" :
365+ content = msg .get ("content" )
366+ # Content can be a string or a list of content parts
367+ if isinstance (content , str ):
368+ return content
369+ elif isinstance (content , list ):
370+ # Extract text from content parts
371+ text_parts = []
372+ for part in content :
373+ if isinstance (part , dict ):
374+ # Check for various text field names
375+ text = part .get ("text" ) or part .get ("input_text" ) or part .get ("output_text" )
376+ if text and isinstance (text , str ):
377+ text_parts .append (text )
378+ if text_parts :
379+ return " " .join (text_parts )
380+ # If content is something else, try to stringify it
381+ elif content is not None :
382+ return str (content )
383+
384+ # Fallback: convert to string
385+ return str (input_data )
386+
387+
341388def _create_agents_guardrails_from_config (
342389 config : str | Path | dict [str , Any ], stages : list [str ], guardrail_type : str = "input" , context : Any = None , raise_guardrail_errors : bool = False
343390) -> list [Any ]:
@@ -355,7 +402,7 @@ def _create_agents_guardrails_from_config(
355402 If False (default), treat guardrail errors as safe and continue execution.
356403
357404 Returns:
358- List of guardrail functions that can be used with Agents SDK
405+ List of guardrail functions (one per individual guardrail) ready for Agents SDK
359406
360407 Raises:
361408 ImportError: If agents package is not available
@@ -372,17 +419,15 @@ def _create_agents_guardrails_from_config(
372419 # Load and parse the pipeline configuration
373420 pipeline = load_pipeline_bundles (config )
374421
375- # Instantiate guardrails for requested stages and filter out tool-level guardrails
376- stage_guardrails = {}
422+ # Collect all individual guardrails from requested stages ( filter out tool-level)
423+ all_guardrails = []
377424 for stage_name in stages :
378425 stage = getattr (pipeline , stage_name , None )
379426 if stage :
380- all_guardrails = instantiate_guardrails (stage , default_spec_registry )
427+ stage_guardrails = instantiate_guardrails (stage , default_spec_registry )
381428 # Filter out tool-level guardrails - they're handled separately
382- _ , agent_level_guardrails = _separate_tool_level_from_agent_level (all_guardrails )
383- stage_guardrails [stage_name ] = agent_level_guardrails
384- else :
385- stage_guardrails [stage_name ] = []
429+ _ , agent_level_guardrails = _separate_tool_level_from_agent_level (stage_guardrails )
430+ all_guardrails .extend (agent_level_guardrails )
386431
387432 # Create default context if none provided
388433 if context is None :
@@ -394,31 +439,30 @@ class DefaultContext:
394439
395440 context = DefaultContext (guardrail_llm = AsyncOpenAI ())
396441
397- def _create_stage_guardrail (stage_name : str ):
398- async def stage_guardrail (ctx : RunContextWrapper [None ], agent : Agent , input_data : str ) -> GuardrailFunctionOutput :
399- """Guardrail function for a specific pipeline stage."""
442+ def _create_individual_guardrail (guardrail ):
443+ """Create a function for a single specific guardrail."""
444+ async def single_guardrail (ctx : RunContextWrapper [None ], agent : Agent , input_data : str ) -> GuardrailFunctionOutput :
445+ """Guardrail function for a specific guardrail check."""
400446 try :
401- # Get guardrails for this stage (already filtered to exclude prompt injection)
402- guardrails = stage_guardrails .get (stage_name , [])
403- if not guardrails :
404- return GuardrailFunctionOutput (output_info = None , tripwire_triggered = False )
447+ # Extract text from input_data (handle both string and conversation history formats)
448+ text_data = _extract_text_from_input (input_data )
405449
406- # Run the guardrails for this stage
450+ # Run this single guardrail
407451 results = await run_guardrails (
408452 ctx = context ,
409- data = input_data ,
453+ data = text_data ,
410454 media_type = "text/plain" ,
411- guardrails = guardrails ,
455+ guardrails = [ guardrail ], # Just this one guardrail
412456 suppress_tripwire = True , # We handle tripwires manually
413- stage_name = stage_name ,
457+ stage_name = guardrail_type , # "input" or "output" - indicates which stage
414458 raise_guardrail_errors = raise_guardrail_errors ,
415459 )
416460
417- # Check if any tripwires were triggered
461+ # Check if tripwire was triggered
418462 for result in results :
419463 if result .tripwire_triggered :
420- guardrail_name = result . info . get ( "guardrail_name" , "unknown" ) if isinstance ( result . info , dict ) else "unknown"
421- return GuardrailFunctionOutput (output_info = f"Guardrail { guardrail_name } triggered tripwire" , tripwire_triggered = True )
464+ # Return full metadata in output_info for consistency with tool guardrails
465+ return GuardrailFunctionOutput (output_info = result . info , tripwire_triggered = True )
422466
423467 return GuardrailFunctionOutput (output_info = None , tripwire_triggered = False )
424468
@@ -428,24 +472,33 @@ async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data
428472 raise e
429473 else :
430474 # Current behavior: treat errors as tripwires
431- return GuardrailFunctionOutput (output_info = f"Error running { stage_name } guardrails: { str (e )} " , tripwire_triggered = True )
432-
433- # Set the function name for debugging
434- stage_guardrail .__name__ = f"{ stage_name } _guardrail"
435- return stage_guardrail
475+ # Return structured error info for consistency
476+ return GuardrailFunctionOutput (
477+ output_info = {
478+ "error" : str (e ),
479+ "guardrail_name" : guardrail .definition .name ,
480+ },
481+ tripwire_triggered = True ,
482+ )
483+
484+ # Set the function name to the guardrail name (e.g., "Moderation" → "Moderation")
485+ single_guardrail .__name__ = guardrail .definition .name .replace (" " , "_" )
486+
487+ return single_guardrail
436488
437489 guardrail_functions = []
438490
439- for stage in stages :
440- stage_guardrail = _create_stage_guardrail (stage )
491+ # Create one function per individual guardrail (Agents SDK runs them concurrently)
492+ for guardrail in all_guardrails :
493+ guardrail_func = _create_individual_guardrail (guardrail )
441494
442495 # Decorate with the appropriate guardrail decorator
443496 if guardrail_type == "input" :
444- stage_guardrail = input_guardrail (stage_guardrail )
497+ guardrail_func = input_guardrail (guardrail_func )
445498 else :
446- stage_guardrail = output_guardrail (stage_guardrail )
499+ guardrail_func = output_guardrail (guardrail_func )
447500
448- guardrail_functions .append (stage_guardrail )
501+ guardrail_functions .append (guardrail_func )
449502
450503 return guardrail_functions
451504
0 commit comments