Skip to content

Commit bfe81f1

Browse files
authored
feat: implement parallel streaming output rails execution (#1263)
* feat: implement parallel streaming output rails execution - Add _run_output_rails_in_parallel_streaming method to run output rails concurrently - Use asyncio tasks to execute multiple rails simultaneously during streaming - Implement early termination when any rail blocks content to optimize performance - Register the new action in the runtime dispatcher - Add proper error handling and cancellation for robust parallel execution - Avoid full flow state management issues that can occur with hide_prev_turn logic during streaming - Add comprehensive tests for parallel streaming functionality * rename result to is_blocked
1 parent 0d6fa42 commit bfe81f1

File tree

5 files changed

+1457
-41
lines changed

5 files changed

+1457
-41
lines changed

nemoguardrails/colang/runtime.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@ def __init__(self, config: RailsConfig, verbose: bool = False):
3737
import_paths=list(config.imported_paths.values()),
3838
)
3939

40+
if hasattr(self, "_run_output_rails_in_parallel_streaming"):
41+
self.action_dispatcher.register_action(
42+
self._run_output_rails_in_parallel_streaming,
43+
name="run_output_rails_in_parallel_streaming",
44+
)
45+
4046
# The list of additional parameters that can be passed to the actions.
4147
self.registered_action_params: dict = {}
4248

nemoguardrails/colang/v1_0/runtime/runtime.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
15+
import asyncio
1616
import inspect
1717
import logging
18-
import uuid
1918
from textwrap import indent
2019
from time import time
2120
from typing import Any, Dict, List, Optional, Tuple
@@ -25,10 +24,13 @@
2524
from langchain.chains.base import Chain
2625

2726
from nemoguardrails.actions.actions import ActionResult
27+
from nemoguardrails.actions.output_mapping import is_output_blocked
2828
from nemoguardrails.colang import parse_colang_file
2929
from nemoguardrails.colang.runtime import Runtime
3030
from nemoguardrails.colang.v1_0.runtime.flows import (
3131
FlowConfig,
32+
_get_flow_params,
33+
_normalize_flow_id,
3234
compute_context,
3335
compute_next_steps,
3436
)
@@ -259,6 +261,89 @@ def _internal_error_action_result(message: str):
259261
]
260262
)
261263

