@@ -923,6 +923,19 @@ async def generate_async(
923923 The completion (when a prompt is provided) or the next message.
924924
925925 System messages are not yet supported."""
926+ # convert options to gen_options of type GenerationOptions
927+ gen_options : Optional [GenerationOptions ] = None
928+
929+ if prompt is None and messages is None :
930+ raise ValueError ("Either prompt or messages must be provided." )
931+
932+ if prompt is not None and messages is not None :
933+ raise ValueError ("Only one of prompt or messages can be provided." )
934+
935+ if prompt is not None :
936+ # Currently, we transform the prompt request into a single turn conversation
937+ messages = [{"role" : "user" , "content" : prompt }]
938+
926939 # If a state object is specified, then we switch to "generation options" mode.
927940 # This is because we want the output to be a GenerationResponse which will contain
928941 # the output state.
@@ -932,15 +945,25 @@ async def generate_async(
932945 state = json_to_state (state ["state" ])
933946
934947 if options is None :
935- options = GenerationOptions ()
936-
937- # We allow options to be specified both as a dict and as an object.
938- if options and isinstance (options , dict ):
939- options = GenerationOptions (** options )
948+ gen_options = GenerationOptions ()
949+ elif isinstance (options , dict ):
950+ gen_options = GenerationOptions (** options )
951+ else :
952+ gen_options = options
953+ else :
954+ # We allow options to be specified both as a dict and as an object.
955+ if options and isinstance (options , dict ):
956+ gen_options = GenerationOptions (** options )
957+ elif isinstance (options , GenerationOptions ):
958+ gen_options = options
959+ elif options is None :
960+ gen_options = None
961+ else :
962+ raise TypeError ("options must be a dict or GenerationOptions" )
940963
941964 # Save the generation options in the current async context.
942- # At this point, options is either None or GenerationOptions
943- generation_options_var .set (options if not isinstance ( options , dict ) else None )
965+ # At this point, gen_options is either None or GenerationOptions
966+ generation_options_var .set (gen_options )
944967
945968 if streaming_handler :
946969 streaming_handler_var .set (streaming_handler )
@@ -950,23 +973,14 @@ async def generate_async(
950973 # requests are made.
951974 self .explain_info = self ._ensure_explain_info ()
952975
953- if prompt is not None :
954- # Currently, we transform the prompt request into a single turn conversation
955- messages = [{"role" : "user" , "content" : prompt }]
956- raw_llm_request .set (prompt )
957- else :
958- raw_llm_request .set (messages )
976+ raw_llm_request .set (messages )
959977
960978 # If we have generation options, we also add them to the context
961- if options :
979+ if gen_options :
962980 messages = [
963981 {
964982 "role" : "context" ,
965- "content" : {
966- "generation_options" : getattr (
967- options , "dict" , lambda : options
968- )()
969- },
983+ "content" : {"generation_options" : gen_options .model_dump ()},
970984 }
971985 ] + (messages or [])
972986
@@ -976,9 +990,8 @@ async def generate_async(
976990 if (
977991 messages
978992 and messages [- 1 ]["role" ] == "assistant"
979- and options
980- and hasattr (options , "rails" )
981- and getattr (getattr (options , "rails" , None ), "dialog" , None ) is False
993+ and gen_options
994+ and gen_options .rails .dialog is False
982995 ):
983996 # We already have the first message with a context update, so we use that
984997 messages [0 ]["content" ]["bot_message" ] = messages [- 1 ]["content" ]
@@ -995,7 +1008,7 @@ async def generate_async(
9951008 processing_log = []
9961009
9971010 # The array of events corresponding to the provided sequence of messages.
998- events = self ._get_events_for_messages (messages or [] , state )
1011+ events = self ._get_events_for_messages (messages , state ) # type: ignore
9991012
10001013 if self .config .colang_version == "1.0" :
10011014 # If we had a state object, we also need to prepend the events from the state.
@@ -1114,7 +1127,7 @@ async def generate_async(
11141127 # If a state object is not used, then we use the implicit caching
11151128 if state is None :
11161129 # Save the new events in the history and update the cache
1117- cache_key = get_history_cache_key ((messages or [] ) + [new_message ])
1130+ cache_key = get_history_cache_key ((messages ) + [new_message ]) # type: ignore
11181131 self .events_history_cache [cache_key ] = events
11191132 else :
11201133 output_state = {"events" : events }
@@ -1142,33 +1155,29 @@ async def generate_async(
11421155 # IF tracing is enabled we need to set GenerationLog attrs
11431156 original_log_options = None
11441157 if self .config .tracing .enabled :
1145- if options is None :
1146- options = GenerationOptions ()
1158+ if gen_options is None :
1159+ gen_options = GenerationOptions ()
11471160 else :
1148- # create a copy of the options to avoid modifying the original
1149- if isinstance (options , GenerationOptions ):
1150- options = options .model_copy (deep = True )
1151- else :
1152- # If options is a dict, convert it to GenerationOptions
1153- options = GenerationOptions (** options )
1154- original_log_options = options .log .model_copy (deep = True )
1161+ # create a copy of the gen_options to avoid modifying the original
1162+ gen_options = gen_options .model_copy (deep = True )
1163+ original_log_options = gen_options .log .model_copy (deep = True )
11551164
11561165 # enable log options
11571166 # it is aggressive, but these are required for tracing
11581167 if (
1159- not options .log .activated_rails
1160- or not options .log .llm_calls
1161- or not options .log .internal_events
1168+ not gen_options .log .activated_rails
1169+ or not gen_options .log .llm_calls
1170+ or not gen_options .log .internal_events
11621171 ):
1163- options .log .activated_rails = True
1164- options .log .llm_calls = True
1165- options .log .internal_events = True
1172+ gen_options .log .activated_rails = True
1173+ gen_options .log .llm_calls = True
1174+ gen_options .log .internal_events = True
11661175
11671176 tool_calls = extract_tool_calls_from_events (new_events )
11681177 llm_metadata = get_and_clear_response_metadata_contextvar ()
11691178
11701179 # If we have generation options, we prepare a GenerationResponse instance.
1171- if options :
1180+ if gen_options :
11721181 # If a prompt was used, we only need to return the content of the message.
11731182 if prompt :
11741183 res = GenerationResponse (response = new_message ["content" ])
@@ -1195,9 +1204,9 @@ async def generate_async(
11951204
11961205 if self .config .colang_version == "1.0" :
11971206 # If output variables are specified, we extract their values
1198- if getattr ( options , "output_vars" , None ) :
1207+ if gen_options and gen_options . output_vars :
11991208 context = compute_context (events )
1200- output_vars = getattr ( options , " output_vars" , None )
1209+ output_vars = gen_options . output_vars
12011210 if isinstance (output_vars , list ):
12021211 # If we have only a selection of keys, we filter to only that.
12031212 res .output_data = {k : context .get (k ) for k in output_vars }
@@ -1208,65 +1217,64 @@ async def generate_async(
12081217 _log = compute_generation_log (processing_log )
12091218
12101219 # Include information about activated rails and LLM calls if requested
1211- log_options = getattr ( options , " log" , None )
1220+ log_options = gen_options . log if gen_options else None
12121221 if log_options and (
1213- getattr (log_options , "activated_rails" , False )
1214- or getattr (log_options , "llm_calls" , False )
1222+ log_options .activated_rails or log_options .llm_calls
12151223 ):
12161224 res .log = GenerationLog ()
12171225
12181226 # We always include the stats
12191227 res .log .stats = _log .stats
12201228
1221- if getattr ( log_options , " activated_rails" , False ) :
1229+ if log_options . activated_rails :
12221230 res .log .activated_rails = _log .activated_rails
12231231
1224- if getattr ( log_options , " llm_calls" , False ) :
1232+ if log_options . llm_calls :
12251233 res .log .llm_calls = []
12261234 for activated_rail in _log .activated_rails :
12271235 for executed_action in activated_rail .executed_actions :
12281236 res .log .llm_calls .extend (executed_action .llm_calls )
12291237
12301238 # Include internal events if requested
1231- if getattr ( log_options , "internal_events" , False ) :
1239+ if log_options and log_options . internal_events :
12321240 if res .log is None :
12331241 res .log = GenerationLog ()
12341242
12351243 res .log .internal_events = new_events
12361244
12371245 # Include the Colang history if requested
1238- if getattr ( log_options , "colang_history" , False ) :
1246+ if log_options and log_options . colang_history :
12391247 if res .log is None :
12401248 res .log = GenerationLog ()
12411249
12421250 res .log .colang_history = get_colang_history (events )
12431251
12441252 # Include the raw llm output if requested
1245- if getattr ( options , "llm_output" , False ) :
1253+ if gen_options and gen_options . llm_output :
12461254 # Currently, we include the output from the generation LLM calls.
12471255 for activated_rail in _log .activated_rails :
12481256 if activated_rail .type == "generation" :
12491257 for executed_action in activated_rail .executed_actions :
12501258 for llm_call in executed_action .llm_calls :
12511259 res .llm_output = llm_call .raw_response
12521260 else :
1253- if getattr ( options , "output_vars" , None ) :
1261+ if gen_options and gen_options . output_vars :
12541262 raise ValueError (
12551263 "The `output_vars` option is not supported for Colang 2.0 configurations."
12561264 )
12571265
1258- log_options = getattr ( options , " log" , None )
1266+ log_options = gen_options . log if gen_options else None
12591267 if log_options and (
1260- getattr ( log_options , " activated_rails" , False )
1261- or getattr ( log_options , " llm_calls" , False )
1262- or getattr ( log_options , " internal_events" , False )
1263- or getattr ( log_options , " colang_history" , False )
1268+ log_options . activated_rails
1269+ or log_options . llm_calls
1270+ or log_options . internal_events
1271+ or log_options . colang_history
12641272 ):
12651273 raise ValueError (
12661274 "The `log` option is not supported for Colang 2.0 configurations."
12671275 )
12681276
1269- if getattr ( options , "llm_output" , False ) :
1277+ if gen_options and gen_options . llm_output :
12701278 raise ValueError (
12711279 "The `llm_output` option is not supported for Colang 2.0 configurations."
12721280 )
@@ -1300,25 +1308,21 @@ async def generate_async(
13001308 if original_log_options :
13011309 if not any (
13021310 (
1303- getattr ( original_log_options , " internal_events" , False ) ,
1304- getattr ( original_log_options , " activated_rails" , False ) ,
1305- getattr ( original_log_options , " llm_calls" , False ) ,
1306- getattr ( original_log_options , " colang_history" , False ) ,
1311+ original_log_options . internal_events ,
1312+ original_log_options . activated_rails ,
1313+ original_log_options . llm_calls ,
1314+ original_log_options . colang_history ,
13071315 )
13081316 ):
13091317 res .log = None
13101318 else :
13111319 # Ensure res.log exists before setting attributes
13121320 if res .log is not None :
1313- if not getattr (
1314- original_log_options , "internal_events" , False
1315- ):
1321+ if not original_log_options .internal_events :
13161322 res .log .internal_events = []
1317- if not getattr (
1318- original_log_options , "activated_rails" , False
1319- ):
1323+ if not original_log_options .activated_rails :
13201324 res .log .activated_rails = []
1321- if not getattr ( original_log_options , " llm_calls" , False ) :
1325+ if not original_log_options . llm_calls :
13221326 res .log .llm_calls = []
13231327
13241328 return res
0 commit comments