Skip to content

Commit 8393e8a

Browse files
Pouyanpitgasser-nv
authored andcommitted
feat(tool-rails): implement tool input rails for tool message validation and processing (#1386)
- Add UserToolMessages event handling and tool input rails processing - Fix message-to-event conversion to properly handle tool messages in conversation history - Preserve tool call context in passthrough mode by using full conversation history - Support tool_calls and tool message metadata in LangChain format conversion - Include comprehensive test suite for tool input rails functionality test(runnable_rails): fix prompt format in passthrough mode feat: support ToolMessage in message dicts refactor: rename BotToolCall to BotToolCalls
1 parent a67c5b0 commit 8393e8a

File tree

12 files changed

+1333
-36
lines changed

12 files changed

+1333
-36
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ async def generate_user_intent(
591591

592592
if tool_calls:
593593
output_events.append(
594-
new_event_dict("BotToolCall", tool_calls=tool_calls)
594+
new_event_dict("BotToolCalls", tool_calls=tool_calls)
595595
)
596596
else:
597597
output_events.append(new_event_dict("BotMessage", text=text))
@@ -905,9 +905,23 @@ async def generate_bot_message(
905905
LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value)
906906
)
907907

908-
# We use the potentially updated $user_message. This means that even
909-
# in passthrough mode, input rails can still alter the input.
910-
prompt = context.get("user_message")
908+
# In passthrough mode, we should use the full conversation history
909+
# instead of just the last user message to preserve tool message context
910+
raw_prompt = raw_llm_request.get()
911+
912+
if raw_prompt is not None and isinstance(raw_prompt, list):
913+
# Use the full conversation including tool messages
914+
prompt = raw_prompt.copy()
915+
916+
# Update the last user message if it was altered by input rails
917+
user_message = context.get("user_message")
918+
if user_message and prompt:
919+
for i in reversed(range(len(prompt))):
920+
if prompt[i]["role"] == "user":
921+
prompt[i]["content"] = user_message
922+
break
923+
else:
924+
prompt = context.get("user_message")
911925

912926
generation_options: GenerationOptions = generation_options_var.get()
913927
with llm_params(

nemoguardrails/actions/llm/utils.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,23 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
153153
if msg_type == "user":
154154
messages.append(HumanMessage(content=msg["content"]))
155155
elif msg_type in ["bot", "assistant"]:
156-
messages.append(AIMessage(content=msg["content"]))
156+
tool_calls = msg.get("tool_calls")
157+
if tool_calls:
158+
messages.append(
159+
AIMessage(content=msg["content"], tool_calls=tool_calls)
160+
)
161+
else:
162+
messages.append(AIMessage(content=msg["content"]))
157163
elif msg_type == "system":
158164
messages.append(SystemMessage(content=msg["content"]))
159165
elif msg_type == "tool":
160-
messages.append(
161-
ToolMessage(
162-
content=msg["content"],
163-
tool_call_id=msg.get("tool_call_id", ""),
164-
)
166+
tool_message = ToolMessage(
167+
content=msg["content"],
168+
tool_call_id=msg.get("tool_call_id", ""),
165169
)
170+
if msg.get("name"):
171+
tool_message.name = msg["name"]
172+
messages.append(tool_message)
166173
else:
167174
raise ValueError(f"Unknown message type {msg_type}")
168175

@@ -674,16 +681,16 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:
674681

675682

676683
def extract_tool_calls_from_events(events: list) -> Optional[list]:
677-
"""Extract tool_calls from BotToolCall events.
684+
"""Extract tool_calls from BotToolCalls events.
678685
679686
Args:
680687
events: List of events to search through
681688
682689
Returns:
683-
tool_calls if found in BotToolCall event, None otherwise
690+
tool_calls if found in BotToolCalls event, None otherwise
684691
"""
685692
for event in events:
686-
if event.get("type") == "BotToolCall":
693+
if event.get("type") == "BotToolCalls":
687694
return event.get("tool_calls")
688695
return None
689696

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
BaseMessage,
2727
HumanMessage,
2828
SystemMessage,
29+
ToolMessage,
2930
)
3031
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
3132
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
@@ -231,11 +232,23 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
231232
def _message_to_dict(self, msg: BaseMessage) -> Dict[str, Any]:
232233
"""Convert a BaseMessage to dictionary format."""
233234
if isinstance(msg, AIMessage):
234-
return {"role": "assistant", "content": msg.content}
235+
result = {"role": "assistant", "content": msg.content}
236+
if hasattr(msg, "tool_calls") and msg.tool_calls:
237+
result["tool_calls"] = msg.tool_calls
238+
return result
235239
elif isinstance(msg, HumanMessage):
236240
return {"role": "user", "content": msg.content}
237241
elif isinstance(msg, SystemMessage):
238242
return {"role": "system", "content": msg.content}
243+
elif isinstance(msg, ToolMessage):
244+
result = {
245+
"role": "tool",
246+
"content": msg.content,
247+
"tool_call_id": msg.tool_call_id,
248+
}
249+
if hasattr(msg, "name") and msg.name:
250+
result["name"] = msg.name
251+
return result
239252
else: # Handle other message types
240253
role = getattr(msg, "type", "user")
241254
return {"role": role, "content": msg.content}

nemoguardrails/rails/llm/llm_flows.co

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ define parallel extension flow process bot tool call
106106
"""Processes tool calls from the bot."""
107107
priority 100
108108

109-
event BotToolCall
109+
event BotToolCalls
110110

111111
$tool_calls = $event.tool_calls
112112

@@ -130,6 +130,40 @@ define parallel extension flow process bot tool call
130130
create event StartToolCallBotAction(tool_calls=$tool_calls)
131131

132132

133+
define parallel flow process user tool messages
134+
"""Run all the tool input rails on the tool messages."""
135+
priority 200
136+
event UserToolMessages
137+
138+
$tool_messages = $event["tool_messages"]
139+
140+
# If we have tool input rails, we run them, otherwise we just create the user message event
141+
if $config.rails.tool_input.flows
142+
# If we have generation options, we make sure the tool input rails are enabled.
143+
$tool_input_enabled = True
144+
if $generation_options is not None
145+
if $generation_options.rails.tool_input == False
146+
$tool_input_enabled = False
147+
148+
if $tool_input_enabled:
149+
create event StartToolInputRails
150+
event StartToolInputRails
151+
152+
$i = 0
153+
while $i < len($tool_messages)
154+
$tool_message = $tool_messages[$i].content
155+
$tool_name = $tool_messages[$i].name
156+
if "tool_call_id" in $tool_messages[$i]
157+
$tool_call_id = $tool_messages[$i].tool_call_id
158+
else
159+
$tool_call_id = ""
160+
161+
do run tool input rails
162+
$i = $i + 1
163+
164+
create event ToolInputRailsFinished
165+
event ToolInputRailsFinished
166+
133167
define parallel extension flow process bot message
134168
"""Runs the output rails on a bot message."""
135169
priority 100
@@ -214,3 +248,24 @@ define subflow run tool output rails
214248

215249
# If all went smooth, we remove it.
216250
$triggered_tool_output_rail = None
251+
252+
define subflow run tool input rails
253+
"""Runs all the tool input rails in a sequential order."""
254+
$tool_input_flows = $config.rails.tool_input.flows
255+
256+
$i = 0
257+
while $i < len($tool_input_flows)
258+
# We set the current rail as being triggered.
259+
$triggered_tool_input_rail = $tool_input_flows[$i]
260+
261+
create event StartToolInputRail(flow_id=$triggered_tool_input_rail)
262+
event StartToolInputRail
263+
264+
do $tool_input_flows[$i]
265+
$i = $i + 1
266+
267+
create event ToolInputRailFinished(flow_id=$triggered_tool_input_rail)
268+
event ToolInputRailFinished
269+
270+
# If all went smooth, we remove it.
271+
$triggered_tool_input_rail = None

nemoguardrails/rails/llm/llmrails.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -747,26 +747,74 @@ def _get_events_for_messages(self, messages: List[dict], state: Any):
747747
)
748748

749749
elif msg["role"] == "assistant":
750-
action_uid = new_uuid()
751-
start_event = new_event_dict(
752-
"StartUtteranceBotAction",
753-
script=msg["content"],
754-
action_uid=action_uid,
755-
)
756-
finished_event = new_event_dict(
757-
"UtteranceBotActionFinished",
758-
final_script=msg["content"],
759-
is_success=True,
760-
action_uid=action_uid,
761-
)
762-
events.extend([start_event, finished_event])
750+
if msg.get("tool_calls"):
751+
events.append(
752+
{"type": "BotToolCalls", "tool_calls": msg["tool_calls"]}
753+
)
754+
else:
755+
action_uid = new_uuid()
756+
start_event = new_event_dict(
757+
"StartUtteranceBotAction",
758+
script=msg["content"],
759+
action_uid=action_uid,
760+
)
761+
finished_event = new_event_dict(
762+
"UtteranceBotActionFinished",
763+
final_script=msg["content"],
764+
is_success=True,
765+
action_uid=action_uid,
766+
)
767+
events.extend([start_event, finished_event])
763768
elif msg["role"] == "context":
764769
events.append({"type": "ContextUpdate", "data": msg["content"]})
765770
elif msg["role"] == "event":
766771
events.append(msg["event"])
767772
elif msg["role"] == "system":
768773
# Handle system messages - convert them to SystemMessage events
769774
events.append({"type": "SystemMessage", "content": msg["content"]})
775+
elif msg["role"] == "tool":
776+
# For the last tool message, create grouped tool event and synthetic UserMessage
777+
if idx == len(messages) - 1:
778+
# Find the original user message for response generation
779+
user_message = None
780+
for prev_msg in reversed(messages[:idx]):
781+
if prev_msg["role"] == "user":
782+
user_message = prev_msg["content"]
783+
break
784+
785+
if user_message:
786+
# If tool input rails are configured, group all tool messages
787+
if self.config.rails.tool_input.flows:
788+
# Collect all tool messages for grouped processing
789+
tool_messages = []
790+
for tool_idx in range(len(messages)):
791+
if messages[tool_idx]["role"] == "tool":
792+
tool_messages.append(
793+
{
794+
"content": messages[tool_idx][
795+
"content"
796+
],
797+
"name": messages[tool_idx].get(
798+
"name", "unknown"
799+
),
800+
"tool_call_id": messages[tool_idx].get(
801+
"tool_call_id", ""
802+
),
803+
}
804+
)
805+
806+
events.append(
807+
{
808+
"type": "UserToolMessages",
809+
"tool_messages": tool_messages,
810+
}
811+
)
812+
813+
else:
814+
events.append(
815+
{"type": "UserMessage", "text": user_message}
816+
)
817+
770818
else:
771819
for idx in range(len(messages)):
772820
msg = messages[idx]

0 commit comments

Comments
 (0)