Skip to content

Commit 5e46e5a

Browse files
Pouyanpitgasser-nv
authored andcommitted
review: type hint fixes
fix
1 parent 483c547 commit 5e46e5a

File tree

2 files changed

+72
-66
lines changed

2 files changed

+72
-66
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,19 @@ async def generate_async(
776776
The completion (when a prompt is provided) or the next message.
777777
778778
System messages are not yet supported."""
779+
# convert options to gen_options of type GenerationOptions
780+
gen_options: Optional[GenerationOptions] = None
781+
782+
if prompt is None and messages is None:
783+
raise ValueError("Either prompt or messages must be provided.")
784+
785+
if prompt is not None and messages is not None:
786+
raise ValueError("Only one of prompt or messages can be provided.")
787+
788+
if prompt is not None:
789+
# Currently, we transform the prompt request into a single turn conversation
790+
messages = [{"role": "user", "content": prompt}]
791+
779792
# If a state object is specified, then we switch to "generation options" mode.
780793
# This is because we want the output to be a GenerationResponse which will contain
781794
# the output state.
@@ -785,15 +798,25 @@ async def generate_async(
785798
state = json_to_state(state["state"])
786799

787800
if options is None:
788-
options = GenerationOptions()
789-
790-
# We allow options to be specified both as a dict and as an object.
791-
if options and isinstance(options, dict):
792-
options = GenerationOptions(**options)
801+
gen_options = GenerationOptions()
802+
elif isinstance(options, dict):
803+
gen_options = GenerationOptions(**options)
804+
else:
805+
gen_options = options
806+
else:
807+
# We allow options to be specified both as a dict and as an object.
808+
if options and isinstance(options, dict):
809+
gen_options = GenerationOptions(**options)
810+
elif isinstance(options, GenerationOptions):
811+
gen_options = options
812+
elif options is None:
813+
gen_options = None
814+
else:
815+
raise TypeError("options must be a dict or GenerationOptions")
793816

794817
# Save the generation options in the current async context.
795-
# At this point, options is either None or GenerationOptions
796-
generation_options_var.set(options if not isinstance(options, dict) else None)
818+
# At this point, gen_options is either None or GenerationOptions
819+
generation_options_var.set(gen_options)
797820

798821
if streaming_handler:
799822
streaming_handler_var.set(streaming_handler)
@@ -803,23 +826,14 @@ async def generate_async(
803826
# requests are made.
804827
self.explain_info = self._ensure_explain_info()
805828

806-
if prompt is not None:
807-
# Currently, we transform the prompt request into a single turn conversation
808-
messages = [{"role": "user", "content": prompt}]
809-
raw_llm_request.set(prompt)
810-
else:
811-
raw_llm_request.set(messages)
829+
raw_llm_request.set(messages)
812830

813831
# If we have generation options, we also add them to the context
814-
if options:
832+
if gen_options:
815833
messages = [
816834
{
817835
"role": "context",
818-
"content": {
819-
"generation_options": getattr(
820-
options, "dict", lambda: options
821-
)()
822-
},
836+
"content": {"generation_options": gen_options.model_dump()},
823837
}
824838
] + (messages or [])
825839

@@ -848,7 +862,7 @@ async def generate_async(
848862
processing_log = []
849863

850864
# The array of events corresponding to the provided sequence of messages.
851-
events = self._get_events_for_messages(messages or [], state)
865+
events = self._get_events_for_messages(messages, state) # type: ignore
852866

853867
if self.config.colang_version == "1.0":
854868
# If we had a state object, we also need to prepend the events from the state.
@@ -967,7 +981,7 @@ async def generate_async(
967981
# If a state object is not used, then we use the implicit caching
968982
if state is None:
969983
# Save the new events in the history and update the cache
970-
cache_key = get_history_cache_key((messages or []) + [new_message])
984+
cache_key = get_history_cache_key((messages) + [new_message]) # type: ignore
971985
self.events_history_cache[cache_key] = events
972986
else:
973987
output_state = {"events": events}
@@ -995,33 +1009,29 @@ async def generate_async(
9951009
# IF tracing is enabled we need to set GenerationLog attrs
9961010
original_log_options = None
9971011
if self.config.tracing.enabled:
998-
if options is None:
999-
options = GenerationOptions()
1012+
if gen_options is None:
1013+
gen_options = GenerationOptions()
10001014
else:
1001-
# create a copy of the options to avoid modifying the original
1002-
if isinstance(options, GenerationOptions):
1003-
options = options.model_copy(deep=True)
1004-
else:
1005-
# If options is a dict, convert it to GenerationOptions
1006-
options = GenerationOptions(**options)
1007-
original_log_options = options.log.model_copy(deep=True)
1015+
# create a copy of the gen_options to avoid modifying the original
1016+
gen_options = gen_options.model_copy(deep=True)
1017+
original_log_options = gen_options.log.model_copy(deep=True)
10081018

10091019
# enable log options
10101020
# it is aggressive, but these are required for tracing
10111021
if (
1012-
not options.log.activated_rails
1013-
or not options.log.llm_calls
1014-
or not options.log.internal_events
1022+
not gen_options.log.activated_rails
1023+
or not gen_options.log.llm_calls
1024+
or not gen_options.log.internal_events
10151025
):
1016-
options.log.activated_rails = True
1017-
options.log.llm_calls = True
1018-
options.log.internal_events = True
1026+
gen_options.log.activated_rails = True
1027+
gen_options.log.llm_calls = True
1028+
gen_options.log.internal_events = True
10191029

10201030
tool_calls = extract_tool_calls_from_events(new_events)
10211031
llm_metadata = get_and_clear_response_metadata_contextvar()
10221032

10231033
# If we have generation options, we prepare a GenerationResponse instance.
1024-
if options:
1034+
if gen_options:
10251035
# If a prompt was used, we only need to return the content of the message.
10261036
if prompt:
10271037
res = GenerationResponse(response=new_message["content"])
@@ -1048,9 +1058,9 @@ async def generate_async(
10481058

10491059
if self.config.colang_version == "1.0":
10501060
# If output variables are specified, we extract their values
1051-
if getattr(options, "output_vars", None):
1061+
if gen_options and gen_options.output_vars:
10521062
context = compute_context(events)
1053-
output_vars = getattr(options, "output_vars", None)
1063+
output_vars = gen_options.output_vars
10541064
if isinstance(output_vars, list):
10551065
# If we have only a selection of keys, we filter to only that.
10561066
res.output_data = {k: context.get(k) for k in output_vars}
@@ -1061,65 +1071,64 @@ async def generate_async(
10611071
_log = compute_generation_log(processing_log)
10621072

10631073
# Include information about activated rails and LLM calls if requested
1064-
log_options = getattr(options, "log", None)
1074+
log_options = gen_options.log if gen_options else None
10651075
if log_options and (
1066-
getattr(log_options, "activated_rails", False)
1067-
or getattr(log_options, "llm_calls", False)
1076+
log_options.activated_rails or log_options.llm_calls
10681077
):
10691078
res.log = GenerationLog()
10701079

10711080
# We always include the stats
10721081
res.log.stats = _log.stats
10731082

1074-
if getattr(log_options, "activated_rails", False):
1083+
if log_options.activated_rails:
10751084
res.log.activated_rails = _log.activated_rails
10761085

1077-
if getattr(log_options, "llm_calls", False):
1086+
if log_options.llm_calls:
10781087
res.log.llm_calls = []
10791088
for activated_rail in _log.activated_rails:
10801089
for executed_action in activated_rail.executed_actions:
10811090
res.log.llm_calls.extend(executed_action.llm_calls)
10821091

10831092
# Include internal events if requested
1084-
if getattr(log_options, "internal_events", False):
1093+
if log_options and log_options.internal_events:
10851094
if res.log is None:
10861095
res.log = GenerationLog()
10871096

10881097
res.log.internal_events = new_events
10891098

10901099
# Include the Colang history if requested
1091-
if getattr(log_options, "colang_history", False):
1100+
if log_options and log_options.colang_history:
10921101
if res.log is None:
10931102
res.log = GenerationLog()
10941103

10951104
res.log.colang_history = get_colang_history(events)
10961105

10971106
# Include the raw llm output if requested
1098-
if getattr(options, "llm_output", False):
1107+
if gen_options and gen_options.llm_output:
10991108
# Currently, we include the output from the generation LLM calls.
11001109
for activated_rail in _log.activated_rails:
11011110
if activated_rail.type == "generation":
11021111
for executed_action in activated_rail.executed_actions:
11031112
for llm_call in executed_action.llm_calls:
11041113
res.llm_output = llm_call.raw_response
11051114
else:
1106-
if getattr(options, "output_vars", None):
1115+
if gen_options and gen_options.output_vars:
11071116
raise ValueError(
11081117
"The `output_vars` option is not supported for Colang 2.0 configurations."
11091118
)
11101119

1111-
log_options = getattr(options, "log", None)
1120+
log_options = gen_options.log if gen_options else None
11121121
if log_options and (
1113-
getattr(log_options, "activated_rails", False)
1114-
or getattr(log_options, "llm_calls", False)
1115-
or getattr(log_options, "internal_events", False)
1116-
or getattr(log_options, "colang_history", False)
1122+
log_options.activated_rails
1123+
or log_options.llm_calls
1124+
or log_options.internal_events
1125+
or log_options.colang_history
11171126
):
11181127
raise ValueError(
11191128
"The `log` option is not supported for Colang 2.0 configurations."
11201129
)
11211130

1122-
if getattr(options, "llm_output", False):
1131+
if gen_options and gen_options.llm_output:
11231132
raise ValueError(
11241133
"The `llm_output` option is not supported for Colang 2.0 configurations."
11251134
)
@@ -1153,25 +1162,21 @@ async def generate_async(
11531162
if original_log_options:
11541163
if not any(
11551164
(
1156-
getattr(original_log_options, "internal_events", False),
1157-
getattr(original_log_options, "activated_rails", False),
1158-
getattr(original_log_options, "llm_calls", False),
1159-
getattr(original_log_options, "colang_history", False),
1165+
original_log_options.internal_events,
1166+
original_log_options.activated_rails,
1167+
original_log_options.llm_calls,
1168+
original_log_options.colang_history,
11601169
)
11611170
):
11621171
res.log = None
11631172
else:
11641173
# Ensure res.log exists before setting attributes
11651174
if res.log is not None:
1166-
if not getattr(
1167-
original_log_options, "internal_events", False
1168-
):
1175+
if not original_log_options.internal_events:
11691176
res.log.internal_events = []
1170-
if not getattr(
1171-
original_log_options, "activated_rails", False
1172-
):
1177+
if not original_log_options.activated_rails:
11731178
res.log.activated_rails = []
1174-
if not getattr(original_log_options, "llm_calls", False):
1179+
if not original_log_options.llm_calls:
11751180
res.log.llm_calls = []
11761181

11771182
return res

nemoguardrails/rails/llm/options.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
# {..., log: {"llm_calls": [...]}}
7777
7878
"""
79+
7980
from typing import Any, Dict, List, Optional, Union
8081

8182
from pydantic import BaseModel, Field, root_validator
@@ -156,7 +157,7 @@ class GenerationOptions(BaseModel):
156157
default=None,
157158
description="Additional parameters that should be used for the LLM call",
158159
)
159-
llm_output: Optional[bool] = Field(
160+
llm_output: bool = Field(
160161
default=False,
161162
description="Whether the response should also include any custom LLM output.",
162163
)

0 commit comments

Comments
 (0)