Skip to content

Commit 7491c69

Browse files
authored
fix(config): validate content safety and topic control configs at creation time (#1450)
Refactor internal error tests to work with new config-time validation. Remove tests for missing prompts/models (now caught at config-time), update single error message test to use runtime error injection.
1 parent a588d4c commit 7491c69

File tree

4 files changed

+890
-185
lines changed

4 files changed

+890
-185
lines changed

nemoguardrails/rails/llm/config.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from nemoguardrails import utils
3737
from nemoguardrails.colang import parse_colang_file, parse_flow_elements
38+
from nemoguardrails.colang.v1_0.runtime.flows import _normalize_flow_id
3839
from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message
3940
from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError
4041
from nemoguardrails.llm.types import Task
@@ -1451,7 +1452,65 @@ class RailsConfig(BaseModel):
14511452
description="Configuration for tracing.",
14521453
)
14531454

1454-
@root_validator(pre=True, allow_reuse=True)
1455+
@root_validator(pre=True)
1456+
def check_model_exists_for_input_rails(cls, values):
1457+
"""Make sure we have a model for each input rail where one is provided using $model=<model_type>"""
1458+
rails = values.get("rails", {})
1459+
input_flows = rails.get("input", {}).get("flows", [])
1460+
1461+
# If no flows have a model, early-out
1462+
input_flows_without_model = [
1463+
_get_flow_model(flow) is None for flow in input_flows
1464+
]
1465+
if all(input_flows_without_model):
1466+
return values
1467+
1468+
models = values.get("models", []) or []
1469+
model_types = {
1470+
model.type if isinstance(model, Model) else model["type"]
1471+
for model in models
1472+
}
1473+
1474+
for flow in input_flows:
1475+
flow_model = _get_flow_model(flow)
1476+
if not flow_model:
1477+
continue
1478+
if flow_model not in model_types:
1479+
raise ValueError(
1480+
f"No `{flow_model}` model provided for input flow `{_normalize_flow_id(flow)}`"
1481+
)
1482+
return values
1483+
1484+
@root_validator(pre=True)
1485+
def check_model_exists_for_output_rails(cls, values):
1486+
"""Make sure we have a model for each output rail where one is provided using $model=<model_type>"""
1487+
rails = values.get("rails", {})
1488+
output_flows = rails.get("output", {}).get("flows", [])
1489+
1490+
# If no flows have a model, early-out
1491+
output_flows_without_model = [
1492+
_get_flow_model(flow) is None for flow in output_flows
1493+
]
1494+
if all(output_flows_without_model):
1495+
return values
1496+
1497+
models = values.get("models", []) or []
1498+
model_types = {
1499+
model.type if isinstance(model, Model) else model["type"]
1500+
for model in models
1501+
}
1502+
1503+
for flow in output_flows:
1504+
flow_model = _get_flow_model(flow)
1505+
if not flow_model:
1506+
continue
1507+
if flow_model not in model_types:
1508+
raise ValueError(
1509+
f"No `{flow_model}` model provided for output flow `{_normalize_flow_id(flow)}`"
1510+
)
1511+
return values
1512+
1513+
@root_validator(pre=True)
14551514
def check_prompt_exist_for_self_check_rails(cls, values):
14561515
rails = values.get("rails", {})
14571516
prompts = values.get("prompts", []) or []
@@ -1477,6 +1536,16 @@ def check_prompt_exist_for_self_check_rails(cls, values):
14771536
"You must provide a `llama_guard_check_input` prompt template."
14781537
)
14791538

1539+
# Only content-safety and topic-safety include a $model reference in the rail flow text
1540+
# Need to match rails with flow_id (excluding $model reference) and match prompts
1541+
# on the full flow_id (including $model reference)
1542+
_validate_rail_prompts(
1543+
enabled_input_rails, provided_task_prompts, "content safety check input"
1544+
)
1545+
_validate_rail_prompts(
1546+
enabled_input_rails, provided_task_prompts, "topic safety check input"
1547+
)
1548+
14801549
# Output moderation prompt verification
14811550
if (
14821551
"self check output" in enabled_output_rails
@@ -1504,6 +1573,13 @@ def check_prompt_exist_for_self_check_rails(cls, values):
15041573
):
15051574
raise ValueError("You must provide a `self_check_facts` prompt template.")
15061575

1576+
# Only content-safety and topic-safety include a $model reference in the rail flow text
1577+
# Need to match rails with flow_id (excluding $model reference) and match prompts
1578+
# on the full flow_id (including $model reference)
1579+
_validate_rail_prompts(
1580+
enabled_output_rails, provided_task_prompts, "content safety check output"
1581+
)
1582+
15071583
return values
15081584

