@@ -776,6 +776,55 @@ def live_info(self, *, update_interval: float = 0.1) -> None:
776776 """
777777 return live_info (self , update_interval = update_interval )
778778
779+ def live_info_terminal (
780+ self , * , update_interval : float = 0.5 , overwrite_previous : bool = True
781+ ) -> asyncio .Task :
782+ """
783+ Display live information about the runner in the terminal.
784+
785+ This function provides a live update of the runner's status in the terminal.
786+ The update can either overwrite the previous status or be printed on a new line.
787+
788+ Parameters
789+ ----------
790+ update_interval : float, optional
791+ The time interval (in seconds) at which the runner's status is updated in the terminal.
792+ Default is 0.5 seconds.
793+ overwrite_previous : bool, optional
794+ If True, each update will overwrite the previous status in the terminal.
795+ If False, each update will be printed on a new line.
796+ Default is True.
797+
798+ Returns
799+ -------
800+ asyncio.Task
801+ The asynchronous task responsible for updating the runner's status in the terminal.
802+
803+ Examples
804+ --------
805+ >>> runner = AsyncRunner(...)
806+ >>> runner.live_info_terminal(update_interval=1.0, overwrite_previous=False)
807+
808+ Notes
809+ -----
810+ This function uses ANSI escape sequences to control the terminal's cursor position.
811+ It might not work as expected on all terminal emulators.
812+ """
813+
814+ async def _update (runner : AsyncRunner ) -> None :
815+ try :
816+ while not runner .task .done ():
817+ if overwrite_previous :
818+ # Clear the terminal
819+ print ("\033 [H\033 [J" , end = "" )
820+ print (_info_text (runner , separator = "\t " ))
821+ await asyncio .sleep (update_interval )
822+
823+ except asyncio .CancelledError :
824+ print ("Live info display cancelled." )
825+
826+ return self .ioloop .create_task (_update (self ))
827+
779828 async def _run (self ) -> None :
780829 first_completed = asyncio .FIRST_COMPLETED
781830
@@ -855,6 +904,43 @@ async def _saver():
855904 return self .saving_task
856905
857906
907+ def _info_text (runner , separator : str = "\n " ):
908+ status = runner .status ()
909+
910+ color_map = {
911+ "cancelled" : "\033 [33m" , # Yellow
912+ "failed" : "\033 [31m" , # Red
913+ "running" : "\033 [34m" , # Blue
914+ "finished" : "\033 [32m" , # Green
915+ }
916+
917+ overhead = runner .overhead ()
918+ if overhead < 50 :
919+ overhead_color = "\033 [32m" # Green
920+ else :
921+ overhead_color = "\033 [31m" # Red
922+
923+ info = [
924+ ("time" , str (datetime .now ())),
925+ ("status" , f"{ color_map [status ]} { status } \033 [0m" ),
926+ ("elapsed time" , str (timedelta (seconds = runner .elapsed_time ()))),
927+ ("overhead" , f"{ overhead_color } { overhead :.2f} %\033 [0m" ),
928+ ]
929+
930+ with suppress (Exception ):
931+ info .append (("# of points" , runner .learner .npoints ))
932+
933+ with suppress (Exception ):
934+ info .append (("# of samples" , runner .learner .nsamples ))
935+
936+ with suppress (Exception ):
937+ info .append (("latest loss" , f'{ runner .learner ._cache ["loss" ]:.3f} ' ))
938+
939+ width = 30
940+ formatted_info = [f"{ k } : { v } " .ljust (width ) for i , (k , v ) in enumerate (info )]
941+ return separator .join (formatted_info )
942+
943+
858944# Default runner
859945Runner = AsyncRunner
860946
0 commit comments