33import abc
44import enum
55import logging
6- from typing import Any , Dict , List , Optional , Union , Type
6+ from typing import Any , Dict , Optional
77from dataclasses import dataclass
88
99logger = logging .getLogger (__name__ )
1010
1111
1212class GuardrailAction (enum .Enum ):
1313 """Actions that a guardrail can take."""
14+
1415 ALLOW = "allow"
1516 BLOCK = "block"
1617 MODIFY = "modify"
1718
1819
1920class BlockStrategy (enum .Enum ):
2021 """Strategies for handling blocked requests."""
21- RAISE_EXCEPTION = "raise_exception" # Raise GuardrailBlockedException (breaks pipeline)
22- RETURN_EMPTY = "return_empty" # Return empty/None response (graceful)
22+
23+ RAISE_EXCEPTION = (
24+ "raise_exception" # Raise GuardrailBlockedException (breaks pipeline)
25+ )
26+ RETURN_EMPTY = "return_empty" # Return empty/None response (graceful)
2327 RETURN_ERROR_MESSAGE = "return_error_message" # Return error message (graceful)
24- SKIP_FUNCTION = "skip_function" # Skip function execution, return None (graceful)
28+ SKIP_FUNCTION = "skip_function" # Skip function execution, return None (graceful)
2529
2630
2731@dataclass
2832class GuardrailResult :
2933 """Result of applying a guardrail."""
34+
3035 action : GuardrailAction
3136 modified_data : Optional [Any ] = None
3237 metadata : Optional [Dict [str , Any ]] = None
3338 reason : Optional [str ] = None
3439 block_strategy : Optional [BlockStrategy ] = None
3540 error_message : Optional [str ] = None
36-
41+
3742 def __post_init__ (self ):
3843 """Validate the result after initialization."""
3944 if self .action == GuardrailAction .MODIFY and self .modified_data is None :
4045 raise ValueError ("modified_data must be provided when action is MODIFY" )
4146 if self .action == GuardrailAction .BLOCK and self .block_strategy is None :
42- self .block_strategy = BlockStrategy .RAISE_EXCEPTION # Default to existing behavior
47+ self .block_strategy = (
48+ BlockStrategy .RAISE_EXCEPTION
49+ ) # Default to existing behavior
4350
4451
4552class GuardrailBlockedException (Exception ):
4653 """Exception raised when a guardrail blocks execution."""
47-
48- def __init__ (self , guardrail_name : str , reason : str , metadata : Optional [Dict [str , Any ]] = None ):
54+
55+ def __init__ (
56+ self ,
57+ guardrail_name : str ,
58+ reason : str ,
59+ metadata : Optional [Dict [str , Any ]] = None ,
60+ ):
4961 self .guardrail_name = guardrail_name
5062 self .reason = reason
5163 self .metadata = metadata or {}
@@ -54,10 +66,10 @@ def __init__(self, guardrail_name: str, reason: str, metadata: Optional[Dict[str
5466
5567class BaseGuardrail (abc .ABC ):
5668 """Base class for all guardrails."""
57-
69+
5870 def __init__ (self , name : str , enabled : bool = True , ** config ):
5971 """Initialize the guardrail.
60-
72+
6173 Args:
6274 name: Human-readable name for this guardrail
6375 enabled: Whether this guardrail is active
@@ -66,113 +78,41 @@ def __init__(self, name: str, enabled: bool = True, **config):
6678 self .name = name
6779 self .enabled = enabled
6880 self .config = config
69-
81+
7082 @abc .abstractmethod
7183 def check_input (self , inputs : Dict [str , Any ]) -> GuardrailResult :
7284 """Check and potentially modify function inputs.
73-
85+
7486 Args:
7587 inputs: Dictionary of function inputs (parameter_name -> value)
76-
88+
7789 Returns:
7890 GuardrailResult indicating the action to take
7991 """
8092 pass
81-
93+
8294 @abc .abstractmethod
8395 def check_output (self , output : Any , inputs : Dict [str , Any ]) -> GuardrailResult :
8496 """Check and potentially modify function output.
85-
97+
8698 Args:
8799 output: The function's output
88100 inputs: Dictionary of function inputs for context
89-
101+
90102 Returns:
91103 GuardrailResult indicating the action to take
92104 """
93105 pass
94-
106+
95107 def is_enabled (self ) -> bool :
96108 """Check if this guardrail is enabled."""
97109 return self .enabled
98-
110+
99111 def get_metadata (self ) -> Dict [str , Any ]:
100112 """Get metadata about this guardrail for trace logging."""
101113 return {
102114 "name" : self .name ,
103115 "type" : self .__class__ .__name__ ,
104116 "enabled" : self .enabled ,
105- "config" : self .config
117+ "config" : self .config ,
106118 }
107-
108-
109- class GuardrailRegistry :
110- """Registry for managing guardrails."""
111-
112- def __init__ (self ):
113- self ._guardrails : Dict [str , Type [BaseGuardrail ]] = {}
114-
115- def register (self , name : str , guardrail_class : Type [BaseGuardrail ]):
116- """Register a guardrail class.
117-
118- Args:
119- name: Name to register the guardrail under
120- guardrail_class: The guardrail class to register
121- """
122- if not issubclass (guardrail_class , BaseGuardrail ):
123- raise ValueError (f"Guardrail class must inherit from BaseGuardrail" )
124-
125- self ._guardrails [name ] = guardrail_class
126- logger .debug (f"Registered guardrail: { name } " )
127-
128- def create (self , name : str , ** kwargs ) -> BaseGuardrail :
129- """Create an instance of a registered guardrail.
130-
131- Args:
132- name: Name of the registered guardrail
133- **kwargs: Arguments to pass to the guardrail constructor
134-
135- Returns:
136- Instance of the guardrail
137- """
138- if name not in self ._guardrails :
139- raise ValueError (f"Guardrail '{ name } ' not found in registry. Available: { list (self ._guardrails .keys ())} " )
140-
141- guardrail_class = self ._guardrails [name ]
142- return guardrail_class (** kwargs )
143-
144- def list_available (self ) -> List [str ]:
145- """Get list of available guardrail names."""
146- return list (self ._guardrails .keys ())
147-
148-
149- # Global registry instance
150- _registry = GuardrailRegistry ()
151-
152-
153- def register_guardrail (name : str , guardrail_class : Type [BaseGuardrail ]):
154- """Register a guardrail globally.
155-
156- Args:
157- name: Name to register the guardrail under
158- guardrail_class: The guardrail class to register
159- """
160- _registry .register (name , guardrail_class )
161-
162-
163- def create_guardrail (name : str , ** kwargs ) -> BaseGuardrail :
164- """Create an instance of a registered guardrail.
165-
166- Args:
167- name: Name of the registered guardrail
168- **kwargs: Arguments to pass to the guardrail constructor
169-
170- Returns:
171- Instance of the guardrail
172- """
173- return _registry .create (name , ** kwargs )
174-
175-
176- def list_available_guardrails () -> List [str ]:
177- """Get list of available guardrail names."""
178- return _registry .list_available ()
0 commit comments