15091585
@root_validator(pre=True, allow_reuse=True)
@@ -1833,3 +1909,28 @@ def _generate_rails_flows(flows):
18331909
flow_definitions.insert(1, _LIBRARY_IMPORT + _NEWLINE * 2)
18341910

18351911
return flow_definitions
1912+
1913+
1914+
MODEL_PREFIX = "$model="
1915+
1916+
1917+
def _get_flow_model(flow_text) -> Optional[str]:
1918+
"""Helper to return a model name from a flow definition"""
1919+
if MODEL_PREFIX not in flow_text:
1920+
return None
1921+
return flow_text.split(MODEL_PREFIX)[-1].strip()
1922+
1923+
1924+
def _validate_rail_prompts(
1925+
rails: list[str], prompts: list[Any], validation_rail: str
1926+
) -> None:
1927+
for rail in rails:
1928+
flow_id = _normalize_flow_id(rail)
1929+
flow_model = _get_flow_model(rail)
1930+
if flow_id == validation_rail:
1931+
prompt_flow_id = flow_id.replace(" ", "_")
1932+
expected_prompt = f"{prompt_flow_id} $model={flow_model}"
1933+
if expected_prompt not in prompts:
1934+
raise ValueError(
1935+
f"You must provide a `{expected_prompt}` prompt template."
1936+
)

tests/test_internal_error_parallel_rails.py

Lines changed: 55 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -78,48 +78,6 @@ async def test_internal_error_stops_execution():
7878
not _has_langchain_openai or not _has_openai_key,
7979
reason="langchain-openai not available",
8080
)
81-
@pytest.mark.asyncio
82-
async def test_content_safety_missing_prompt():
83-
config_data = {
84-
"instructions": [
85-
{"type": "general", "content": "You are a helpful assistant."}
86-
],
87-
"models": [
88-
{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"},
89-
{"type": "content_safety", "engine": "openai", "model": "gpt-3.5-turbo"},
90-
],
91-
"rails": {
92-
"input": {
93-
"flows": [
94-
"content safety check input $model=content_safety",
95-
"self check input",
96-
],
97-
"parallel": True,
98-
}
99-
},
100-
}
101-
102-
config = RailsConfig.from_content(
103-
config=config_data,
104-
yaml_content="prompts:\n - task: self_check_input\n content: 'Is the user input safe? Answer Yes or No.'",
105-
)
106-
107-
chat = TestChat(config, llm_completions=["Safe response"])
108-
chat >> "test message"
109-
110-
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
111-
112-
assert result is not None
113-
assert "internal error" in result.response[0]["content"].lower()
114-
115-
stop_events = [
116-
event
117-
for event in result.log.internal_events
118-
if event.get("type") == "BotIntent" and event.get("intent") == "stop"
119-
]
120-
assert len(stop_events) > 0
121-
122-
12381
@pytest.mark.asyncio
12482
async def test_no_app_llm_request_on_internal_error():
12583
"""Test that App LLM request is not sent when internal error occurs."""
@@ -164,48 +122,6 @@ async def test_no_app_llm_request_on_internal_error():
164122
), "Expected BotIntent stop event after internal error"
165123

166124

167-
@pytest.mark.asyncio
168-
async def test_content_safety_missing_model():
169-
"""Test content safety with missing model configuration."""
170-
config_data = {
171-
"instructions": [
172-
{"type": "general", "content": "You are a helpful assistant."}
173-
],
174-
"models": [
175-
{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}
176-
# missing content_safety model
177-
],
178-
"rails": {
179-
"input": {
180-
"flows": ["content safety check input $model=content_safety"],
181-
"parallel": True,
182-
}
183-
},
184-
}
185-
186-
config = RailsConfig.from_content(
187-
config=config_data,
188-
yaml_content="prompts:\n - task: content_safety_check_input $model=content_safety\n content: 'Check if this is safe: {{ user_input }}'",
189-
)
190-
191-
chat = TestChat(config, llm_completions=["Response"])
192-
chat >> "test message"
193-
194-
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
195-
196-
# should get internal error due to missing model
197-
assert result is not None
198-
assert "internal error" in result.response[0]["content"].lower()
199-
200-
# verify stop event was generated
201-
stop_events = [
202-
event
203-
for event in result.log.internal_events
204-
if event.get("type") == "BotIntent" and event.get("intent") == "stop"
205-
]
206-
assert len(stop_events) > 0
207-
208-
209125
@pytest.mark.asyncio
210126
async def test_parallel_rails_partial_failure():
211127
"""Test that partial failure in parallel rails is handled properly."""
@@ -343,21 +259,28 @@ async def test_action_execution_returns_failed():
343259
), "Expected BotIntent stop event after action failure"
344260

