Skip to content

Commit ab3f458

Browse files
authored
Updated prompt injection check (#27)
* Updated prompt injection check * Formatting changes * Removed legacy code * update results doc * updating dataset details
1 parent c1d868b commit ab3f458

File tree

12 files changed

+510
-180
lines changed

12 files changed

+510
-180
lines changed
-24 KB
Loading

docs/ref/checks/prompt_injection_detection.md

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,8 @@ Returns a `GuardrailResult` with the following `info` dictionary:
9292

9393
This benchmark evaluates model performance on agent conversation traces:
9494

95-
- **Synthetic dataset**: 1,000 samples with 500 positive cases (50% prevalence) simulating realistic agent traces
96-
- **AgentDojo dataset**: 1,046 samples from AgentDojo's workspace, travel, banking, and Slack suite combined with the "important_instructions" attack (949 positive cases, 97 negative samples)
97-
- **Test scenarios**: Multi-turn conversations with function calls and tool outputs across realistic workplace domains
98-
- **Misalignment examples**: Unrelated function calls, harmful operations, and data leakage
95+
- **[AgentDojo dataset](https://github.com/ethz-spylab/agentdojo)**: 1,046 samples generated from running AgentDojo's benchmark script on workspace, travel, banking, and Slack suite combined with the "important_instructions" attack (949 positive cases, 97 negative samples)
96+
- **Internal synthetic dataset**: 537 positive cases simulating realistic, multi-turn agent conversation traces
9997

10098
**Example of misaligned conversation:**
10199

@@ -113,12 +111,12 @@ This benchmark evaluates model performance on agent conversation traces:
113111

114112
| Model | ROC AUC | Prec@R=0.80 | Prec@R=0.90 | Prec@R=0.95 | Recall@FPR=0.01 |
115113
|---------------|---------|-------------|-------------|-------------|-----------------|
116-
| gpt-5 | 0.9604 | 0.998 | 0.995 | 0.963 | 0.431 |
117-
| gpt-5-mini | 0.9796 | 0.999 | 0.999 | 0.966 | 0.000 |
118-
| gpt-5-nano | 0.8651 | 0.963 | 0.963 | 0.951 | 0.056 |
119-
| gpt-4.1 | 0.9846 | 0.998 | 0.998 | 0.998 | 0.000 |
120-
| gpt-4.1-mini (default) | 0.9728 | 0.995 | 0.995 | 0.995 | 0.000 |
121-
| gpt-4.1-nano | 0.8677 | 0.974 | 0.974 | 0.974 | 0.000 |
114+
| gpt-5 | 0.9931 | 0.9992 | 0.9992 | 0.9992 | 0.5845 |
115+
| gpt-5-mini | 0.9536 | 0.9951 | 0.9951 | 0.9951 | 0.0000 |
116+
| gpt-5-nano | 0.9283 | 0.9913 | 0.9913 | 0.9717 | 0.0350 |
117+
| gpt-4.1 | 0.9794 | 0.9973 | 0.9973 | 0.9973 | 0.0000 |
118+
| gpt-4.1-mini (default) | 0.9865 | 0.9986 | 0.9986 | 0.9986 | 0.0000 |
119+
| gpt-4.1-nano | 0.9142 | 0.9948 | 0.9948 | 0.9387 | 0.0000 |
122120

123121
**Notes:**
124122

src/guardrails/checks/text/hallucination_detection.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,13 @@
5252
from guardrails.spec import GuardrailSpecMetadata
5353
from guardrails.types import GuardrailLLMContextProto, GuardrailResult
5454

55-
from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable
55+
from .llm_base import (
56+
LLMConfig,
57+
LLMErrorOutput,
58+
LLMOutput,
59+
_invoke_openai_callable,
60+
create_error_result,
61+
)
5662

5763
logger = logging.getLogger(__name__)
5864

@@ -232,39 +238,43 @@ async def hallucination_detection(
232238
)
233239

234240
except ValueError as e:
235-
# Log validation errors but return safe default
241+
# Log validation errors and use shared error helper
236242
logger.warning(f"Validation error in hallucination_detection: {e}")
237-
return GuardrailResult(
238-
tripwire_triggered=False,
239-
info={
240-
"guardrail_name": "Hallucination Detection",
241-
"flagged": False,
242-
"confidence": 0.0,
243+
error_output = LLMErrorOutput(
244+
flagged=False,
245+
confidence=0.0,
246+
info={"error_message": f"Validation failed: {str(e)}"},
247+
)
248+
return create_error_result(
249+
guardrail_name="Hallucination Detection",
250+
analysis=error_output,
251+
checked_text=candidate,
252+
additional_info={
253+
"threshold": config.confidence_threshold,
243254
"reasoning": f"Validation failed: {str(e)}",
244255
"hallucination_type": None,
245256
"hallucinated_statements": None,
246257
"verified_statements": None,
247-
"threshold": config.confidence_threshold,
248-
"error": str(e),
249-
"checked_text": candidate, # Hallucination Detection doesn't modify text, pass through unchanged
250258
},
251259
)
252260
except Exception as e:
253-
# Log unexpected errors and return safe default
261+
# Log unexpected errors and use shared error helper
254262
logger.exception("Unexpected error in hallucination_detection")
255-
return GuardrailResult(
256-
tripwire_triggered=False,
257-
info={
258-
"guardrail_name": "Hallucination Detection",
259-
"flagged": False,
260-
"confidence": 0.0,
263+
error_output = LLMErrorOutput(
264+
flagged=False,
265+
confidence=0.0,
266+
info={"error_message": str(e)},
267+
)
268+
return create_error_result(
269+
guardrail_name="Hallucination Detection",
270+
analysis=error_output,
271+
checked_text=candidate,
272+
additional_info={
273+
"threshold": config.confidence_threshold,
261274
"reasoning": f"Analysis failed: {str(e)}",
262275
"hallucination_type": None,
263276
"hallucinated_statements": None,
264277
"verified_statements": None,
265-
"threshold": config.confidence_threshold,
266-
"error": str(e),
267-
"checked_text": candidate, # Hallucination Detection doesn't modify text, pass through unchanged
268278
},
269279
)
270280

src/guardrails/checks/text/llm_base.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,13 @@ class MyLLMOutput(LLMOutput):
6060
logger = logging.getLogger(__name__)
6161

6262

63-
__all__ = ["LLMConfig", "LLMOutput", "LLMErrorOutput", "create_llm_check_fn"]
63+
__all__ = [
64+
"LLMConfig",
65+
"LLMOutput",
66+
"LLMErrorOutput",
67+
"create_llm_check_fn",
68+
"create_error_result",
69+
]
6470

6571

6672
class LLMConfig(BaseModel):
@@ -115,6 +121,44 @@ class LLMErrorOutput(LLMOutput):
115121
info: dict
116122

117123

124+
def create_error_result(
125+
guardrail_name: str,
126+
analysis: LLMErrorOutput,
127+
checked_text: str,
128+
additional_info: dict[str, Any] | None = None,
129+
) -> GuardrailResult:
130+
"""Create a standardized GuardrailResult from an LLM error output.
131+
132+
Args:
133+
guardrail_name: Name of the guardrail that failed.
134+
analysis: The LLM error output.
135+
checked_text: The text that was being checked.
136+
additional_info: Optional additional fields to include in info dict.
137+
138+
Returns:
139+
GuardrailResult with execution_failed=True.
140+
"""
141+
error_info = getattr(analysis, "info", {})
142+
error_message = error_info.get("error_message", "LLM execution failed")
143+
144+
result_info: dict[str, Any] = {
145+
"guardrail_name": guardrail_name,
146+
"checked_text": checked_text,
147+
"error": error_message,
148+
**analysis.model_dump(),
149+
}
150+
151+
if additional_info:
152+
result_info.update(additional_info)
153+
154+
return GuardrailResult(
155+
tripwire_triggered=False,
156+
execution_failed=True,
157+
original_exception=Exception(error_message),
158+
info=result_info,
159+
)
160+
161+
118162
def _build_full_prompt(system_prompt: str) -> str:
119163
"""Assemble a complete LLM prompt with instructions and response schema.
120164
@@ -334,20 +378,10 @@ async def guardrail_func(
334378

335379
# Check if this is an error result
336380
if isinstance(analysis, LLMErrorOutput):
337-
# Extract error information from the LLMErrorOutput
338-
error_info = analysis.info if hasattr(analysis, "info") else {}
339-
error_message = error_info.get("error_message", "LLM execution failed")
340-
341-
return GuardrailResult(
342-
tripwire_triggered=False, # Don't trigger tripwire on execution errors
343-
execution_failed=True,
344-
original_exception=Exception(error_message), # Create exception from error message
345-
info={
346-
"guardrail_name": name,
347-
"checked_text": data,
348-
"error": error_message,
349-
**analysis.model_dump(),
350-
},
381+
return create_error_result(
382+
guardrail_name=name,
383+
analysis=analysis,
384+
checked_text=data,
351385
)
352386

353387
# Compare severity levels

0 commit comments

Comments
 (0)