2323
2424from langchain_core .language_models import BaseChatModel
2525from langchain_core .language_models .llms import BaseLLM
26+ from langchain_text_splitters import ElementType
2627from pytest_asyncio .plugin import event_loop
2728from rich .text import Text
2829
@@ -215,7 +216,8 @@ async def _collect_user_intent_and_examples(
215216 # We add all currently active user intents (heads on match statements)
216217 heads = find_all_active_event_matchers (state )
217218 for head in heads :
218- element = get_element_from_head (state , head )
219+ el = get_element_from_head (state , head )
220+ element = el if type (el ) == SpecOp else SpecOp (** cast (Dict , el ))
219221 flow_state = state .flow_states [head .flow_state_uid ]
220222 event = get_event_from_element (state , flow_state , element )
221223 if (
@@ -235,9 +237,11 @@ async def _collect_user_intent_and_examples(
235237 ):
236238 if flow_config .elements [1 ]["_type" ] == "doc_string_stmt" :
237239 examples += "user action: <" + (
238- flow_config .elements [1 ]["elements" ][0 ]["elements" ][
239- 0
240+ flow_config .elements [1 ]["elements" ][0 ][
241+ "elements"
240242 ][ # pyright: ignore (TODO - Don't know where to even start with this line of code)
243+ 0
244+ ][
241245 "elements"
242246 ][
243247 0
@@ -279,7 +283,7 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
279283 """Generate the canonical form for what the user said i.e. user intent."""
280284
281285 # Use action specific llm if registered else fallback to main llm
282- llm = llm or self .llm
286+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
283287
284288 log .info ("Phase 1 :: Generating user intent" )
285289 (
@@ -311,8 +315,8 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
311315 )
312316
313317 # We make this call with lowest temperature to have it as deterministic as possible.
314- with llm_params (llm , temperature = self .config .lowest_temperature ):
315- result = await llm_call (llm , prompt , stop = stop )
318+ with llm_params (generation_llm , temperature = self .config .lowest_temperature ):
319+ result = await llm_call (generation_llm , prompt , stop = stop )
316320
317321 # Parse the output using the associated parser
318322 result = self .llm_task_manager .parse_task_output (
@@ -356,7 +360,7 @@ async def generate_user_intent_and_bot_action(
356360 """Generate the canonical form for what the user said i.e. user intent and a suitable bot action."""
357361
358362 # Use action specific llm if registered else fallback to main llm
359- llm = llm or self .llm
363+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
360364
361365 log .info ("Phase 1 :: Generating user intent and bot action" )
362366
@@ -389,8 +393,8 @@ async def generate_user_intent_and_bot_action(
389393 )
390394
391395 # We make this call with lowest temperature to have it as deterministic as possible.
392- with llm_params (llm , temperature = self .config .lowest_temperature ):
393- result = await llm_call (llm , prompt , stop = stop )
396+ with llm_params (generation_llm , temperature = self .config .lowest_temperature ):
397+ result = await llm_call (generation_llm , prompt , stop = stop )
394398
395399 # Parse the output using the associated parser
396400 result = self .llm_task_manager .parse_task_output (
@@ -439,7 +443,12 @@ async def passthrough_llm_action(
439443 events : List [dict ],
440444 llm : Optional [BaseLLM ] = None ,
441445 ):
446+ if not llm :
447+ raise Exception ("No LLM provided to passthrough LLM Action" )
448+
442449 event = get_last_user_utterance_event_v2_x (events )
450+ if not event :
451+ raise Exception ("Passthrough LLM Action couldn't find last user utterance" )
443452
444453 # We check if we have a raw request. If the guardrails API is using
445454 # the `generate_events` API, this will not be set.
@@ -465,7 +474,10 @@ async def passthrough_llm_action(
465474 # Initialize the LLMCallInfo object
466475 llm_call_info_var .set (LLMCallInfo (task = Task .GENERAL .value ))
467476
468- generation_options : GenerationOptions = generation_options_var .get ()
477+ generation_options : Optional [GenerationOptions ] = generation_options_var .get ()
478+
479+ streaming_handler : Optional [StreamingHandler ] = streaming_handler_var .get ()
480+ custom_callback_handlers = [streaming_handler ] if streaming_handler else None
469481
470482 with llm_params (
471483 llm ,
@@ -474,7 +486,7 @@ async def passthrough_llm_action(
474486 text = await llm_call (
475487 llm ,
476488 user_message ,
477- custom_callback_handlers = [ streaming_handler_var . get ()] ,
489+ custom_callback_handlers = custom_callback_handlers ,
478490 )
479491
480492 text = self .llm_task_manager .parse_task_output (Task .GENERAL , output = text )
@@ -526,12 +538,12 @@ async def generate_flow_from_instructions(
526538 raise RuntimeError ("No instruction flows index has been created." )
527539
528540 # Use action specific llm if registered else fallback to main llm
529- llm = llm or self .llm
541+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
530542
531543 log .info ("Generating flow for instructions: %s" , instructions )
532544
533545 results = await self .instruction_flows_index .search (
534- text = instructions , max_results = 5
546+ text = instructions , max_results = 5 , threshold = None
535547 )
536548
537549 examples = ""
@@ -557,8 +569,8 @@ async def generate_flow_from_instructions(
557569 )
558570
559571 # We make this call with temperature 0 to have it as deterministic as possible.
560- with llm_params (llm , temperature = self .config .lowest_temperature ):
561- result = await llm_call (llm , prompt )
572+ with llm_params (generation_llm , temperature = self .config .lowest_temperature ):
573+ result = await llm_call (generation_llm , prompt )
562574
563575 result = self .llm_task_manager .parse_task_output (
564576 task = Task .GENERATE_FLOW_FROM_INSTRUCTIONS , output = result
@@ -604,12 +616,15 @@ async def generate_flow_from_name(
604616 raise RuntimeError ("No flows index has been created." )
605617
606618 # Use action specific llm if registered else fallback to main llm
607- llm = llm or self .llm
619+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
608620
609621 log .info ("Generating flow for name: {name}" )
610622
623+ if not self .instruction_flows_index :
624+ raise Exception ("No instruction flows index has been created." )
625+
611626 results = await self .instruction_flows_index .search (
612- text = f"flow { name } " , max_results = 5
627+ text = f"flow { name } " , max_results = 5 , threshold = None
613628 )
614629
615630 examples = ""
@@ -631,8 +646,8 @@ async def generate_flow_from_name(
631646 stop = self .llm_task_manager .get_stop_tokens (Task .GENERATE_FLOW_FROM_NAME )
632647
633648 # We make this call with temperature 0 to have it as deterministic as possible.
634- with llm_params (llm , temperature = self .config .lowest_temperature ):
635- result = await llm_call (llm , prompt , stop )
649+ with llm_params (generation_llm , temperature = self .config .lowest_temperature ):
650+ result = await llm_call (generation_llm , prompt , stop = stop )
636651
637652 result = self .llm_task_manager .parse_task_output (
638653 task = Task .GENERATE_FLOW_FROM_NAME , output = result
@@ -666,7 +681,7 @@ async def generate_flow_continuation(
666681 raise RuntimeError ("No instruction flows index has been created." )
667682
668683 # Use action specific llm if registered else fallback to main llm
669- llm = llm or self .llm
684+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
670685
671686 log .info ("Generating flow continuation." )
672687
@@ -675,7 +690,11 @@ async def generate_flow_continuation(
675690 # We use the last line from the history to search for relevant flows
676691 search_text = colang_history .split ("\n " )[- 1 ]
677692
678- results = await self .flows_index .search (text = search_text , max_results = 10 )
693+ if self .flows_index is None :
694+ raise RuntimeError ("No flows index has been created." )
695+ results = await self .flows_index .search (
696+ text = search_text , max_results = 10 , threshold = None
697+ )
679698
680699 examples = ""
681700 for result in reversed (results ):
@@ -697,8 +716,8 @@ async def generate_flow_continuation(
697716 )
698717
699718 # We make this call with temperature 0 to have it as deterministic as possible.
700- with llm_params (llm , temperature = temperature ):
701- result = await llm_call (llm , prompt )
719+ with llm_params (generation_llm , temperature = temperature ):
720+ result = await llm_call (generation_llm , prompt )
702721
703722 # TODO: Currently, we only support generating a bot action as continuation. This could be generalized
704723 # Colang statements.
@@ -775,7 +794,7 @@ async def create_flow(
775794 }
776795
777796 @action (name = "GenerateValueAction" , is_system_action = True , execute_async = True )
778- async def generate_value (
797+ async def generate_value ( # pyright: ignore (TODO - different arguments to base-class)
779798 self ,
780799 state : State ,
781800 instructions : str ,
@@ -791,15 +810,21 @@ async def generate_value(
791810 :param llm: Custom llm model to generate_value
792811 """
793812 # Use action specific llm if registered else fallback to main llm
794- llm = llm or self .llm
813+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
795814
796815 # We search for the most relevant flows.
797816 examples = ""
798817 if self .flows_index :
799- if var_name :
800- results = await self .flows_index .search (
801- text = f"${ var_name } = " , max_results = 5
818+ results = (
819+ await self .flows_index .search (
820+ text = f"${ var_name } = " , max_results = 5 , threshold = None
802821 )
822+ if var_name
823+ else None
824+ )
825+
826+ if not results :
827+ raise Exception ("No results found while generating value" )
803828
804829 # We add these in reverse order so the most relevant is towards the end.
805830 for result in reversed (results ):
@@ -827,8 +852,8 @@ async def generate_value(
827852 Task .GENERATE_USER_INTENT_FROM_USER_ACTION
828853 )
829854
830- with llm_params (llm , temperature = 0.1 ):
831- result = await llm_call (llm , prompt , stop )
855+ with llm_params (generation_llm , temperature = 0.1 ):
856+ result = await llm_call (generation_llm , prompt , stop = stop )
832857
833858 # Parse the output using the associated parser
834859 result = self .llm_task_manager .parse_task_output (
@@ -871,11 +896,17 @@ async def generate_flow(
871896 ) -> dict :
872897 """Generate the body for a flow."""
873898 # Use action specific llm if registered else fallback to main llm
874- llm = llm or self .llm
899+ generation_llm : Union [ BaseLLM , BaseChatModel ] = llm if llm else self .llm
875900
876901 triggering_flow_id = flow_id
902+ if not triggering_flow_id :
903+ raise Exception (
904+ f"No flow_id provided to generate flow."
905+ ) # TODO! Should flow_id be mandatory?
877906
878907 flow_config = state .flow_configs [triggering_flow_id ]
908+ if not flow_config .source_code :
909+ raise Exception (f"No source_code in flow_config { flow_config } " )
879910 docstrings = re .findall (r'"""(.*?)"""' , flow_config .source_code , re .DOTALL )
880911
881912 if len (docstrings ) > 0 :
@@ -897,6 +928,10 @@ async def generate_flow(
897928 for flow_config in state .flow_configs .values ():
898929 if flow_config .decorators .get ("meta" , {}).get ("tool" ) is True :
899930 # We get rid of the first line, which is the decorator
931+
932+ if not flow_config .source_code :
933+ raise Exception (f"No source_code in flow_config { flow_config } " )
934+
900935 body = flow_config .source_code .split ("\n " , maxsplit = 1 )[1 ]
901936
902937 # We only need the part up to the docstring
@@ -936,8 +971,8 @@ async def generate_flow(
936971 Task .GENERATE_FLOW_CONTINUATION_FROM_NLD
937972 )
938973
939- with llm_params (llm , temperature = self .config .lowest_temperature ):
940- result = await llm_call (llm , prompt , stop )
974+ with llm_params (generation_llm , temperature = self .config .lowest_temperature ):
975+ result = await llm_call (generation_llm , prompt , stop = stop )
941976
942977 # Parse the output using the associated parser
943978 result = self .llm_task_manager .parse_task_output (
0 commit comments