Skip to content

Commit 623f812

Browse files
chore: completes OPEN-7287 remove concrete guardrail implementations
1 parent c80943f commit 623f812

File tree

3 files changed

+32
-424
lines changed

3 files changed

+32
-424
lines changed

src/openlayer/lib/guardrails/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,12 @@
66
GuardrailResult,
77
BaseGuardrail,
88
GuardrailBlockedException,
9-
GuardrailRegistry,
109
)
11-
from .pii import PIIGuardrail
1210

1311
__all__ = [
1412
"GuardrailAction",
1513
"BlockStrategy",
16-
"GuardrailResult",
14+
"GuardrailResult",
1715
"BaseGuardrail",
1816
"GuardrailBlockedException",
19-
"GuardrailRegistry",
20-
"PIIGuardrail",
2117
]

src/openlayer/lib/guardrails/base.py

Lines changed: 31 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,49 +3,61 @@
33
import abc
44
import enum
55
import logging
6-
from typing import Any, Dict, List, Optional, Union, Type
6+
from typing import Any, Dict, Optional
77
from dataclasses import dataclass
88

99
logger = logging.getLogger(__name__)
1010

1111

1212
class GuardrailAction(enum.Enum):
1313
"""Actions that a guardrail can take."""
14+
1415
ALLOW = "allow"
1516
BLOCK = "block"
1617
MODIFY = "modify"
1718

1819

1920
class BlockStrategy(enum.Enum):
2021
"""Strategies for handling blocked requests."""
21-
RAISE_EXCEPTION = "raise_exception" # Raise GuardrailBlockedException (breaks pipeline)
22-
RETURN_EMPTY = "return_empty" # Return empty/None response (graceful)
22+
23+
RAISE_EXCEPTION = (
24+
"raise_exception" # Raise GuardrailBlockedException (breaks pipeline)
25+
)
26+
RETURN_EMPTY = "return_empty" # Return empty/None response (graceful)
2327
RETURN_ERROR_MESSAGE = "return_error_message" # Return error message (graceful)
24-
SKIP_FUNCTION = "skip_function" # Skip function execution, return None (graceful)
28+
SKIP_FUNCTION = "skip_function" # Skip function execution, return None (graceful)
2529

2630

2731
@dataclass
2832
class GuardrailResult:
2933
"""Result of applying a guardrail."""
34+
3035
action: GuardrailAction
3136
modified_data: Optional[Any] = None
3237
metadata: Optional[Dict[str, Any]] = None
3338
reason: Optional[str] = None
3439
block_strategy: Optional[BlockStrategy] = None
3540
error_message: Optional[str] = None
36-
41+
3742
def __post_init__(self):
3843
"""Validate the result after initialization."""
3944
if self.action == GuardrailAction.MODIFY and self.modified_data is None:
4045
raise ValueError("modified_data must be provided when action is MODIFY")
4146
if self.action == GuardrailAction.BLOCK and self.block_strategy is None:
42-
self.block_strategy = BlockStrategy.RAISE_EXCEPTION # Default to existing behavior
47+
self.block_strategy = (
48+
BlockStrategy.RAISE_EXCEPTION
49+
) # Default to existing behavior
4350

4451

