1515import asyncio
1616import inspect
1717import logging
18+ import uuid
1819from textwrap import indent
1920from time import time
2021from typing import Any , Dict , List , Optional , Tuple
2425from langchain .chains .base import Chain
2526
2627from nemoguardrails .actions .actions import ActionResult
28+ from nemoguardrails .actions .core import create_event
2729from nemoguardrails .actions .output_mapping import is_output_blocked
2830from nemoguardrails .colang import parse_colang_file
2931from nemoguardrails .colang .runtime import Runtime
@@ -169,7 +171,7 @@ async def generate_events(
169171 next_events = await self ._process_start_action (events )
170172
171173 # If we need to start a flow, we parse the content and register it.
172- elif last_event ["type" ] == "start_flow" :
174+ elif last_event ["type" ] == "start_flow" and last_event . get ( "flow_body" ) :
173175 next_events = await self ._process_start_flow (
174176 events , processing_log = processing_log
175177 )
@@ -189,18 +191,30 @@ async def generate_events(
189191 new_events .extend (next_events )
190192
191193 for event in next_events :
192- processing_log .append (
193- {"type" : "event" , "timestamp" : time (), "data" : event }
194- )
194+ if event ["type" ] != "EventHistoryUpdate" :
195+ processing_log .append (
196+ {"type" : "event" , "timestamp" : time (), "data" : event }
197+ )
195198
196199 # If the next event is a listen, we stop the processing.
197200 if next_events [- 1 ]["type" ] == "Listen" :
198201 break
199202
200203 # As a safety measure, we stop the processing if we have too many events.
201- if len (new_events ) > 100 :
204+ if len (new_events ) > 300 :
202205 raise Exception ("Too many events." )
203206
207+ # Unpack and insert events in event history update event if available
208+ temp_events = []
209+ for event in new_events :
210+ if event ["type" ] == "EventHistoryUpdate" :
211+ temp_events .extend (
212+ [e for e in event ["data" ]["events" ] if e ["type" ] != "Listen" ]
213+ )
214+ else :
215+ temp_events .append (event )
216+ new_events = temp_events
217+
204218 return new_events
205219
206220 async def _compute_next_steps (
@@ -261,6 +275,210 @@ def _internal_error_action_result(message: str):
261275 ]
262276 )
263277
278+ async def _run_flows_in_parallel (
279+ self ,
280+ flows : List [str ],
281+ events : List [dict ],
282+ pre_events : Optional [List [dict ]] = None ,
283+ post_events : Optional [List [dict ]] = None ,
284+ ) -> ActionResult :
285+ """
286+ Run flows in parallel.
287+
288+ Running flows in parallel is done by triggering a separate event loop with a `start_flow` event for each flow, in the context of the current event loop.
289+
290+ Args:
291+ flows (List[str]): The list of flow names to run in parallel.
292+ events (List[dict]): The current events.
293+ pre_events (List[dict], optional): Events to be added before starting each flow.
294+ post_events (List[dict], optional): Events to be added after finishing each flow.
295+ """
296+
297+ if pre_events is not None and len (pre_events ) != len (flows ):
298+ raise ValueError ("Number of pre-events must match number of flows." )
299+ if post_events is not None and len (post_events ) != len (flows ):
300+ raise ValueError ("Number of post-events must match number of flows." )
301+
302+ unique_flow_ids = {} # Keep track of unique flow IDs order
303+ task_results : Dict [str , List ] = {} # Store results keyed by flow_id
304+ task_processing_logs : dict = {} # Store resulting processing logs for each flow
305+
306+ # Wrapper function to help reverse map the task result to the flow ID
307+ async def task_call_helper (flow_uid , post_event , func , * args , ** kwargs ):
308+ result = await func (* args , ** kwargs )
309+ if post_event :
310+ result .append (post_event )
311+ args [1 ].append (
312+ {"type" : "event" , "timestamp" : time (), "data" : post_event }
313+ )
314+ return flow_uid , result
315+
316+ # Create a task for each flow but don't await them yet
317+ tasks = []
318+ for index , flow_name in enumerate (flows ):
319+ # Copy the events to avoid modifying the original list
320+ _events = events .copy ()
321+
322+ flow_params = _get_flow_params (flow_name )
323+ flow_id = _normalize_flow_id (flow_name )
324+
325+ if flow_params :
326+ _events .append (
327+ {"type" : "start_flow" , "flow_id" : flow_id , "params" : flow_params }
328+ )
329+ else :
330+ _events .append ({"type" : "start_flow" , "flow_id" : flow_id })
331+
332+ # Generate a unique flow ID
333+ flow_uid = f"{ flow_id } :{ str (uuid .uuid4 ())} "
334+
335+ # Initialize task results and processing logs for this flow
336+ task_results [flow_uid ] = []
337+ task_processing_logs [flow_uid ] = []
338+
339+ # Add pre-event if provided
340+ if pre_events :
341+ task_results [flow_uid ].append (pre_events [index ])
342+ task_processing_logs [flow_uid ].append (
343+ {"type" : "event" , "timestamp" : time (), "data" : pre_events [index ]}
344+ )
345+
346+ task = asyncio .create_task (
347+ task_call_helper (
348+ flow_uid ,
349+ post_events [index ] if post_events else None ,
350+ self .generate_events ,
351+ _events ,
352+ task_processing_logs [flow_uid ],
353+ )
354+ )
355+ tasks .append (task )
356+ unique_flow_ids [flow_uid ] = task
357+
358+ stopped_task_results : List [dict ] = []
359+
360+ # Process tasks as they complete using as_completed
361+ try :
362+ for future in asyncio .as_completed (tasks ):
363+ try :
364+ (flow_id , result ) = await future
365+
366+ # Check if this rail requested to stop
367+ has_stop = any (
368+ event ["type" ] == "BotIntent" and event ["intent" ] == "stop"
369+ for event in result
370+ )
371+
372+ # If this flow had a stop event
373+ if has_stop :
374+ stopped_task_results = task_results [flow_id ] + result
375+
376+ # Cancel all remaining tasks
377+ for pending_task in tasks :
378+ # Don't include results and processing logs for cancelled or stopped tasks
379+ if (
380+ pending_task != unique_flow_ids [flow_id ]
381+ and not pending_task .done ()
382+ ):
383+ # Cancel the task if it is not done
384+ pending_task .cancel ()
385+ # Find the flow_uid for this task and remove it from the dict
386+ for k , v in list (unique_flow_ids .items ()):
387+ if v == pending_task :
388+ del unique_flow_ids [k ]
389+ break
390+ del unique_flow_ids [flow_id ]
391+ break
392+ else :
393+ # Store the result for this specific flow
394+ task_results [flow_id ].extend (result )
395+
396+ except asyncio .exceptions .CancelledError :
397+ pass
398+
399+ except Exception as e :
400+ log .error (f"Error in parallel rail execution: { str (e )} " )
401+ raise
402+
403+ context_updates : dict = {}
404+ processing_log = processing_log_var .get ()
405+
406+ finished_task_processing_logs : List [dict ] = [] # Collect all results in order
407+ finished_task_results : List [dict ] = [] # Collect all results in order
408+
409+ # Compose results in original flow order of all completed tasks
410+ for flow_id in unique_flow_ids :
411+ result = task_results [flow_id ]
412+
413+ # Extract context updates
414+ for event in result :
415+ if event ["type" ] == "ContextUpdate" :
416+ context_updates = {** context_updates , ** event ["data" ]}
417+
418+ finished_task_results .extend (result )
419+ finished_task_processing_logs .extend (task_processing_logs [flow_id ])
420+
421+ if processing_log :
422+ for plog in finished_task_processing_logs :
423+ # Filter out "Listen" and "start_flow" events from task processing log
424+ if plog ["type" ] == "event" and (
425+ plog ["data" ]["type" ] == "Listen"
426+ or plog ["data" ]["type" ] == "start_flow"
427+ ):
428+ continue
429+ processing_log .append (plog )
430+
431+ # We pack all events into a single event to add it to the event history.
432+ history_events = new_event_dict (
433+ "EventHistoryUpdate" ,
434+ data = {"events" : finished_task_results },
435+ )
436+
437+ return ActionResult (
438+ events = [history_events ] + stopped_task_results ,
439+ context_updates = context_updates ,
440+ )
441+
442+ async def _run_input_rails_in_parallel (
443+ self , flows : List [str ], events : List [dict ]
444+ ) -> ActionResult :
445+ """Run the input rails in parallel."""
446+ pre_events = [
447+ (await create_event ({"_type" : "StartInputRail" , "flow_id" : flow })).events [0 ]
448+ for flow in flows
449+ ]
450+ post_events = [
451+ (
452+ await create_event ({"_type" : "InputRailFinished" , "flow_id" : flow })
453+ ).events [0 ]
454+ for flow in flows
455+ ]
456+
457+ return await self ._run_flows_in_parallel (
458+ flows = flows , events = events , pre_events = pre_events , post_events = post_events
459+ )
460+
461+ async def _run_output_rails_in_parallel (
462+ self , flows : List [str ], events : List [dict ]
463+ ) -> ActionResult :
464+ """Run the output rails in parallel."""
465+ pre_events = [
466+ (await create_event ({"_type" : "StartOutputRail" , "flow_id" : flow })).events [
467+ 0
468+ ]
469+ for flow in flows
470+ ]
471+ post_events = [
472+ (
473+ await create_event ({"_type" : "OutputRailFinished" , "flow_id" : flow })
474+ ).events [0 ]
475+ for flow in flows
476+ ]
477+
478+ return await self ._run_flows_in_parallel (
479+ flows = flows , events = events , pre_events = pre_events , post_events = post_events
480+ )
481+
264482 async def _run_output_rails_in_parallel_streaming (
265483 self , flows_with_params : Dict [str , dict ], events : List [dict ]
266484 ) -> ActionResult :
@@ -472,15 +690,7 @@ async def _process_start_action(self, events: List[dict]) -> List[dict]:
472690 next_steps = []
473691
474692 if context_updates :
475- # We check if at least one key changed
476- changes = False
477- for k , v in context_updates .items ():
478- if context .get (k ) != v :
479- changes = True
480- break
481-
482- if changes :
483- next_steps .append (new_event_dict ("ContextUpdate" , data = context_updates ))
693+ next_steps .append (new_event_dict ("ContextUpdate" , data = context_updates ))
484694
485695 next_steps .append (
486696 new_event_dict (
0 commit comments