Skip to content

Commit b7bc444

Browse files
committed
All but one error left to clean
1 parent 0622bd3 commit b7bc444

File tree

4 files changed

+74
-37
lines changed

4 files changed

+74
-37
lines changed

nemoguardrails/actions/llm/generation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,9 @@ async def generate_bot_message(
972972

973973
# We use the potentially updated $user_message. This means that even
974974
# in passthrough mode, input rails can still alter the input.
975-
prompt: Optional[str] = context.get(
975+
prompt: Optional[
976+
str
977+
] = context.get( # pyright: ignore (TODO Refactor these branches into separate methods)
976978
"user_message"
977979
) # pyright: ignore (TODO - refactor nested `prompt` definitions)
978980
if not prompt:

nemoguardrails/actions/v2_x/generation.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from langchain_core.language_models import BaseChatModel
2525
from langchain_core.language_models.llms import BaseLLM
26+
from langchain_text_splitters import ElementType
2627
from pytest_asyncio.plugin import event_loop
2728
from rich.text import Text
2829

@@ -215,7 +216,8 @@ async def _collect_user_intent_and_examples(
215216
# We add all currently active user intents (heads on match statements)
216217
heads = find_all_active_event_matchers(state)
217218
for head in heads:
218-
element = get_element_from_head(state, head)
219+
el = get_element_from_head(state, head)
220+
element = el if type(el) == SpecOp else SpecOp(**cast(Dict, el))
219221
flow_state = state.flow_states[head.flow_state_uid]
220222
event = get_event_from_element(state, flow_state, element)
221223
if (
@@ -235,9 +237,11 @@ async def _collect_user_intent_and_examples(
235237
):
236238
if flow_config.elements[1]["_type"] == "doc_string_stmt":
237239
examples += "user action: <" + (
238-
flow_config.elements[1]["elements"][0]["elements"][
239-
0
240+
flow_config.elements[1]["elements"][0][
241+
"elements"
240242
][ # pyright: ignore (TODO - Don't know where to even start with this line of code)
243+
0
244+
][
241245
"elements"
242246
][
243247
0
@@ -279,7 +283,7 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
279283
"""Generate the canonical form for what the user said i.e. user intent."""
280284

281285
# Use action specific llm if registered else fallback to main llm
282-
llm = llm or self.llm
286+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
283287

284288
log.info("Phase 1 :: Generating user intent")
285289
(
@@ -311,8 +315,8 @@ async def generate_user_intent( # pyright: ignore (TODO - Signature completely
311315
)
312316

313317
# We make this call with lowest temperature to have it as deterministic as possible.
314-
with llm_params(llm, temperature=self.config.lowest_temperature):
315-
result = await llm_call(llm, prompt, stop=stop)
318+
with llm_params(generation_llm, temperature=self.config.lowest_temperature):
319+
result = await llm_call(generation_llm, prompt, stop=stop)
316320

317321
# Parse the output using the associated parser
318322
result = self.llm_task_manager.parse_task_output(
@@ -356,7 +360,7 @@ async def generate_user_intent_and_bot_action(
356360
"""Generate the canonical form for what the user said i.e. user intent and a suitable bot action."""
357361

358362
# Use action specific llm if registered else fallback to main llm
359-
llm = llm or self.llm
363+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
360364

361365
log.info("Phase 1 :: Generating user intent and bot action")
362366

@@ -389,8 +393,8 @@ async def generate_user_intent_and_bot_action(
389393
)
390394

391395
# We make this call with lowest temperature to have it as deterministic as possible.
392-
with llm_params(llm, temperature=self.config.lowest_temperature):
393-
result = await llm_call(llm, prompt, stop=stop)
396+
with llm_params(generation_llm, temperature=self.config.lowest_temperature):
397+
result = await llm_call(generation_llm, prompt, stop=stop)
394398

395399
# Parse the output using the associated parser
396400
result = self.llm_task_manager.parse_task_output(
@@ -439,7 +443,12 @@ async def passthrough_llm_action(
439443
events: List[dict],
440444
llm: Optional[BaseLLM] = None,
441445
):
446+
if not llm:
447+
raise Exception("No LLM provided to passthrough LLM Action")
448+
442449
event = get_last_user_utterance_event_v2_x(events)
450+
if not event:
451+
raise Exception("Passthrough LLM Action couldn't find last user utterance")
443452

444453
# We check if we have a raw request. If the guardrails API is using
445454
# the `generate_events` API, this will not be set.
@@ -465,7 +474,10 @@ async def passthrough_llm_action(
465474
# Initialize the LLMCallInfo object
466475
llm_call_info_var.set(LLMCallInfo(task=Task.GENERAL.value))
467476

468-
generation_options: GenerationOptions = generation_options_var.get()
477+
generation_options: Optional[GenerationOptions] = generation_options_var.get()
478+
479+
streaming_handler: Optional[StreamingHandler] = streaming_handler_var.get()
480+
custom_callback_handlers = [streaming_handler] if streaming_handler else None
469481

470482
with llm_params(
471483
llm,
@@ -474,7 +486,7 @@ async def passthrough_llm_action(
474486
text = await llm_call(
475487
llm,
476488
user_message,
477-
custom_callback_handlers=[streaming_handler_var.get()],
489+
custom_callback_handlers=custom_callback_handlers,
478490
)
479491

480492
text = self.llm_task_manager.parse_task_output(Task.GENERAL, output=text)
@@ -526,12 +538,12 @@ async def generate_flow_from_instructions(
526538
raise RuntimeError("No instruction flows index has been created.")
527539

528540
# Use action specific llm if registered else fallback to main llm
529-
llm = llm or self.llm
541+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
530542

531543
log.info("Generating flow for instructions: %s", instructions)
532544

533545
results = await self.instruction_flows_index.search(
534-
text=instructions, max_results=5
546+
text=instructions, max_results=5, threshold=None
535547
)
536548

537549
examples = ""
@@ -557,8 +569,8 @@ async def generate_flow_from_instructions(
557569
)
558570

559571
# We make this call with temperature 0 to have it as deterministic as possible.
560-
with llm_params(llm, temperature=self.config.lowest_temperature):
561-
result = await llm_call(llm, prompt)
572+
with llm_params(generation_llm, temperature=self.config.lowest_temperature):
573+
result = await llm_call(generation_llm, prompt)
562574

563575
result = self.llm_task_manager.parse_task_output(
564576
task=Task.GENERATE_FLOW_FROM_INSTRUCTIONS, output=result
@@ -604,12 +616,15 @@ async def generate_flow_from_name(
604616
raise RuntimeError("No flows index has been created.")
605617

606618
# Use action specific llm if registered else fallback to main llm
607-
llm = llm or self.llm
619+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
608620

609621
log.info("Generating flow for name: {name}")
610622

623+
if not self.instruction_flows_index:
624+
raise Exception("No instruction flows index has been created.")
625+
611626
results = await self.instruction_flows_index.search(
612-
text=f"flow {name}", max_results=5
627+
text=f"flow {name}", max_results=5, threshold=None
613628
)
614629

615630
examples = ""
@@ -631,8 +646,8 @@ async def generate_flow_from_name(
631646
stop = self.llm_task_manager.get_stop_tokens(Task.GENERATE_FLOW_FROM_NAME)
632647

633648
# We make this call with temperature 0 to have it as deterministic as possible.
634-
with llm_params(llm, temperature=self.config.lowest_temperature):
635-
result = await llm_call(llm, prompt, stop)
649+
with llm_params(generation_llm, temperature=self.config.lowest_temperature):
650+
result = await llm_call(generation_llm, prompt, stop=stop)
636651

637652
result = self.llm_task_manager.parse_task_output(
638653
task=Task.GENERATE_FLOW_FROM_NAME, output=result
@@ -666,7 +681,7 @@ async def generate_flow_continuation(
666681
raise RuntimeError("No instruction flows index has been created.")
667682

668683
# Use action specific llm if registered else fallback to main llm
669-
llm = llm or self.llm
684+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
670685

671686
log.info("Generating flow continuation.")
672687

@@ -675,7 +690,11 @@ async def generate_flow_continuation(
675690
# We use the last line from the history to search for relevant flows
676691
search_text = colang_history.split("\n")[-1]
677692

678-
results = await self.flows_index.search(text=search_text, max_results=10)
693+
if self.flows_index is None:
694+
raise RuntimeError("No flows index has been created.")
695+
results = await self.flows_index.search(
696+
text=search_text, max_results=10, threshold=None
697+
)
679698

680699
examples = ""
681700
for result in reversed(results):
@@ -697,8 +716,8 @@ async def generate_flow_continuation(
697716
)
698717

699718
# We make this call with temperature 0 to have it as deterministic as possible.
700-
with llm_params(llm, temperature=temperature):
701-
result = await llm_call(llm, prompt)
719+
with llm_params(generation_llm, temperature=temperature):
720+
result = await llm_call(generation_llm, prompt)
702721

703722
# TODO: Currently, we only support generating a bot action as continuation. This could be generalized
704723
# Colang statements.
@@ -775,7 +794,7 @@ async def create_flow(
775794
}
776795

777796
@action(name="GenerateValueAction", is_system_action=True, execute_async=True)
778-
async def generate_value(
797+
async def generate_value( # pyright: ignore (TODO - different arguments to base-class)
779798
self,
780799
state: State,
781800
instructions: str,
@@ -791,15 +810,21 @@ async def generate_value(
791810
:param llm: Custom llm model to generate_value
792811
"""
793812
# Use action specific llm if registered else fallback to main llm
794-
llm = llm or self.llm
813+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
795814

796815
# We search for the most relevant flows.
797816
examples = ""
798817
if self.flows_index:
799-
if var_name:
800-
results = await self.flows_index.search(
801-
text=f"${var_name} = ", max_results=5
818+
results = (
819+
await self.flows_index.search(
820+
text=f"${var_name} = ", max_results=5, threshold=None
802821
)
822+
if var_name
823+
else None
824+
)
825+
826+
if not results:
827+
raise Exception("No results found while generating value")
803828

804829
# We add these in reverse order so the most relevant is towards the end.
805830
for result in reversed(results):
@@ -827,8 +852,8 @@ async def generate_value(
827852
Task.GENERATE_USER_INTENT_FROM_USER_ACTION
828853
)
829854

830-
with llm_params(llm, temperature=0.1):
831-
result = await llm_call(llm, prompt, stop)
855+
with llm_params(generation_llm, temperature=0.1):
856+
result = await llm_call(generation_llm, prompt, stop=stop)
832857

833858
# Parse the output using the associated parser
834859
result = self.llm_task_manager.parse_task_output(
@@ -871,11 +896,17 @@ async def generate_flow(
871896
) -> dict:
872897
"""Generate the body for a flow."""
873898
# Use action specific llm if registered else fallback to main llm
874-
llm = llm or self.llm
899+
generation_llm: Union[BaseLLM, BaseChatModel] = llm if llm else self.llm
875900

876901
triggering_flow_id = flow_id
902+
if not triggering_flow_id:
903+
raise Exception(
904+
f"No flow_id provided to generate flow."
905+
) # TODO! Should flow_id be mandatory?
877906

878907
flow_config = state.flow_configs[triggering_flow_id]
908+
if not flow_config.source_code:
909+
raise Exception(f"No source_code in flow_config {flow_config}")
879910
docstrings = re.findall(r'"""(.*?)"""', flow_config.source_code, re.DOTALL)
880911

881912
if len(docstrings) > 0:
@@ -897,6 +928,10 @@ async def generate_flow(
897928
for flow_config in state.flow_configs.values():
898929
if flow_config.decorators.get("meta", {}).get("tool") is True:
899930
# We get rid of the first line, which is the decorator
931+
932+
if not flow_config.source_code:
933+
raise Exception(f"No source_code in flow_config {flow_config}")
934+
900935
body = flow_config.source_code.split("\n", maxsplit=1)[1]
901936

902937
# We only need the part up to the docstring
@@ -936,8 +971,8 @@ async def generate_flow(
936971
Task.GENERATE_FLOW_CONTINUATION_FROM_NLD
937972
)
938973

939-
with llm_params(llm, temperature=self.config.lowest_temperature):
940-
result = await llm_call(llm, prompt, stop)
974+
with llm_params(generation_llm, temperature=self.config.lowest_temperature):
975+
result = await llm_call(generation_llm, prompt, stop=stop)
941976

942977
# Parse the output using the associated parser
943978
result = self.llm_task_manager.parse_task_output(

nemoguardrails/actions/validation/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
# limitations under the License.
1515
import json
1616
import re
17-
from typing import List
17+
from typing import List, Sequence
1818
from urllib.parse import quote
1919

2020
from nemoguardrails.actions.validation.filter_secrets import contains_secrets
2121

2222
MAX_LEN = 50
2323

2424

25-
def validate_input(attribute: str, validators: List[str] = (), **validation_args):
25+
def validate_input(attribute: str, validators: Sequence[str] = (), **validation_args):
2626
"""A generic decorator that can be used by any action (class method or function) for input validation.
2727
2828
Supported validation choices are: length and quote.

nemoguardrails/actions/validation/filter_secrets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def contains_secrets(resp):
2222
ArtifactoryDetector : False
2323
"""
2424
try:
25-
import detect_secrets
25+
import detect_secrets # pyright: ignore (Assume user installs detect_secrets with instructions below)
2626
except ModuleNotFoundError:
2727
raise ValueError(
2828
"Could not import detect_secrets. Please install using `pip install detect-secrets`"

0 commit comments

Comments
 (0)