4552
class GuardrailBlockedException(Exception):
4653
"""Exception raised when a guardrail blocks execution."""
47-
48-
def __init__(self, guardrail_name: str, reason: str, metadata: Optional[Dict[str, Any]] = None):
54+
55+
def __init__(
56+
self,
57+
guardrail_name: str,
58+
reason: str,
59+
metadata: Optional[Dict[str, Any]] = None,
60+
):
4961
self.guardrail_name = guardrail_name
5062
self.reason = reason
5163
self.metadata = metadata or {}
@@ -54,10 +66,10 @@ def __init__(self, guardrail_name: str, reason: str, metadata: Optional[Dict[str
5466

5567
class BaseGuardrail(abc.ABC):
5668
"""Base class for all guardrails."""
57-
69+
5870
def __init__(self, name: str, enabled: bool = True, **config):
5971
"""Initialize the guardrail.
60-
72+
6173
Args:
6274
name: Human-readable name for this guardrail
6375
enabled: Whether this guardrail is active
@@ -66,113 +78,41 @@ def __init__(self, name: str, enabled: bool = True, **config):
6678
self.name = name
6779
self.enabled = enabled
6880
self.config = config
69-
81+
7082
@abc.abstractmethod
7183
def check_input(self, inputs: Dict[str, Any]) -> GuardrailResult:
7284
"""Check and potentially modify function inputs.
73-
85+
7486
Args:
7587
inputs: Dictionary of function inputs (parameter_name -> value)
76-
88+
7789
Returns:
7890
GuardrailResult indicating the action to take
7991
"""
8092
pass
81-
93+
8294
@abc.abstractmethod
8395
def check_output(self, output: Any, inputs: Dict[str, Any]) -> GuardrailResult:
8496
"""Check and potentially modify function output.
85-
97+
8698
Args:
8799
output: The function's output
88100
inputs: Dictionary of function inputs for context
89-
101+
90102
Returns:
91103
GuardrailResult indicating the action to take
92104
"""
93105
pass
94-
106+
95107
def is_enabled(self) -> bool:
96108
"""Check if this guardrail is enabled."""
97109
return self.enabled
98-
110+
99111
def get_metadata(self) -> Dict[str, Any]:
100112
"""Get metadata about this guardrail for trace logging."""
101113
return {
102114
"name": self.name,
103115
"type": self.__class__.__name__,
104116
"enabled": self.enabled,
105-
"config": self.config
117+
"config": self.config,
106118
}
107-
108-
109-
class GuardrailRegistry:
110-
"""Registry for managing guardrails."""
111-
112-
def __init__(self):
113-
self._guardrails: Dict[str, Type[BaseGuardrail]] = {}
114-
115-
def register(self, name: str, guardrail_class: Type[BaseGuardrail]):
116-
"""Register a guardrail class.
117-
118-
Args:
119-
name: Name to register the guardrail under
120-
guardrail_class: The guardrail class to register
121-
"""
122-
if not issubclass(guardrail_class, BaseGuardrail):
123-
raise ValueError(f"Guardrail class must inherit from BaseGuardrail")
124-
125-
self._guardrails[name] = guardrail_class
126-
logger.debug(f"Registered guardrail: {name}")
127-
128-
def create(self, name: str, **kwargs) -> BaseGuardrail:
129-
"""Create an instance of a registered guardrail.
130-
131-
Args:
132-
name: Name of the registered guardrail
133-
**kwargs: Arguments to pass to the guardrail constructor
134-
135-
Returns:
136-
Instance of the guardrail
137-
"""
138-
if name not in self._guardrails:
139-
raise ValueError(f"Guardrail '{name}' not found in registry. Available: {list(self._guardrails.keys())}")
140-
141-
guardrail_class = self._guardrails[name]
142-
return guardrail_class(**kwargs)
143-
144-
def list_available(self) -> List[str]:
145-
"""Get list of available guardrail names."""
146-
return list(self._guardrails.keys())
147-
148-
149-
# Global registry instance
150-
_registry = GuardrailRegistry()
151-
152-
153-
def register_guardrail(name: str, guardrail_class: Type[BaseGuardrail]):
154-
"""Register a guardrail globally.
155-
156-
Args:
157-
name: Name to register the guardrail under
158-
guardrail_class: The guardrail class to register
159-
"""
160-
_registry.register(name, guardrail_class)
161-
162-
163-
def create_guardrail(name: str, **kwargs) -> BaseGuardrail:
164-
"""Create an instance of a registered guardrail.
165-
166-
Args:
167-
name: Name of the registered guardrail
168-
**kwargs: Arguments to pass to the guardrail constructor
169-
170-
Returns:
171-
Instance of the guardrail
172-
"""
173-
return _registry.create(name, **kwargs)
174-
175-
176-
def list_available_guardrails() -> List[str]:
177-
"""Get list of available guardrail names."""
178-
return _registry.list_available()

0 commit comments

Comments
 (0)