Skip to content

Commit a67c5b0

Browse files
Pouyanpitgasser-nv
authored andcommitted
feat(tool-rails): add support for tool output rails and validation (#1382)
Introduce tool output/input rails configuration and Colang flows for tool call validation and parameter security checks. Add support for BotToolCall event emission in passthrough mode, enabling tool call guardrails before execution.
1 parent c44396c commit a67c5b0

13 files changed

+1943
-35
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,21 @@ async def generate_user_intent(
582582
if streaming_handler:
583583
await streaming_handler.push_chunk(text)
584584

585-
output_events.append(new_event_dict("BotMessage", text=text))
585+
if self.config.passthrough:
586+
from nemoguardrails.actions.llm.utils import (
587+
get_and_clear_tool_calls_contextvar,
588+
)
589+
590+
tool_calls = get_and_clear_tool_calls_contextvar()
591+
592+
if tool_calls:
593+
output_events.append(
594+
new_event_dict("BotToolCall", tool_calls=tool_calls)
595+
)
596+
else:
597+
output_events.append(new_event_dict("BotMessage", text=text))
598+
else:
599+
output_events.append(new_event_dict("BotMessage", text=text))
586600

587601
return ActionResult(events=output_events)
588602

nemoguardrails/actions/llm/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -673,6 +673,21 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:
673673
return None
674674

675675

676+
def extract_tool_calls_from_events(events: list) -> Optional[list]:
677+
"""Extract tool_calls from BotToolCall events.
678+
679+
Args:
680+
events: List of events to search through
681+
682+
Returns:
683+
tool_calls if found in BotToolCall event, None otherwise
684+
"""
685+
for event in events:
686+
if event.get("type") == "BotToolCall":
687+
return event.get("tool_calls")
688+
return None
689+
690+
676691
def get_and_clear_response_metadata_contextvar() -> Optional[dict]:
677692
"""Get the current response metadata and clear it from the context.
678693

nemoguardrails/rails/llm/config.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,40 @@ class ActionRails(BaseModel):
527527
)
528528

529529

530+
class ToolOutputRails(BaseModel):
531+
"""Configuration of tool output rails.
532+
533+
Tool output rails are applied to tool calls before they are executed.
534+
They can validate tool names, parameters, and context to ensure safe tool usage.
535+
"""
536+
537+
flows: List[str] = Field(
538+
default_factory=list,
539+
description="The names of all the flows that implement tool output rails.",
540+
)
541+
parallel: Optional[bool] = Field(
542+
default=False,
543+
description="If True, the tool output rails are executed in parallel.",
544+
)
545+
546+
547+
class ToolInputRails(BaseModel):
548+
"""Configuration of tool input rails.
549+
550+
Tool input rails are applied to tool results before they are processed.
551+
They can validate, filter, or transform tool outputs for security and safety.
552+
"""
553+
554+
flows: List[str] = Field(
555+
default_factory=list,
556+
description="The names of all the flows that implement tool input rails.",
557+
)
558+
parallel: Optional[bool] = Field(
559+
default=False,
560+
description="If True, the tool input rails are executed in parallel.",
561+
)
562+
563+
530564
class SingleCallConfig(BaseModel):
531565
"""Configuration for the single LLM call option for topical rails."""
532566

@@ -912,6 +946,14 @@ class Rails(BaseModel):
912946
actions: ActionRails = Field(
913947
default_factory=ActionRails, description="Configuration of action rails."
914948
)
949+
tool_output: ToolOutputRails = Field(
950+
default_factory=ToolOutputRails,
951+
description="Configuration of tool output rails.",
952+
)
953+
tool_input: ToolInputRails = Field(
954+
default_factory=ToolInputRails,
955+
description="Configuration of tool input rails.",
956+
)
915957

916958

917959
def merge_two_dicts(dict_1: dict, dict_2: dict, ignore_keys: Set[str]) -> None:

nemoguardrails/rails/llm/llm_flows.co

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,34 @@ define parallel extension flow generate bot message
102102
execute generate_bot_message
103103

104104

105+
define parallel extension flow process bot tool call
106+
"""Processes tool calls from the bot."""
107+
priority 100
108+
109+
event BotToolCall
110+
111+
$tool_calls = $event.tool_calls
112+
113+
# Run tool-specific output rails if configured (Phase 2)
114+
if $config.rails.tool_output.flows
115+
# If we have generation options, we make sure the tool output rails are enabled.
116+
if $generation_options is None or $generation_options.rails.tool_output:
117+
# Create a marker event.
118+
create event StartToolOutputRails
119+
event StartToolOutputRails
120+
121+
# Run all the tool output rails
122+
# This can potentially alter or block the tool calls
123+
do run tool output rails
124+
125+
# Create a marker event.
126+
create event ToolOutputRailsFinished
127+
event ToolOutputRailsFinished
128+
129+
# Create the action event for tool execution
130+
create event StartToolCallBotAction(tool_calls=$tool_calls)
131+
132+
105133
define parallel extension flow process bot message
106134
"""Runs the output rails on a bot message."""
107135
priority 100
@@ -164,3 +192,25 @@ define subflow run retrieval rails
164192
while $i < len($retrieval_flows)
165193
do $retrieval_flows[$i]
166194
$i = $i + 1
195+
196+
197+
define subflow run tool output rails
198+
"""Runs all the tool output rails in a sequential order."""
199+
$tool_output_flows = $config.rails.tool_output.flows
200+
201+
$i = 0
202+
while $i < len($tool_output_flows)
203+
# We set the current rail as being triggered.
204+
$triggered_tool_output_rail = $tool_output_flows[$i]
205+
206+
create event StartToolOutputRail(flow_id=$triggered_tool_output_rail)
207+
event StartToolOutputRail
208+
209+
do $tool_output_flows[$i]
210+
$i = $i + 1
211+
212+
create event ToolOutputRailFinished(flow_id=$triggered_tool_output_rail)
213+
event ToolOutputRailFinished
214+
215+
# If all went smooth, we remove it.
216+
$triggered_tool_output_rail = None

nemoguardrails/rails/llm/llmrails.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232

3333
from nemoguardrails.actions.llm.generation import LLMGenerationActions
3434
from nemoguardrails.actions.llm.utils import (
35+
extract_tool_calls_from_events,
3536
get_and_clear_reasoning_trace_contextvar,
3637
get_and_clear_response_metadata_contextvar,
37-
get_and_clear_tool_calls_contextvar,
3838
get_colang_history,
3939
)
4040
from nemoguardrails.actions.output_mapping import is_output_blocked
@@ -1086,7 +1086,7 @@ async def generate_async(
10861086
options.log.llm_calls = True
10871087
options.log.internal_events = True
10881088

1089-
tool_calls = get_and_clear_tool_calls_contextvar()
1089+
tool_calls = extract_tool_calls_from_events(new_events)
10901090
llm_metadata = get_and_clear_response_metadata_contextvar()
10911091

10921092
# If we have generation options, we prepare a GenerationResponse instance.

nemoguardrails/rails/llm/options.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ class GenerationRailsOptions(BaseModel):
127127
default=True,
128128
description="Whether the dialog rails are enabled or not.",
129129
)
130+
tool_output: Union[bool, List[str]] = Field(
131+
default=True,
132+
description="Whether the tool output rails are enabled or not. "
133+
"If a list of names is specified, then only the specified tool output rails will be applied.",
134+
)
135+
tool_input: Union[bool, List[str]] = Field(
136+
default=True,
137+
description="Whether the tool input rails are enabled or not. "
138+
"If a list of names is specified, then only the specified tool input rails will be applied.",
139+
)
130140

131141

132142
class GenerationOptions(BaseModel):
@@ -177,6 +187,8 @@ def check_fields(cls, values):
177187
"dialog": False,
178188
"retrieval": False,
179189
"output": False,
190+
"tool_output": False,
191+
"tool_input": False,
180192
}
181193
for rail_type in values["rails"]:
182194
_rails[rail_type] = True

0 commit comments

Comments
 (0)