Skip to content

Commit 794ff02

Browse files
authored
Support Agents SDK prompt param usage (#34)
* Support prompt usage * adding tests
1 parent bfd7b7a commit 794ff02

File tree

3 files changed

+153
-35
lines changed

3 files changed

+153
-35
lines changed

docs/benchmarking/nsfw.md

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/guardrails/agents.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def __new__(
492492
cls,
493493
config: str | Path | dict[str, Any],
494494
name: str,
495-
instructions: str,
495+
instructions: str | Callable[[Any, Any], Any] | None = None,
496496
raise_guardrail_errors: bool = False,
497497
block_on_tool_violations: bool = False,
498498
**agent_kwargs: Any,
@@ -511,7 +511,9 @@ def __new__(
511511
Args:
512512
config: Pipeline configuration (file path, dict, or JSON string)
513513
name: Agent name
514-
instructions: Agent instructions
514+
instructions: Agent instructions. Can be a string, a callable that dynamically
515+
generates instructions, or None. If a callable, it will receive the context
516+
and agent instance and must return a string.
515517
raise_guardrail_errors: If True, raise exceptions when guardrails fail to execute.
516518
If False (default), treat guardrail errors as safe and continue execution.
517519
block_on_tool_violations: If True, tool guardrail violations raise exceptions (halt execution).
@@ -553,7 +555,11 @@ def __new__(
553555
input_tool, input_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("input", []))
554556
output_tool, output_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("output", []))
555557

556-
# Create agent-level INPUT guardrails
558+
# Extract any user-provided guardrails from agent_kwargs
559+
user_input_guardrails = agent_kwargs.pop("input_guardrails", [])
560+
user_output_guardrails = agent_kwargs.pop("output_guardrails", [])
561+
562+
# Create agent-level INPUT guardrails from config
557563
input_guardrails = []
558564

559565
# Add agent-level guardrails from pre_flight and input stages
@@ -573,7 +579,10 @@ def __new__(
573579
)
574580
)
575581

576-
# Create agent-level OUTPUT guardrails
582+
# Merge with user-provided input guardrails (config ones run first, then user ones)
583+
input_guardrails.extend(user_input_guardrails)
584+
585+
# Create agent-level OUTPUT guardrails from config
577586
output_guardrails = []
578587
if output_agent:
579588
output_guardrails = _create_agents_guardrails_from_config(
@@ -583,6 +592,9 @@ def __new__(
583592
raise_guardrail_errors=raise_guardrail_errors,
584593
)
585594

595+
# Merge with user-provided output guardrails (config ones run first, then user ones)
596+
output_guardrails.extend(user_output_guardrails)
597+
586598
# Apply tool-level guardrails
587599
tools = agent_kwargs.get("tools", [])
588600

tests/unit/test_agents.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,140 @@ def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None:
597597
agent_instance = agents.GuardrailAgent(config={}, name="NoTools", instructions="None")
598598

599599
assert getattr(agent_instance, "input_guardrails", []) == [] # noqa: S101
600+
601+
602+
def test_guardrail_agent_without_instructions(monkeypatch: pytest.MonkeyPatch) -> None:
603+
"""GuardrailAgent should work without instructions parameter."""
604+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
605+
606+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
607+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
608+
609+
# Should not raise TypeError about missing instructions
610+
agent_instance = agents.GuardrailAgent(config={}, name="NoInstructions")
611+
612+
assert isinstance(agent_instance, agents_module.Agent) # noqa: S101
613+
assert agent_instance.instructions is None # noqa: S101
614+
615+
616+
def test_guardrail_agent_with_callable_instructions(monkeypatch: pytest.MonkeyPatch) -> None:
617+
"""GuardrailAgent should accept callable instructions."""
618+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
619+
620+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
621+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
622+
623+
def dynamic_instructions(ctx: Any, agent: Any) -> str:
624+
return f"You are {agent.name}"
625+
626+
agent_instance = agents.GuardrailAgent(
627+
config={},
628+
name="DynamicAgent",
629+
instructions=dynamic_instructions,
630+
)
631+
632+
assert isinstance(agent_instance, agents_module.Agent) # noqa: S101
633+
assert callable(agent_instance.instructions) # noqa: S101
634+
assert agent_instance.instructions == dynamic_instructions # noqa: S101
635+
636+
637+
def test_guardrail_agent_merges_user_input_guardrails(monkeypatch: pytest.MonkeyPatch) -> None:
638+
"""User input guardrails should be merged with config guardrails."""
639+
agent_guard = _make_guardrail("Config Input Guard")
640+
641+
class FakePipeline:
642+
def __init__(self) -> None:
643+
self.pre_flight = None
644+
self.input = SimpleNamespace()
645+
self.output = None
646+
647+
pipeline = FakePipeline()
648+
649+
def fake_load_pipeline_bundles(config: Any) -> FakePipeline:
650+
return pipeline
651+
652+
def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]:
653+
if stage is pipeline.input:
654+
return [agent_guard]
655+
return []
656+
657+
from guardrails import runtime as runtime_module
658+
659+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles)
660+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails)
661+
662+
# Create a custom user guardrail
663+
custom_guardrail = lambda ctx, agent, input: None # noqa: E731
664+
665+
agent_instance = agents.GuardrailAgent(
666+
config={},
667+
name="MergedAgent",
668+
instructions="Test",
669+
input_guardrails=[custom_guardrail],
670+
)
671+
672+
# Should have both config and user guardrails merged
673+
assert isinstance(agent_instance, agents_module.Agent) # noqa: S101
674+
assert len(agent_instance.input_guardrails) == 2 # noqa: S101
675+
# Config guardrail from _create_agents_guardrails_from_config, then user guardrail
676+
677+
678+
def test_guardrail_agent_merges_user_output_guardrails(monkeypatch: pytest.MonkeyPatch) -> None:
679+
"""User output guardrails should be merged with config guardrails."""
680+
agent_guard = _make_guardrail("Config Output Guard")
681+
682+
class FakePipeline:
683+
def __init__(self) -> None:
684+
self.pre_flight = None
685+
self.input = None
686+
self.output = SimpleNamespace()
687+
688+
pipeline = FakePipeline()
689+
690+
def fake_load_pipeline_bundles(config: Any) -> FakePipeline:
691+
return pipeline
692+
693+
def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]:
694+
if stage is pipeline.output:
695+
return [agent_guard]
696+
return []
697+
698+
from guardrails import runtime as runtime_module
699+
700+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles)
701+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails)
702+
703+
# Create a custom user guardrail
704+
custom_guardrail = lambda ctx, agent, output: None # noqa: E731
705+
706+
agent_instance = agents.GuardrailAgent(
707+
config={},
708+
name="MergedAgent",
709+
instructions="Test",
710+
output_guardrails=[custom_guardrail],
711+
)
712+
713+
# Should have both config and user guardrails merged
714+
assert isinstance(agent_instance, agents_module.Agent) # noqa: S101
715+
assert len(agent_instance.output_guardrails) == 2 # noqa: S101
716+
# Config guardrail from _create_agents_guardrails_from_config, then user guardrail
717+
718+
719+
def test_guardrail_agent_with_empty_user_guardrails(monkeypatch: pytest.MonkeyPatch) -> None:
720+
"""GuardrailAgent should handle empty user guardrail lists gracefully."""
721+
pipeline = SimpleNamespace(pre_flight=None, input=None, output=None)
722+
723+
monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline, raising=False)
724+
monkeypatch.setattr(runtime_module, "instantiate_guardrails", lambda *args, **kwargs: [], raising=False)
725+
726+
agent_instance = agents.GuardrailAgent(
727+
config={},
728+
name="EmptyListAgent",
729+
instructions="Test",
730+
input_guardrails=[],
731+
output_guardrails=[],
732+
)
733+
734+
assert isinstance(agent_instance, agents_module.Agent) # noqa: S101
735+
assert agent_instance.input_guardrails == [] # noqa: S101
736+
assert agent_instance.output_guardrails == [] # noqa: S101

0 commit comments

Comments
 (0)