1515import asyncio
1616import json
1717import os
18- from dataclasses import dataclass , field
19- from typing import Dict , List , Optional , cast
18+ from dataclasses import asdict , dataclass , field
19+ from typing import Dict , List , Optional , Tuple , Union , cast
2020
2121import aiohttp
2222from prompt_toolkit import HTML , PromptSession
3030from nemoguardrails .colang .v2_x .runtime .runtime import RuntimeV2_x
3131from nemoguardrails .logging import verbose
3232from nemoguardrails .logging .verbose import console
33- from nemoguardrails .streaming import StreamingHandler
33+ from nemoguardrails .rails .llm .options import (
34+ GenerationLog ,
35+ GenerationOptions ,
36+ GenerationResponse ,
37+ )
3438from nemoguardrails .utils import get_or_create_event_loop , new_event_dict , new_uuid
3539
3640os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
@@ -61,6 +65,8 @@ async def _run_chat_v1_0(
6165 )
6266
6367 if not server_url :
68+ if config_path is None :
69+ raise RuntimeError ("config_path cannot be None when server_url is None" )
6470 rails_config = RailsConfig .from_path (config_path )
6571 rails_app = LLMRails (rails_config , verbose = verbose )
6672 if streaming and not rails_config .streaming_supported :
@@ -82,7 +88,12 @@ async def _run_chat_v1_0(
8288
8389 if not server_url :
8490 # If we have streaming from a locally loaded config, we initialize the handler.
85- if streaming and not server_url and rails_app .main_llm_supports_streaming :
91+ if (
92+ streaming
93+ and not server_url
94+ and rails_app
95+ and rails_app .main_llm_supports_streaming
96+ ):
8697 bot_message_list = []
8798 async for chunk in rails_app .stream_async (messages = history ):
8899 if '{"event": "ABORT"' in chunk :
@@ -101,11 +112,40 @@ async def _run_chat_v1_0(
101112 bot_message = {"role" : "assistant" , "content" : bot_message_text }
102113
103114 else :
104- bot_message = await rails_app .generate_async (messages = history )
115+ if rails_app is None :
116+ raise RuntimeError ("Rails App is None" )
117+ response : Union [
118+ str , Dict , GenerationResponse , Tuple [Dict , Dict ]
119+ ] = await rails_app .generate_async (messages = history )
120+
121+ # Handle different return types from generate_async
122+ if isinstance (response , tuple ) and len (response ) == 2 :
123+ bot_message = (
124+ response [0 ]
125+ if response
126+ else {"role" : "assistant" , "content" : "" }
127+ )
128+ elif isinstance (response , GenerationResponse ):
129+ # GenerationResponse case
130+ response_attr = getattr (response , "response" , None )
131+ if isinstance (response_attr , list ) and len (response_attr ) > 0 :
132+ bot_message = response_attr [0 ]
133+ else :
134+ bot_message = {
135+ "role" : "assistant" ,
136+ "content" : str (response_attr ),
137+ }
138+ elif isinstance (response , dict ):
139+ # Direct dict case
140+ bot_message = response
141+ else :
142+ # String or other fallback case
143+ bot_message = {"role" : "assistant" , "content" : str (response )}
105144
106145 if not streaming or not rails_app .main_llm_supports_streaming :
107146 # We print bot messages in green.
108- console .print ("[green]" + f"{ bot_message ['content' ]} " + "[/]" )
147+ content = bot_message .get ("content" , str (bot_message ))
148+ console .print ("[green]" + f"{ content } " + "[/]" )
109149 else :
110150 data = {
111151 "config_id" : config_id ,
@@ -116,19 +156,19 @@ async def _run_chat_v1_0(
116156 async with session .post (
117157 f"{ server_url } /v1/chat/completions" ,
118158 json = data ,
119- ) as response :
159+ ) as http_response :
120160 # If the response is streaming, we show each chunk as it comes
121- if response .headers .get ("Transfer-Encoding" ) == "chunked" :
161+ if http_response .headers .get ("Transfer-Encoding" ) == "chunked" :
122162 bot_message_text = ""
123- async for chunk in response .content .iter_any ():
124- chunk = chunk .decode ("utf-8" )
163+ async for chunk_bytes in http_response .content .iter_any ():
164+ chunk = chunk_bytes .decode ("utf-8" )
125165 console .print ("[green]" + f"{ chunk } " + "[/]" , end = "" )
126166 bot_message_text += chunk
127167 console .print ("" )
128168
129169 bot_message = {"role" : "assistant" , "content" : bot_message_text }
130170 else :
131- result = await response .json ()
171+ result = await http_response .json ()
132172 bot_message = result ["messages" ][0 ]
133173
134174 # We print bot messages in green.
@@ -297,7 +337,8 @@ def _process_output():
297337 else :
298338 console .print (
299339 "[black on magenta]"
300- + f"scene information (start): (title={ event ['title' ]} , action_uid={ event ['action_uid' ]} , content={ event ['content' ]} )"
340+ + f"scene information (start): (title={ event ['title' ]} , "
341+ + f"action_uid={ event ['action_uid' ]} , content={ event ['content' ]} )"
301342 + "[/]"
302343 )
303344
@@ -333,7 +374,8 @@ def _process_output():
333374 else :
334375 console .print (
335376 "[black on magenta]"
336- + f"scene form (start): (prompt={ event ['prompt' ]} , action_uid={ event ['action_uid' ]} , inputs={ event ['inputs' ]} )"
377+ + f"scene form (start): (prompt={ event ['prompt' ]} , "
378+ + f"action_uid={ event ['action_uid' ]} , inputs={ event ['inputs' ]} )"
337379 + "[/]"
338380 )
339381 chat_state .input_events .append (
@@ -370,7 +412,8 @@ def _process_output():
370412 else :
371413 console .print (
372414 "[black on magenta]"
373- + f"scene choice (start): (prompt={ event ['prompt' ]} , action_uid={ event ['action_uid' ]} , options={ event ['options' ]} )"
415+ + f"scene choice (start): (prompt={ event ['prompt' ]} , "
416+ + f"action_uid={ event ['action_uid' ]} , options={ event ['options' ]} )"
374417 + "[/]"
375418 )
376419 chat_state .input_events .append (
@@ -452,12 +495,16 @@ async def _check_local_async_actions():
452495 # We need to copy input events to prevent race condition
453496 input_events_copy = chat_state .input_events .copy ()
454497 chat_state .input_events = []
455- (
456- chat_state .output_events ,
457- chat_state .output_state ,
458- ) = await rails_app .process_events_async (
459- input_events_copy , chat_state .state
498+
499+ output_events , output_state = await rails_app .process_events_async (
500+ input_events_copy ,
501+ asdict (chat_state .state ) if chat_state .state else None ,
460502 )
503+ chat_state .output_events = output_events
504+
505+ # process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
506+ if output_state :
507+ chat_state .output_state = cast (State , State (** output_state ))
461508
462509 # Process output_events and potentially generate new input_events
463510 _process_output ()
@@ -470,7 +517,8 @@ async def _check_local_async_actions():
470517 # If there are no pending actions, we stop
471518 check_task .cancel ()
472519 check_task = None
473- debugger .set_output_state (chat_state .output_state )
520+ if chat_state .output_state is not None :
521+ debugger .set_output_state (chat_state .output_state )
474522 chat_state .status .stop ()
475523 enable_input .set ()
476524 return
@@ -485,13 +533,16 @@ async def _process_input_events():
485533 # We need to copy input events to prevent race condition
486534 input_events_copy = chat_state .input_events .copy ()
487535 chat_state .input_events = []
488- (
489- chat_state .output_events ,
490- chat_state .output_state ,
491- ) = await rails_app .process_events_async (
492- input_events_copy , chat_state .state
536+ output_events , output_state = await rails_app .process_events_async (
537+ input_events_copy ,
538+ asdict (chat_state .state ) if chat_state .state else None ,
493539 )
494- debugger .set_output_state (chat_state .output_state )
540+ chat_state .output_events = output_events
541+ if output_state :
542+ # process_events_async returns a Dict `state`, need to convert to dataclass for ChatState object
543+ output_state_typed : State = cast (State , State (** output_state ))
544+ chat_state .output_state = output_state_typed
545+ debugger .set_output_state (output_state_typed )
495546
496547 _process_output ()
497548 # If we don't have a check task, we start it
@@ -653,6 +704,8 @@ def run_chat(
653704 server_url (Optional[str]): The URL of the chat server. Defaults to None.
654705 config_id (Optional[str]): The configuration ID. Defaults to None.
655706 """
707+ if config_path is None :
708+ raise RuntimeError ("config_path cannot be None" )
656709 rails_config = RailsConfig .from_path (config_path )
657710
658711 if verbose and verbose_llm_calls :
0 commit comments