Skip to content

Commit 01ab146

Browse files
authored
fix(runtime): set stop flag for exception-based rails in parallel mode (#1487)
1 parent 121fc8f commit 01ab146

File tree

6 files changed

+605
-7
lines changed

6 files changed

+605
-7
lines changed

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
313313
result = await func(*args, **kwargs)
314314

315315
has_stop = any(
316-
event["type"] == "BotIntent" and event["intent"] == "stop"
316+
(event["type"] == "BotIntent" and event["intent"] == "stop")
317+
or event["type"].endswith("Exception")
317318
for event in result
318319
)
319320

@@ -377,7 +378,8 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
377378

378379
# Check if this rail requested to stop
379380
has_stop = any(
380-
event["type"] == "BotIntent" and event["intent"] == "stop"
381+
(event["type"] == "BotIntent" and event["intent"] == "stop")
382+
or event["type"].endswith("Exception")
381383
for event in result
382384
)
383385

@@ -402,6 +404,7 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
402404
if v == pending_task:
403405
del unique_flow_ids[k]
404406
break
407+
# Remove the stopped flow from unique_flow_ids so it's not in finished_task_results
405408
del unique_flow_ids[flow_id]
406409
break
407410
else:
@@ -446,15 +449,14 @@ async def task_call_helper(flow_uid, post_event, func, *args, **kwargs):
446449

447450
def filter_and_append(logs, target_log):
448451
for plog in logs:
449-
# Filter out "Listen" and "start_flow" events from task processing log
450452
if plog["type"] == "event" and (
451-
plog["data"]["type"] == "Listen"
452-
or plog["data"]["type"] == "start_flow"
453+
plog["data"]["type"] == "start_flow"
453454
):
454455
continue
455456
target_log.append(plog)
456457

457-
filter_and_append(stopped_task_processing_logs, processing_log)
458+
# Only append finished rails logs. Stopped rail logs should not be appended
459+
# again since they're already in the processing log from when they started
458460
filter_and_append(finished_task_processing_logs, processing_log)
459461

460462
# We pack all events into a single event to add it to the event history.
@@ -463,6 +465,7 @@ def filter_and_append(logs, target_log):
463465
data={"events": finished_task_results},
464466
)
465467

468+
# Return stopped_task_results separately so the caller knows to stop processing
466469
return ActionResult(
467470
events=[history_events] + stopped_task_results,
468471
context_updates=context_updates,

nemoguardrails/logging/processing_log.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,15 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
191191
elif event_type == "OutputRailsFinished":
192192
output_rails_finished_at = event["timestamp"]
193193

194+
elif event_type.endswith("Exception"):
195+
if activated_rail is not None and activated_rail.type in [
196+
"input",
197+
"output",
198+
]:
199+
activated_rail.stop = True
200+
if "stop" not in activated_rail.decisions:
201+
activated_rail.decisions.append("stop")
202+
194203
elif event["type"] == "llm_call_info":
195204
if executed_action is not None:
196205
executed_action.llm_calls.append(event["data"])
@@ -210,7 +219,8 @@ def compute_generation_log(processing_log: List[dict]) -> GenerationLog:
210219

211220
if activated_rail.type in ["input", "output"]:
212221
activated_rail.stop = True
213-
activated_rail.decisions.append("stop")
222+
if "stop" not in activated_rail.decisions:
223+
activated_rail.decisions.append("stop")
214224

215225
# If we have input rails, we also record the general stats
216226
if input_rails_started_at:
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from nemoguardrails.actions import action
17+
18+
19+
@action(is_system_action=True)
20+
async def check_safety_action(context: dict):
21+
user_message = context.get("user_message", "")
22+
23+
unsafe_terms = ["unsafe", "dangerous", "harmful", "kill", "violence"]
24+
is_safe = not any(term in user_message.lower() for term in unsafe_terms)
25+
26+
return is_safe
27+
28+
29+
@action(is_system_action=True)
30+
async def check_topic_action(context: dict):
31+
user_message = context.get("user_message", "")
32+
33+
off_topic_terms = ["offtopic", "irrelevant", "unrelated", "stupid", "idiot"]
34+
is_on_topic = not any(term in user_message.lower() for term in off_topic_terms)
35+
36+
return is_on_topic
37+
38+
39+
@action(is_system_action=True)
40+
async def check_with_context_update(context: dict):
41+
user_message = context.get("user_message", "")
42+
43+
violation_count = context.get("violation_count", 0)
44+
context["violation_count"] = violation_count + 1
45+
46+
blocked_terms = ["blocked", "forbidden"]
47+
is_allowed = not any(term in user_message.lower() for term in blocked_terms)
48+
49+
return is_allowed
50+
51+
52+
@action(is_system_action=True)
53+
async def check_output_safety_action(context: dict):
54+
bot_message = context.get("bot_message", "")
55+
56+
unsafe_terms = ["harmful", "dangerous", "unsafe", "violence"]
57+
is_safe = not any(term in bot_message.lower() for term in unsafe_terms)
58+
59+
return is_safe
60+
61+
62+
@action(is_system_action=True)
63+
async def check_output_length_action(context: dict):
64+
bot_message = context.get("bot_message", "")
65+
66+
max_length = 500
67+
is_valid = len(bot_message) <= max_length
68+
69+
return is_valid
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
models:
2+
- type: main
3+
engine: openai
4+
model: gpt-3.5-turbo-instruct
5+
6+
rails:
7+
input:
8+
flows:
9+
- check safety with exception
10+
- check topic with exception
11+
parallel: True
12+
13+
output:
14+
flows:
15+
- check output safety with exception
16+
- check output length with exception
17+
parallel: True
18+
19+
enable_rails_exceptions: True
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
define flow check safety with exception
2+
$is_safe = execute check_safety_action
3+
4+
if not $is_safe
5+
create event SafetyCheckException(message="Input blocked by safety check")
6+
stop
7+
8+
define flow check topic with exception
9+
$is_on_topic = execute check_topic_action
10+
11+
if not $is_on_topic
12+
create event TopicCheckException(message="Input blocked by topic check")
13+
stop
14+
15+
define flow check with context update
16+
$is_allowed = execute check_with_context_update
17+
18+
if not $is_allowed
19+
create event ContextUpdateException(message="Input blocked with context update")
20+
stop
21+
22+
define flow check output safety with exception
23+
$is_safe = execute check_output_safety_action
24+
25+
if not $is_safe
26+
create event OutputSafetyException(message="Output blocked by safety check")
27+
stop
28+
29+
define flow check output length with exception
30+
$is_valid_length = execute check_output_length_action
31+
32+
if not $is_valid_length
33+
create event OutputLengthException(message="Output blocked due to length")
34+
stop

0 commit comments

Comments
 (0)