@@ -776,6 +776,19 @@ async def generate_async(
776776 The completion (when a prompt is provided) or the next message.
777777
778778 System messages are not yet supported."""
779+ # convert options to gen_options of type GenerationOptions
780+ gen_options : Optional [GenerationOptions ] = None
781+
782+ if prompt is None and messages is None :
783+ raise ValueError ("Either prompt or messages must be provided." )
784+
785+ if prompt is not None and messages is not None :
786+ raise ValueError ("Only one of prompt or messages can be provided." )
787+
788+ if prompt is not None :
789+ # Currently, we transform the prompt request into a single turn conversation
790+ messages = [{"role" : "user" , "content" : prompt }]
791+
779792 # If a state object is specified, then we switch to "generation options" mode.
780793 # This is because we want the output to be a GenerationResponse which will contain
781794 # the output state.
@@ -785,15 +798,25 @@ async def generate_async(
785798 state = json_to_state (state ["state" ])
786799
787800 if options is None :
788- options = GenerationOptions ()
789-
790- # We allow options to be specified both as a dict and as an object.
791- if options and isinstance (options , dict ):
792- options = GenerationOptions (** options )
801+ gen_options = GenerationOptions ()
802+ elif isinstance (options , dict ):
803+ gen_options = GenerationOptions (** options )
804+ else :
805+ gen_options = options
806+ else :
807+ # We allow options to be specified both as a dict and as an object.
808+ if options and isinstance (options , dict ):
809+ gen_options = GenerationOptions (** options )
810+ elif isinstance (options , GenerationOptions ):
811+ gen_options = options
812+ elif options is None :
813+ gen_options = None
814+ else :
815+ raise TypeError ("options must be a dict or GenerationOptions" )
793816
794817 # Save the generation options in the current async context.
795- # At this point, options is either None or GenerationOptions
796- generation_options_var .set (options if not isinstance ( options , dict ) else None )
818+ # At this point, gen_options is either None or GenerationOptions
819+ generation_options_var .set (gen_options )
797820
798821 if streaming_handler :
799822 streaming_handler_var .set (streaming_handler )
@@ -803,23 +826,14 @@ async def generate_async(
803826 # requests are made.
804827 self .explain_info = self ._ensure_explain_info ()
805828
806- if prompt is not None :
807- # Currently, we transform the prompt request into a single turn conversation
808- messages = [{"role" : "user" , "content" : prompt }]
809- raw_llm_request .set (prompt )
810- else :
811- raw_llm_request .set (messages )
829+ raw_llm_request .set (messages )
812830
813831 # If we have generation options, we also add them to the context
814- if options :
832+ if gen_options :
815833 messages = [
816834 {
817835 "role" : "context" ,
818- "content" : {
819- "generation_options" : getattr (
820- options , "dict" , lambda : options
821- )()
822- },
836+ "content" : {"generation_options" : gen_options .model_dump ()},
823837 }
824838 ] + (messages or [])
825839
@@ -848,7 +862,7 @@ async def generate_async(
848862 processing_log = []
849863
850864 # The array of events corresponding to the provided sequence of messages.
851- events = self ._get_events_for_messages (messages or [] , state )
865+ events = self ._get_events_for_messages (messages , state ) # type: ignore
852866
853867 if self .config .colang_version == "1.0" :
854868 # If we had a state object, we also need to prepend the events from the state.
@@ -967,7 +981,7 @@ async def generate_async(
967981 # If a state object is not used, then we use the implicit caching
968982 if state is None :
969983 # Save the new events in the history and update the cache
970- cache_key = get_history_cache_key ((messages or [] ) + [new_message ])
984+ cache_key = get_history_cache_key ((messages ) + [new_message ]) # type: ignore
971985 self .events_history_cache [cache_key ] = events
972986 else :
973987 output_state = {"events" : events }
@@ -995,33 +1009,29 @@ async def generate_async(
9951009 # IF tracing is enabled we need to set GenerationLog attrs
9961010 original_log_options = None
9971011 if self .config .tracing .enabled :
998- if options is None :
999- options = GenerationOptions ()
1012+ if gen_options is None :
1013+ gen_options = GenerationOptions ()
10001014 else :
1001- # create a copy of the options to avoid modifying the original
1002- if isinstance (options , GenerationOptions ):
1003- options = options .model_copy (deep = True )
1004- else :
1005- # If options is a dict, convert it to GenerationOptions
1006- options = GenerationOptions (** options )
1007- original_log_options = options .log .model_copy (deep = True )
1015+ # create a copy of the gen_options to avoid modifying the original
1016+ gen_options = gen_options .model_copy (deep = True )
1017+ original_log_options = gen_options .log .model_copy (deep = True )
10081018
10091019 # enable log options
10101020 # it is aggressive, but these are required for tracing
10111021 if (
1012- not options .log .activated_rails
1013- or not options .log .llm_calls
1014- or not options .log .internal_events
1022+ not gen_options .log .activated_rails
1023+ or not gen_options .log .llm_calls
1024+ or not gen_options .log .internal_events
10151025 ):
1016- options .log .activated_rails = True
1017- options .log .llm_calls = True
1018- options .log .internal_events = True
1026+ gen_options .log .activated_rails = True
1027+ gen_options .log .llm_calls = True
1028+ gen_options .log .internal_events = True
10191029
10201030 tool_calls = extract_tool_calls_from_events (new_events )
10211031 llm_metadata = get_and_clear_response_metadata_contextvar ()
10221032
10231033 # If we have generation options, we prepare a GenerationResponse instance.
1024- if options :
1034+ if gen_options :
10251035 # If a prompt was used, we only need to return the content of the message.
10261036 if prompt :
10271037 res = GenerationResponse (response = new_message ["content" ])
@@ -1048,9 +1058,9 @@ async def generate_async(
10481058
10491059 if self .config .colang_version == "1.0" :
10501060 # If output variables are specified, we extract their values
1051- if getattr ( options , "output_vars" , None ) :
1061+ if gen_options and gen_options . output_vars :
10521062 context = compute_context (events )
1053- output_vars = getattr ( options , " output_vars" , None )
1063+ output_vars = gen_options . output_vars
10541064 if isinstance (output_vars , list ):
10551065 # If we have only a selection of keys, we filter to only that.
10561066 res .output_data = {k : context .get (k ) for k in output_vars }
@@ -1061,65 +1071,64 @@ async def generate_async(
10611071 _log = compute_generation_log (processing_log )
10621072
10631073 # Include information about activated rails and LLM calls if requested
1064- log_options = getattr ( options , " log" , None )
1074+ log_options = gen_options . log if gen_options else None
10651075 if log_options and (
1066- getattr (log_options , "activated_rails" , False )
1067- or getattr (log_options , "llm_calls" , False )
1076+ log_options .activated_rails or log_options .llm_calls
10681077 ):
10691078 res .log = GenerationLog ()
10701079
10711080 # We always include the stats
10721081 res .log .stats = _log .stats
10731082
1074- if getattr ( log_options , " activated_rails" , False ) :
1083+ if log_options . activated_rails :
10751084 res .log .activated_rails = _log .activated_rails
10761085
1077- if getattr ( log_options , " llm_calls" , False ) :
1086+ if log_options . llm_calls :
10781087 res .log .llm_calls = []
10791088 for activated_rail in _log .activated_rails :
10801089 for executed_action in activated_rail .executed_actions :
10811090 res .log .llm_calls .extend (executed_action .llm_calls )
10821091
10831092 # Include internal events if requested
1084- if getattr ( log_options , "internal_events" , False ) :
1093+ if log_options and log_options . internal_events :
10851094 if res .log is None :
10861095 res .log = GenerationLog ()
10871096
10881097 res .log .internal_events = new_events
10891098
10901099 # Include the Colang history if requested
1091- if getattr ( log_options , "colang_history" , False ) :
1100+ if log_options and log_options . colang_history :
10921101 if res .log is None :
10931102 res .log = GenerationLog ()
10941103
10951104 res .log .colang_history = get_colang_history (events )
10961105
10971106 # Include the raw llm output if requested
1098- if getattr ( options , "llm_output" , False ) :
1107+ if gen_options and gen_options . llm_output :
10991108 # Currently, we include the output from the generation LLM calls.
11001109 for activated_rail in _log .activated_rails :
11011110 if activated_rail .type == "generation" :
11021111 for executed_action in activated_rail .executed_actions :
11031112 for llm_call in executed_action .llm_calls :
11041113 res .llm_output = llm_call .raw_response
11051114 else :
1106- if getattr ( options , "output_vars" , None ) :
1115+ if gen_options and gen_options . output_vars :
11071116 raise ValueError (
11081117 "The `output_vars` option is not supported for Colang 2.0 configurations."
11091118 )
11101119
1111- log_options = getattr ( options , " log" , None )
1120+ log_options = gen_options . log if gen_options else None
11121121 if log_options and (
1113- getattr ( log_options , " activated_rails" , False )
1114- or getattr ( log_options , " llm_calls" , False )
1115- or getattr ( log_options , " internal_events" , False )
1116- or getattr ( log_options , " colang_history" , False )
1122+ log_options . activated_rails
1123+ or log_options . llm_calls
1124+ or log_options . internal_events
1125+ or log_options . colang_history
11171126 ):
11181127 raise ValueError (
11191128 "The `log` option is not supported for Colang 2.0 configurations."
11201129 )
11211130
1122- if getattr ( options , "llm_output" , False ) :
1131+ if gen_options and gen_options . llm_output :
11231132 raise ValueError (
11241133 "The `llm_output` option is not supported for Colang 2.0 configurations."
11251134 )
@@ -1153,25 +1162,21 @@ async def generate_async(
11531162 if original_log_options :
11541163 if not any (
11551164 (
1156- getattr ( original_log_options , " internal_events" , False ) ,
1157- getattr ( original_log_options , " activated_rails" , False ) ,
1158- getattr ( original_log_options , " llm_calls" , False ) ,
1159- getattr ( original_log_options , " colang_history" , False ) ,
1165+ original_log_options . internal_events ,
1166+ original_log_options . activated_rails ,
1167+ original_log_options . llm_calls ,
1168+ original_log_options . colang_history ,
11601169 )
11611170 ):
11621171 res .log = None
11631172 else :
11641173 # Ensure res.log exists before setting attributes
11651174 if res .log is not None :
1166- if not getattr (
1167- original_log_options , "internal_events" , False
1168- ):
1175+ if not original_log_options .internal_events :
11691176 res .log .internal_events = []
1170- if not getattr (
1171- original_log_options , "activated_rails" , False
1172- ):
1177+ if not original_log_options .activated_rails :
11731178 res .log .activated_rails = []
1174- if not getattr ( original_log_options , " llm_calls" , False ) :
1179+ if not original_log_options . llm_calls :
11751180 res .log .llm_calls = []
11761181
11771182 return res
0 commit comments