1212
1313from readchar import readchar
1414
15- from .misc .get_input import get_input
15+ from .misc .get_input import async_get_input
1616
1717# Third-party imports
1818os .environ ["LITELLM_LOCAL_MODEL_COST_MAP" ] = "True"
@@ -227,11 +227,14 @@ def default_system_message(self):
227227
228228 return system_message
229229
230- async def async_respond (self ):
230+ async def async_respond (self , user_input = None ):
231231 """
232232 Agentic sampling loop for the assistant/tool interaction.
233233 Yields chunks and maintains message history on the interpreter instance.
234234 """
235+ if user_input :
236+ self .messages .append ({"role" : "user" , "content" : user_input })
237+
235238 tools = []
236239 if "interpreter" in self .tools :
237240 tools .append (BashTool ())
@@ -925,14 +928,26 @@ def _handle_command(self, cmd: str, parts: list[str]) -> bool:
925928 return self ._command_handler .handle_command (cmd , parts )
926929
927930 def chat (self ):
928- """
929- Interactive mode
930- """
931+ """Chat with the interpreter. Handles both sync and async contexts."""
932+ try :
933+ loop = asyncio .get_running_loop ()
934+ # If we get here, there is a running event loop
935+ loop .create_task (self .async_chat ())
936+ except RuntimeError :
937+ # No running event loop, create one
938+ asyncio .run (self .async_chat ())
939+
940+ async def async_chat (self ):
941+ original_message_length = len (self .messages )
942+
931943 try :
932944 message_count = 0
933945 while True :
934- user_input = get_input ()
935- print ("" )
946+ try :
947+ user_input = await async_get_input ()
948+ except KeyboardInterrupt :
949+ print ()
950+ return self .messages [original_message_length :]
936951
937952 message_count += 1 # Increment counter after each message
938953
@@ -943,23 +958,23 @@ def chat(self):
943958 continue
944959
945960 if user_input == "" :
946- if message_count in range (4 , 7 ):
961+ if message_count in range (8 , 11 ):
947962 print ("Error: Cat is asleep on Enter key\n " )
948963 else :
949964 print ("Error: No input provided\n " )
950965 continue
951966
952- self .messages .append ({"role" : "user" , "content" : user_input })
953-
954- for _ in self .respond ():
955- pass
967+ try :
968+ print ()
969+ async for _ in self .async_respond (user_input ):
970+ pass
971+ except KeyboardInterrupt :
972+ self ._spinner .stop ()
973+ except asyncio .CancelledError :
974+ self ._spinner .stop ()
956975
957976 print ()
958- except KeyboardInterrupt :
959- self ._spinner .stop ()
960- print ()
961- pass
962- except Exception as e :
977+ except :
963978 self ._spinner .stop ()
964979 print (traceback .format_exc ())
965980 print ("\n \n \033 [91mAn error has occurred.\033 [0m" )
@@ -976,35 +991,34 @@ def chat(self):
976991 self ._report_error ("" .join (traceback .format_exc ()))
977992 exit (1 )
978993
979- async def _consume_generator (self , generator ):
980- """Consume the async generator from async_respond """
981- async for chunk in generator :
982- yield chunk
994+ def respond (self , user_input = None , stream = False ):
995+ """Sync method to respond to user input if provided, or to the messages in self.messages. """
996+ if user_input :
997+ self . messages . append ({ "role" : "user" , "content" : user_input })
983998
984- def respond (self ):
985- """
986- Synchronous wrapper around async_respond.
987- Yields chunks from the async generator.
988- """
999+ if stream :
1000+ return self ._sync_respond_stream ()
1001+ else :
1002+ original_message_length = len (self .messages )
1003+ for _ in self ._sync_respond_stream ():
1004+ pass
1005+ return self .messages [original_message_length :]
1006+
1007+ def _sync_respond_stream (self ):
1008+ """Synchronous generator that yields responses. Only use in synchronous contexts."""
1009+ loop = asyncio .new_event_loop ()
1010+ asyncio .set_event_loop (loop )
9891011 try :
990- loop = asyncio .get_event_loop ()
991- except RuntimeError :
992- loop = asyncio .new_event_loop ()
993- asyncio .set_event_loop (loop )
994-
995- async def run ():
996- async for chunk in self .async_respond ():
997- yield chunk
998-
999- agen = run ()
1000- while True :
1001- try :
1002- chunk = loop .run_until_complete (anext (agen ))
1003- yield chunk
1004- except StopAsyncIteration :
1005- break
1006-
1007- return self .messages
1012+ # Convert async generator to sync generator
1013+ async_gen = self .async_respond ()
1014+ while True :
1015+ try :
1016+ chunk = loop .run_until_complete (async_gen .__anext__ ())
1017+ yield chunk
1018+ except StopAsyncIteration :
1019+ break
1020+ finally :
1021+ loop .close ()
10081022
10091023 def server (self ):
10101024 """
0 commit comments