264+
async def _run_output_rails_in_parallel_streaming(
265+
self, flows_with_params: Dict[str, dict], events: List[dict]
266+
) -> ActionResult:
267+
"""Run the output rails in parallel for streaming chunks.
268+
269+
This is a streamlined version that avoids the full flow state management
270+
which can cause issues with hide_prev_turn logic during streaming.
271+
272+
Args:
273+
flows_with_params: Dictionary mapping flow_id to {"action_name": str, "params": dict}
274+
events: The events list for context
275+
"""
276+
tasks = []
277+
278+
async def run_single_rail(flow_id: str, action_info: dict) -> tuple:
279+
"""Run a single rail flow and return (flow_id, result)"""
280+
281+
try:
282+
action_name = action_info["action_name"]
283+
params = action_info["params"]
284+
285+
result_tuple = await self.action_dispatcher.execute_action(
286+
action_name, params
287+
)
288+
result, status = result_tuple
289+
290+
if status != "success":
291+
log.error(f"Action {action_name} failed with status: {status}")
292+
return flow_id, False # Allow on failure
293+
294+
action_func = self.action_dispatcher.get_action(action_name)
295+
296+
# use the mapping to decide if the result indicates blocked content.
297+
# True means blocked, False means allowed
298+
result = is_output_blocked(result, action_func)
299+
300+
return flow_id, result
301+
302+
except Exception as e:
303+
log.error(f"Error executing rail {flow_id}: {e}")
304+
return flow_id, False # Allow on error
305+
306+
# create tasks for all flows
307+
for flow_id, action_info in flows_with_params.items():
308+
task = asyncio.create_task(run_single_rail(flow_id, action_info))
309+
tasks.append(task)
310+
311+
stopped_events = []
312+
313+
try:
314+
for future in asyncio.as_completed(tasks):
315+
try:
316+
flow_id, is_blocked = await future
317+
318+
# check if this rail blocked the content
319+
if is_blocked:
320+
# create stop events
321+
stopped_events = [
322+
{
323+
"type": "BotIntent",
324+
"intent": "stop",
325+
"flow_id": flow_id,
326+
}
327+
]
328+
329+
# cancel remaining tasks
330+
for pending_task in tasks:
331+
if not pending_task.done():
332+
pending_task.cancel()
333+
break
334+
335+
except asyncio.CancelledError:
336+
pass
337+
except Exception as e:
338+
log.error(f"Error in parallel rail task: {e}")
339+
continue
340+
341+
except Exception as e:
342+
log.error(f"Error in parallel rail execution: {e}")
343+
return ActionResult(events=[])
344+
345+
return ActionResult(events=stopped_events)
346+
262347
async def _process_start_action(self, events: List[dict]) -> List[dict]:
263348
"""
264349
Start the specified action, wait for it to finish, and post back the result.
@@ -458,8 +543,9 @@ async def _get_action_resp(
458543
)
459544

460545
resp = await resp.json()
461-
result, status = resp.get("result", result), resp.get(
462-
"status", status
546+
result, status = (
547+
resp.get("result", result),
548+
resp.get("status", status),
463549
)
464550
except Exception as e:
465551
log.info(f"Exception {e} while making request to {action_name}")

nemoguardrails/rails/llm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,11 @@ class OutputRailsStreamingConfig(BaseModel):
455455
class OutputRails(BaseModel):
456456
"""Configuration of output rails."""
457457

458+
parallel: Optional[bool] = Field(
459+
default=False,
460+
description="If True, the output rails are executed in parallel.",
461+
)
462+
458463
flows: List[str] = Field(
459464
default_factory=list,
460465
description="The names of all the flows that implement output rails.",

nemoguardrails/rails/llm/llmrails.py

Lines changed: 135 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
from nemoguardrails.logging.verbose import set_verbose
6767
from nemoguardrails.patch_asyncio import check_sync_call_from_async_loop
6868
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
69-
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, Model, RailsConfig
69+
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, RailsConfig
7070
from nemoguardrails.rails.llm.options import (
7171
GenerationLog,
7272
GenerationOptions,
@@ -1351,6 +1351,32 @@ def _get_latest_user_message(
13511351
return message
13521352
return {}
13531353

1354+
def _prepare_context_for_parallel_rails(
1355+
chunk_str: str,
1356+
prompt: Optional[str] = None,
1357+
messages: Optional[List[dict]] = None,
1358+
) -> dict:
1359+
"""Prepare context for parallel rails execution."""
1360+
context_message = _get_last_context_message(messages)
1361+
user_message = prompt or _get_latest_user_message(messages)
1362+
1363+
context = {
1364+
"user_message": user_message,
1365+
"bot_message": chunk_str,
1366+
}
1367+
1368+
if context_message:
1369+
context.update(context_message["content"])
1370+
1371+
return context
1372+
1373+
def _create_events_for_chunk(chunk_str: str, context: dict) -> List[dict]:
1374+
"""Create events for running output rails on a chunk."""
1375+
return [
1376+
{"type": "ContextUpdate", "data": context},
1377+
{"type": "BotMessage", "text": chunk_str},
1378+
]
1379+
13541380
def _prepare_params(
13551381
flow_id: str,
13561382
action_name: str,
@@ -1404,6 +1430,8 @@ def _prepare_params(
14041430
_get_action_details_from_flow_id, flows=self.config.flows
14051431
)
14061432

1433+
parallel_mode = getattr(self.config.rails.output, "parallel", False)
1434+
14071435
async for chunk_batch in buffer_strategy(streaming_handler):
14081436
user_output_chunks = chunk_batch.user_output_chunks
14091437
# format processing_context for output rails processing (needs full context)
@@ -1427,48 +1455,118 @@ def _prepare_params(
14271455
for chunk in user_output_chunks:
14281456
yield chunk
14291457

1430-
for flow_id in output_rails_flows_id:
1431-
action_name, action_params = get_action_details(flow_id)
1458+
if parallel_mode:
1459+
try:
1460+
context = _prepare_context_for_parallel_rails(
1461+
bot_response_chunk, prompt, messages
1462+
)
1463+
events = _create_events_for_chunk(bot_response_chunk, context)
1464+
1465+
flows_with_params = {}
1466+
for flow_id in output_rails_flows_id:
1467+
action_name, action_params = get_action_details(flow_id)
1468+
params = _prepare_params(
1469+
flow_id=flow_id,
1470+
action_name=action_name,
1471+
bot_response_chunk=bot_response_chunk,
1472+
prompt=prompt,
1473+
messages=messages,
1474+
action_params=action_params,
1475+
)
1476+
flows_with_params[flow_id] = {
1477+
"action_name": action_name,
1478+
"params": params,
1479+
}
1480+
1481+
result_tuple = await self.runtime.action_dispatcher.execute_action(
1482+
"run_output_rails_in_parallel_streaming",
1483+
{
1484+
"flows_with_params": flows_with_params,
1485+
"events": events,
1486+
},
1487+
)
14321488

1433-
params = _prepare_params(
1434-
flow_id=flow_id,
1435-
action_name=action_name,
1436-
bot_response_chunk=bot_response_chunk,
1437-
prompt=prompt,
1438-
messages=messages,
1439-
action_params=action_params,
1440-
)
1489+
# ActionDispatcher.execute_action always returns (result, status)
1490+
result, status = result_tuple
14411491

1442-
result = await self.runtime.action_dispatcher.execute_action(
1443-
action_name, params
1444-
)
1492+
if status != "success":
1493+
log.error(
1494+
f"Parallel rails execution failed with status: {status}"
1495+
)
1496+
# continue processing the chunk even if rails fail
1497+
pass
1498+
else:
1499+
# if there are any stop events, content was blocked
1500+
if result.events:
1501+
# extract the blocked flow from the first stop event
1502+
blocked_flow = result.events[0].get(
1503+
"flow_id", "output rails"
1504+
)
1505+
1506+
reason = f"Blocked by {blocked_flow} rails."
1507+
error_data = {
1508+
"error": {
1509+
"message": reason,
1510+
"type": "guardrails_violation",
1511+
"param": blocked_flow,
1512+
"code": "content_blocked",
1513+
}
1514+
}
1515+
yield json.dumps(error_data)
1516+
return
1517+
1518+
except Exception as e:
1519+
log.error(f"Error in parallel rail execution: {e}")
1520+
# don't block the stream for rail execution errors
1521+
# continue processing the chunk
1522+
pass
1523+
1524+
# update explain info for parallel mode
14451525
self.explain_info = self._ensure_explain_info()
14461526

1447-
action_func = self.runtime.action_dispatcher.get_action(action_name)
1448-
1449-
# Use the mapping to decide if the result indicates blocked content.
1450-
if is_output_blocked(result, action_func):
1451-
reason = f"Blocked by {flow_id} rails."
1452-
1453-
# return the error as a plain JSON string (not in SSE format)
1454-
# NOTE: When integrating with the OpenAI Python client, the server code should:
1455-
# 1. detect this JSON error object in the stream
1456-
# 2. terminate the stream
1457-
# 3. format the error following OpenAI's SSE format
1458-
# the OpenAI client will then properly raise an APIError with this error message
1459-
1460-
error_data = {
1461-
"error": {
1462-
"message": reason,
1463-
"type": "guardrails_violation",
1464-
"param": flow_id,
1465-
"code": "content_blocked",
1527+
else:
1528+
for flow_id in output_rails_flows_id:
1529+
action_name, action_params = get_action_details(flow_id)
1530+
1531+
params = _prepare_params(
1532+
flow_id=flow_id,
1533+
action_name=action_name,
1534+
bot_response_chunk=bot_response_chunk,
1535+
prompt=prompt,
1536+
messages=messages,
1537+
action_params=action_params,
1538+
)
1539+
1540+
result = await self.runtime.action_dispatcher.execute_action(
1541+
action_name, params
1542+
)
1543+
self.explain_info = self._ensure_explain_info()
1544+
1545+
action_func = self.runtime.action_dispatcher.get_action(action_name)
1546+
1547+
# Use the mapping to decide if the result indicates blocked content.
1548+
if is_output_blocked(result, action_func):
1549+
reason = f"Blocked by {flow_id} rails."
1550+
1551+
# return the error as a plain JSON string (not in SSE format)
1552+
# NOTE: When integrating with the OpenAI Python client, the server code should:
1553+
# 1. detect this JSON error object in the stream
1554+
# 2. terminate the stream
1555+
# 3. format the error following OpenAI's SSE format
1556+
# the OpenAI client will then properly raise an APIError with this error message
1557+
1558+
error_data = {
1559+
"error": {
1560+
"message": reason,
1561+
"type": "guardrails_violation",
1562+
"param": flow_id,
1563+
"code": "content_blocked",
1564+
}
14661565
}
1467-
}
14681566

1469-
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error
1470-
yield json.dumps(error_data)
1471-
return
1567+
# return as plain JSON: the server should detect this JSON and convert it to an HTTP error
1568+
yield json.dumps(error_data)
1569+
return
14721570

14731571
if not stream_first:
14741572
# yield the individual chunks directly from the buffer strategy

0 commit comments

Comments
 (0)