66from tqdm .auto import tqdm
77
88from ...server_version .server_version import ServerVersion
9- from .progress_provider import ProgressProvider
9+ from .progress_provider import ProgressProvider , TaskWithProgress
1010from .query_progress_provider import CypherQueryFunction , QueryProgressProvider , ServerVersionFunction
1111from .static_progress_provider import StaticProgressProvider , StaticProgressStore
1212
1313DataFrameProducer = Callable [[], DataFrame ]
1414
1515
1616class QueryProgressLogger :
17- _LOG_POLLING_INTERVAL = 0.5
18-
1917 def __init__ (
2018 self ,
2119 run_cypher_func : CypherQueryFunction ,
2220 server_version_func : ServerVersionFunction ,
21+ polling_interval : float = 0.5 ,
22+ progress_bar_options : dict [str , Any ] = {},
2323 ):
2424 self ._run_cypher_func = run_cypher_func
2525 self ._server_version_func = server_version_func
2626 self ._static_progress_provider = StaticProgressProvider ()
2727 self ._query_progress_provider = QueryProgressProvider (run_cypher_func , server_version_func )
28+ self ._polling_interval = polling_interval
29+ self ._progress_bar_options = progress_bar_options
2830
2931 def run_with_progress_logging (
3032 self , runnable : DataFrameProducer , job_id : str , database : Optional [str ] = None
@@ -54,39 +56,18 @@ def _select_progress_provider(self, job_id: str) -> ProgressProvider:
5456 )
5557
5658 def _log (
57- self , future : " Future[Any]" , job_id : str , progress_provider : ProgressProvider , database : Optional [str ] = None
59+ self , future : Future [Any ], job_id : str , progress_provider : ProgressProvider , database : Optional [str ] = None
5860 ) -> None :
5961 pbar : Optional [tqdm [NoReturn ]] = None
6062 warn_if_failure = True
6163
62- while wait ([future ], timeout = self ._LOG_POLLING_INTERVAL ).not_done :
64+ while wait ([future ], timeout = self ._polling_interval ).not_done :
6365 try :
6466 task_with_progress = progress_provider .root_task_with_progress (job_id , database )
65- root_task_name = task_with_progress .task_name
66- progress_percent = task_with_progress .progress_percent
67-
68- has_relative_progress = progress_percent != "n/a"
6967 if pbar is None :
70- if has_relative_progress :
71- pbar = tqdm (total = 100 , unit = "%" , desc = root_task_name , maxinterval = self ._LOG_POLLING_INTERVAL )
72- else :
73- pbar = tqdm (
74- total = None ,
75- unit = "" ,
76- desc = root_task_name ,
77- maxinterval = self ._LOG_POLLING_INTERVAL ,
78- bar_format = "{desc} [elapsed: {elapsed} {postfix}]" ,
79- )
80-
81- pbar .set_postfix_str (
82- f"status: { task_with_progress .status } , task: { task_with_progress .sub_tasks_description } "
83- )
84- if has_relative_progress :
85- parsed_progress = float (progress_percent [:- 1 ])
86- new_progress = parsed_progress - pbar .n
87- pbar .update (new_progress )
88- else :
89- pbar .refresh () # show latest elapsed time + postfix
68+ pbar = self ._init_pbar (task_with_progress )
69+
70+ self ._update_pbar (pbar , task_with_progress )
9071 except Exception as e :
9172 # Do nothing if the procedure either:
9273 # * has not started yet,
@@ -100,7 +81,51 @@ def _log(
10081 continue
10182
10283 if pbar is not None :
103- if pbar .total is not None :
104- pbar .update (pbar .total - pbar .n )
105- pbar .set_postfix_str ("status: finished" )
84+ self ._finish_pbar (pbar )
85+
86+ def _init_pbar (self , task_with_progress : TaskWithProgress ) -> tqdm : # type: ignore
87+ root_task_name = task_with_progress .task_name
88+ parsed_progress = QueryProgressLogger ._relative_progress (task_with_progress )
89+ if parsed_progress is None : # Qualitative progress report
90+ return tqdm (
91+ total = None ,
92+ unit = "" ,
93+ desc = root_task_name ,
94+ maxinterval = self ._polling_interval ,
95+ bar_format = "{desc} [elapsed: {elapsed} {postfix}]" ,
96+ ** self ._progress_bar_options ,
97+ )
98+ else :
99+ return tqdm (
100+ total = 100 ,
101+ unit = "%" ,
102+ desc = root_task_name ,
103+ maxinterval = self ._polling_interval ,
104+ ** self ._progress_bar_options ,
105+ )
106+
107+ def _update_pbar (self , pbar : tqdm , task_with_progress : TaskWithProgress ) -> None : # type: ignore
108+ parsed_progress = QueryProgressLogger ._relative_progress (task_with_progress )
109+ postfix = (
110+ f"status: { task_with_progress .status } , task: { task_with_progress .sub_tasks_description } "
111+ if task_with_progress .sub_tasks_description
112+ else f"status: { task_with_progress .status } "
113+ )
114+ pbar .set_postfix_str (postfix , refresh = False )
115+ if parsed_progress is not None :
116+ new_progress = parsed_progress - pbar .n
117+ pbar .update (new_progress )
118+ else :
106119 pbar .refresh ()
120+
121+ def _finish_pbar (self , pbar : tqdm ) -> None : # type: ignore
122+ if pbar .total is not None :
123+ pbar .update (pbar .total - pbar .n )
124+ pbar .set_postfix_str ("status: FINISHED" , refresh = True )
125+
126+ @staticmethod
127+ def _relative_progress (task : TaskWithProgress ) -> Optional [float ]:
128+ try :
129+ return float (task .progress_percent .removesuffix ("%" ))
130+ except ValueError :
131+ return None
0 commit comments