345261

262+
@pytest.mark.skipif(
263+
not _has_langchain_openai or not _has_openai_key,
264+
reason="langchain-openai not available",
265+
)
346266
@pytest.mark.asyncio
347267
async def test_single_error_message_not_multiple():
348268
"""Test that we get exactly one error message, not multiple for each failed rail.
349269
350270
Before the fix, if we had multiple rails failing, we'd get multiple error messages.
351271
This test verifies we only get one error message even with multiple parallel rails.
272+
Now with config-time validation, we provide valid config and trigger runtime failures.
352273
"""
353274
config_data = {
354-
"models": [{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"}],
275+
"models": [
276+
{"type": "main", "engine": "openai", "model": "gpt-3.5-turbo"},
277+
{"type": "content_safety", "engine": "openai", "model": "gpt-3.5-turbo"},
278+
],
355279
"rails": {
356280
"input": {
357281
"flows": [
358282
"self check input",
359283
"content safety check input $model=content_safety",
360-
"llama guard check input $model=llama_guard",
361284
],
362285
"parallel": True,
363286
}
@@ -366,44 +289,56 @@ async def test_single_error_message_not_multiple():
366289
{
367290
"task": "self_check_input",
368291
"content": "Is the user input safe? Answer Yes or No.",
369-
}
370-
# missing prompts and models for content_safety and llama_guard
292+
},
293+
{
294+
"task": "content_safety_check_input $model=content_safety",
295+
"content": "Check content safety: {{ user_input }}",
296+
},
371297
],
372298
}
373299

374300
config = RailsConfig.from_content(config=config_data)
375301

376-
chat = TestChat(config, llm_completions=["Test response"])
377-
chat >> "test message"
302+
with patch(
303+
"nemoguardrails.llm.taskmanager.LLMTaskManager.render_task_prompt"
304+
) as mock_render:
305+
mock_render.side_effect = Exception("Runtime error in multiple rails")
378306

379-
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
307+
chat = TestChat(config, llm_completions=["Test response"])
308+
chat >> "test message"
380309

381-
# should get exactly one response, not multiple
382-
assert result is not None
383-
assert len(result.response) == 1, f"Expected 1 response, got {len(result.response)}"
384-
385-
# that single response should be an internal error
386-
assert "internal error" in result.response[0]["content"].lower()
387-
388-
# count how many times "internal error" appears in the response
389-
error_count = result.response[0]["content"].lower().count("internal error")
390-
assert error_count == 1, f"Expected 1 'internal error' message, found {error_count}"
391-
392-
# verify stop event was generated
393-
stop_events = [
394-
event
395-
for event in result.log.internal_events
396-
if event.get("type") == "BotIntent" and event.get("intent") == "stop"
397-
]
398-
assert len(stop_events) >= 1, "Expected at least one BotIntent stop event"
399-
400-
# verify we don't have multiple StartUtteranceBotAction events with error messages
401-
error_utterances = [
402-
event
403-
for event in result.log.internal_events
404-
if event.get("type") == "StartUtteranceBotAction"
405-
and "internal error" in event.get("script", "").lower()
406-
]
407-
assert (
408-
len(error_utterances) == 1
409-
), f"Expected 1 error utterance, found {len(error_utterances)}"
310+
result = await chat.app.generate_async(messages=chat.history, options=OPTIONS)
311+
312+
# should get exactly one response, not multiple
313+
assert result is not None
314+
assert (
315+
len(result.response) == 1
316+
), f"Expected 1 response, got {len(result.response)}"
317+
318+
# that single response should be an internal error
319+
assert "internal error" in result.response[0]["content"].lower()
320+
321+
# count how many times "internal error" appears in the response
322+
error_count = result.response[0]["content"].lower().count("internal error")
323+
assert (
324+
error_count == 1
325+
), f"Expected 1 'internal error' message, found {error_count}"
326+
327+
# verify stop event was generated
328+
stop_events = [
329+
event
330+
for event in result.log.internal_events
331+
if event.get("type") == "BotIntent" and event.get("intent") == "stop"
332+
]
333+
assert len(stop_events) >= 1, "Expected at least one BotIntent stop event"
334+
335+
# verify we don't have multiple StartUtteranceBotAction events with error messages
336+
error_utterances = [
337+
event
338+
for event in result.log.internal_events
339+
if event.get("type") == "StartUtteranceBotAction"
340+
and "internal error" in event.get("script", "").lower()
341+
]
342+
assert (
343+
len(error_utterances) == 1
344+
), f"Expected 1 error utterance, found {len(error_utterances)}"

0 commit comments

Comments
 (0)